aboutsummaryrefslogtreecommitdiff
path: root/db
diff options
context:
space:
mode:
authorjloup <jean-loup.jamet@trainline.com>2018-02-14 14:19:09 +0100
committerjloup <jean-loup.jamet@trainline.com>2018-02-14 14:19:09 +0100
commit7a9e5112eaaea58d55f181d3e5296e4ff839921c (patch)
tree968ed193f42a1fad759cc89ad2f8ad5b0091291e /db
downloadFront-7a9e5112eaaea58d55f181d3e5296e4ff839921c.tar.gz
Front-7a9e5112eaaea58d55f181d3e5296e4ff839921c.tar.zst
Front-7a9e5112eaaea58d55f181d3e5296e4ff839921c.zip
initial commit
Diffstat (limited to 'db')
-rw-r--r--db/db.go75
-rw-r--r--db/db_test.go35
-rw-r--r--db/errors.go23
-rw-r--r--db/market_config.go45
-rw-r--r--db/user.go72
5 files changed, 250 insertions, 0 deletions
diff --git a/db/db.go b/db/db.go
new file mode 100644
index 0000000..bc4b8b3
--- /dev/null
+++ b/db/db.go
@@ -0,0 +1,75 @@
1package db
2
3import (
4 "fmt"
5 "strings"
6
7 "github.com/go-pg/pg"
8 "github.com/go-pg/pg/orm"
9 "github.com/jloup/utils"
10)
11
12var DB *pg.DB
13
14var log = utils.StandardL().WithField("module", "db")
15
16type DBConfig struct {
17 Address string
18 Database string
19 User string
20 Password string
21}
22
23func Init(config DBConfig) {
24 var err error
25
26 DB = connect(config)
27
28 err = createSchema(DB)
29 if err != nil {
30 log.Errorf("cannot create schemas %v\n", err)
31 }
32
33 err = createIndexes(DB)
34 if err != nil {
35 log.Errorf("cannot create indexes %v\n", err)
36 }
37}
38
39func connect(config DBConfig) *pg.DB {
40 return pg.Connect(&pg.Options{
41 User: config.User,
42 Password: config.Password,
43 Database: config.Database,
44 Addr: config.Address,
45 })
46}
47
48func createSchema(db *pg.DB) error {
49 for _, model := range []interface{}{&User{}, &MarketConfig{}} {
50 err := db.CreateTable(model, &orm.CreateTableOptions{IfNotExists: true})
51 if err != nil {
52 return err
53 }
54 }
55 return nil
56}
57
58func createIndexes(db *pg.DB) error {
59 indexes := []struct {
60 TableName string
61 Name string
62 Columns []string
63 }{
64 {"market_configs", "market_name_user_id_idx", []string{"user_id", "market_name"}},
65 }
66
67 for _, index := range indexes {
68 _, err := db.Exec(fmt.Sprintf("CREATE UNIQUE INDEX IF NOT EXISTS %s ON %s (%s)", index.Name, index.TableName, strings.Join(index.Columns, ",")))
69 if err != nil {
70 return err
71 }
72 }
73
74 return nil
75}
diff --git a/db/db_test.go b/db/db_test.go
new file mode 100644
index 0000000..0481915
--- /dev/null
+++ b/db/db_test.go
@@ -0,0 +1,35 @@
1package db
2
3import "testing"
4
5func TestInit(t *testing.T) {
6 Init(DBConfig{"localhost:5432", "cryptoportfolio", "cryptoportfolio", "cryptoportfolio-dev"})
7}
8
9func TestUpdateUser(t *testing.T) {
10 Init(DBConfig{"localhost:5432", "cryptoportfolio", "cryptoportfolio", "cryptoportfolio-dev"})
11 t.Log(InsertUser(&User{Email: "j@test.com", PasswordHash: "yp"}))
12 err := InsertUser(&User{Email: "t2@test.com", PasswordHash: "yp"})
13
14 t.Log(err, IsDup(err))
15
16 t.Log(GetUserByEmail("testyo"))
17}
18
19func TestMarketConfig(t *testing.T) {
20 Init(DBConfig{"localhost:5432", "cryptoportfolio", "cryptoportfolio", "cryptoportfolio-dev"})
21
22 config := MarketConfig{UserId: 1, MarketName: "poloniex"}
23 config.Config = make(map[string]string)
24
25 config.Config["secret"] = "key"
26
27 t.Log(InsertMarketConfig(&config))
28 t.Log(config)
29
30 t.Log(GetUserMarketConfig(1, "poloniex"))
31
32 config.Config["secret2"] = "key2"
33 t.Log(SetUserMarketConfig(1, "poloniex", config.Config))
34 t.Log(SetUserMarketConfig(1, "bifinance", config.Config))
35}
diff --git a/db/errors.go b/db/errors.go
new file mode 100644
index 0000000..ed5f371
--- /dev/null
+++ b/db/errors.go
@@ -0,0 +1,23 @@
1package db
2
3import (
4 "strings"
5
6 "github.com/go-pg/pg"
7)
8
9func PGCode(err error) string {
10 if _, ok := err.(pg.Error); !ok {
11 return ""
12 }
13
14 return err.(pg.Error).Field('C')
15}
16
17func IsDup(err error) bool {
18 return PGCode(err) == "23505"
19}
20
21func IsSQLError(err error) bool {
22 return strings.HasPrefix(err.Error(), "ERROR #")
23}
diff --git a/db/market_config.go b/db/market_config.go
new file mode 100644
index 0000000..b26c092
--- /dev/null
+++ b/db/market_config.go
@@ -0,0 +1,45 @@
1package db
2
3import "github.com/go-pg/pg"
4
5type MarketConfig struct {
6 Id int64
7 MarketName string `sql:",notnull"`
8 UserId int64 `sql:",notnull"`
9 Config map[string]string
10}
11
12func InsertMarketConfig(config *MarketConfig) error {
13 return DB.Insert(config)
14}
15
16func GetUserMarketConfig(userId int64, market string) (*MarketConfig, error) {
17 var config MarketConfig
18
19 err := DB.Model(&config).Where("user_id = ?", userId).Where("market_name = ?", market).First()
20
21 if err != nil && err != pg.ErrNoRows {
22 return nil, err
23 }
24
25 if err == pg.ErrNoRows {
26 return nil, nil
27 } else {
28 return &config, nil
29 }
30}
31
32func SetUserMarketConfig(userId int64, market string, newConfig map[string]string) (*MarketConfig, error) {
33 config := MarketConfig{
34 UserId: userId,
35 MarketName: market,
36 Config: newConfig,
37 }
38
39 _, err := DB.Model(&config).
40 OnConflict("(user_id, market_name) DO UPDATE").
41 Set("config = ?", newConfig).
42 Insert()
43
44 return &config, err
45}
diff --git a/db/user.go b/db/user.go
new file mode 100644
index 0000000..aed0ac1
--- /dev/null
+++ b/db/user.go
@@ -0,0 +1,72 @@
1package db
2
3import (
4 "golang.org/x/crypto/bcrypt"
5)
6
7type UserStatus uint8
8
9const (
10 Confirmed UserStatus = iota + 1
11 AwaitingConfirmation
12)
13
14type User struct {
15 Id int64
16 Email string `sql:",unique,notnull"`
17 PasswordHash string `sql:",notnull"`
18 OtpSecret string
19 IsOtpSetup bool
20 Status UserStatus
21}
22
23func HashPassword(password string) (string, error) {
24 b, err := bcrypt.GenerateFromPassword([]byte(password), 10)
25
26 return string(b), err
27}
28
29func ValidatePassword(password string, hash string) error {
30 return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
31}
32
33func InsertUser(user *User) error {
34 return DB.Insert(user)
35}
36
37func ConfirmUserByEmail(email string) error {
38 _, err := DB.Model(&User{}).Set("status=?", Confirmed).Where("email=?", email).Returning("*").Update()
39
40 return err
41}
42
43func GetUserById(id int64) (*User, error) {
44 user := User{Id: id}
45
46 err := DB.Select(&user)
47
48 return &user, err
49}
50
51func GetUserByEmail(email string) (*User, error) {
52 var users []User
53
54 err := DB.Model(&users).Where("email = ?", email).Select()
55
56 if err != nil {
57 return nil, err
58 }
59
60 if len(users) == 0 {
61 return nil, nil
62 }
63
64 return &users[0], nil
65}
66
67func SetOtpSecret(user *User, secret string, temporary bool) error {
68 user.OtpSecret = secret
69 user.IsOtpSetup = !temporary
70
71 return DB.Update(user)
72}