diff options
Diffstat (limited to 'db')
-rw-r--r-- | db/db.go | 75 | ||||
-rw-r--r-- | db/db_test.go | 35 | ||||
-rw-r--r-- | db/errors.go | 23 | ||||
-rw-r--r-- | db/market_config.go | 45 | ||||
-rw-r--r-- | db/user.go | 72 |
5 files changed, 250 insertions, 0 deletions
diff --git a/db/db.go b/db/db.go new file mode 100644 index 0000000..bc4b8b3 --- /dev/null +++ b/db/db.go | |||
@@ -0,0 +1,75 @@ | |||
1 | package db | ||
2 | |||
3 | import ( | ||
4 | "fmt" | ||
5 | "strings" | ||
6 | |||
7 | "github.com/go-pg/pg" | ||
8 | "github.com/go-pg/pg/orm" | ||
9 | "github.com/jloup/utils" | ||
10 | ) | ||
11 | |||
12 | var DB *pg.DB | ||
13 | |||
14 | var log = utils.StandardL().WithField("module", "db") | ||
15 | |||
16 | type DBConfig struct { | ||
17 | Address string | ||
18 | Database string | ||
19 | User string | ||
20 | Password string | ||
21 | } | ||
22 | |||
23 | func Init(config DBConfig) { | ||
24 | var err error | ||
25 | |||
26 | DB = connect(config) | ||
27 | |||
28 | err = createSchema(DB) | ||
29 | if err != nil { | ||
30 | log.Errorf("cannot create schemas %v\n", err) | ||
31 | } | ||
32 | |||
33 | err = createIndexes(DB) | ||
34 | if err != nil { | ||
35 | log.Errorf("cannot create indexes %v\n", err) | ||
36 | } | ||
37 | } | ||
38 | |||
39 | func connect(config DBConfig) *pg.DB { | ||
40 | return pg.Connect(&pg.Options{ | ||
41 | User: config.User, | ||
42 | Password: config.Password, | ||
43 | Database: config.Database, | ||
44 | Addr: config.Address, | ||
45 | }) | ||
46 | } | ||
47 | |||
48 | func createSchema(db *pg.DB) error { | ||
49 | for _, model := range []interface{}{&User{}, &MarketConfig{}} { | ||
50 | err := db.CreateTable(model, &orm.CreateTableOptions{IfNotExists: true}) | ||
51 | if err != nil { | ||
52 | return err | ||
53 | } | ||
54 | } | ||
55 | return nil | ||
56 | } | ||
57 | |||
58 | func createIndexes(db *pg.DB) error { | ||
59 | indexes := []struct { | ||
60 | TableName string | ||
61 | Name string | ||
62 | Columns []string | ||
63 | }{ | ||
64 | {"market_configs", "market_name_user_id_idx", []string{"user_id", "market_name"}}, | ||
65 | } | ||
66 | |||
67 | for _, index := range indexes { | ||
68 | _, err := db.Exec(fmt.Sprintf("CREATE UNIQUE INDEX IF NOT EXISTS %s ON %s (%s)", index.Name, index.TableName, strings.Join(index.Columns, ","))) | ||
69 | if err != nil { | ||
70 | return err | ||
71 | } | ||
72 | } | ||
73 | |||
74 | return nil | ||
75 | } | ||
diff --git a/db/db_test.go b/db/db_test.go new file mode 100644 index 0000000..0481915 --- /dev/null +++ b/db/db_test.go | |||
@@ -0,0 +1,35 @@ | |||
1 | package db | ||
2 | |||
3 | import "testing" | ||
4 | |||
5 | func TestInit(t *testing.T) { | ||
6 | Init(DBConfig{"localhost:5432", "cryptoportfolio", "cryptoportfolio", "cryptoportfolio-dev"}) | ||
7 | } | ||
8 | |||
9 | func TestUpdateUser(t *testing.T) { | ||
10 | Init(DBConfig{"localhost:5432", "cryptoportfolio", "cryptoportfolio", "cryptoportfolio-dev"}) | ||
11 | t.Log(InsertUser(&User{Email: "j@test.com", PasswordHash: "yp"})) | ||
12 | err := InsertUser(&User{Email: "t2@test.com", PasswordHash: "yp"}) | ||
13 | |||
14 | t.Log(err, IsDup(err)) | ||
15 | |||
16 | t.Log(GetUserByEmail("testyo")) | ||
17 | } | ||
18 | |||
19 | func TestMarketConfig(t *testing.T) { | ||
20 | Init(DBConfig{"localhost:5432", "cryptoportfolio", "cryptoportfolio", "cryptoportfolio-dev"}) | ||
21 | |||
22 | config := MarketConfig{UserId: 1, MarketName: "poloniex"} | ||
23 | config.Config = make(map[string]string) | ||
24 | |||
25 | config.Config["secret"] = "key" | ||
26 | |||
27 | t.Log(InsertMarketConfig(&config)) | ||
28 | t.Log(config) | ||
29 | |||
30 | t.Log(GetUserMarketConfig(1, "poloniex")) | ||
31 | |||
32 | config.Config["secret2"] = "key2" | ||
33 | t.Log(SetUserMarketConfig(1, "poloniex", config.Config)) | ||
34 | t.Log(SetUserMarketConfig(1, "bifinance", config.Config)) | ||
35 | } | ||
diff --git a/db/errors.go b/db/errors.go new file mode 100644 index 0000000..ed5f371 --- /dev/null +++ b/db/errors.go | |||
@@ -0,0 +1,23 @@ | |||
1 | package db | ||
2 | |||
3 | import ( | ||
4 | "strings" | ||
5 | |||
6 | "github.com/go-pg/pg" | ||
7 | ) | ||
8 | |||
9 | func PGCode(err error) string { | ||
10 | if _, ok := err.(pg.Error); !ok { | ||
11 | return "" | ||
12 | } | ||
13 | |||
14 | return err.(pg.Error).Field('C') | ||
15 | } | ||
16 | |||
17 | func IsDup(err error) bool { | ||
18 | return PGCode(err) == "23505" | ||
19 | } | ||
20 | |||
21 | func IsSQLError(err error) bool { | ||
22 | return strings.HasPrefix(err.Error(), "ERROR #") | ||
23 | } | ||
diff --git a/db/market_config.go b/db/market_config.go new file mode 100644 index 0000000..b26c092 --- /dev/null +++ b/db/market_config.go | |||
@@ -0,0 +1,45 @@ | |||
1 | package db | ||
2 | |||
3 | import "github.com/go-pg/pg" | ||
4 | |||
5 | type MarketConfig struct { | ||
6 | Id int64 | ||
7 | MarketName string `sql:",notnull"` | ||
8 | UserId int64 `sql:",notnull"` | ||
9 | Config map[string]string | ||
10 | } | ||
11 | |||
12 | func InsertMarketConfig(config *MarketConfig) error { | ||
13 | return DB.Insert(config) | ||
14 | } | ||
15 | |||
16 | func GetUserMarketConfig(userId int64, market string) (*MarketConfig, error) { | ||
17 | var config MarketConfig | ||
18 | |||
19 | err := DB.Model(&config).Where("user_id = ?", userId).Where("market_name = ?", market).First() | ||
20 | |||
21 | if err != nil && err != pg.ErrNoRows { | ||
22 | return nil, err | ||
23 | } | ||
24 | |||
25 | if err == pg.ErrNoRows { | ||
26 | return nil, nil | ||
27 | } else { | ||
28 | return &config, nil | ||
29 | } | ||
30 | } | ||
31 | |||
32 | func SetUserMarketConfig(userId int64, market string, newConfig map[string]string) (*MarketConfig, error) { | ||
33 | config := MarketConfig{ | ||
34 | UserId: userId, | ||
35 | MarketName: market, | ||
36 | Config: newConfig, | ||
37 | } | ||
38 | |||
39 | _, err := DB.Model(&config). | ||
40 | OnConflict("(user_id, market_name) DO UPDATE"). | ||
41 | Set("config = ?", newConfig). | ||
42 | Insert() | ||
43 | |||
44 | return &config, err | ||
45 | } | ||
diff --git a/db/user.go b/db/user.go new file mode 100644 index 0000000..aed0ac1 --- /dev/null +++ b/db/user.go | |||
@@ -0,0 +1,72 @@ | |||
1 | package db | ||
2 | |||
3 | import ( | ||
4 | "golang.org/x/crypto/bcrypt" | ||
5 | ) | ||
6 | |||
7 | type UserStatus uint8 | ||
8 | |||
9 | const ( | ||
10 | Confirmed UserStatus = iota + 1 | ||
11 | AwaitingConfirmation | ||
12 | ) | ||
13 | |||
14 | type User struct { | ||
15 | Id int64 | ||
16 | Email string `sql:",unique,notnull"` | ||
17 | PasswordHash string `sql:",notnull"` | ||
18 | OtpSecret string | ||
19 | IsOtpSetup bool | ||
20 | Status UserStatus | ||
21 | } | ||
22 | |||
23 | func HashPassword(password string) (string, error) { | ||
24 | b, err := bcrypt.GenerateFromPassword([]byte(password), 10) | ||
25 | |||
26 | return string(b), err | ||
27 | } | ||
28 | |||
29 | func ValidatePassword(password string, hash string) error { | ||
30 | return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) | ||
31 | } | ||
32 | |||
33 | func InsertUser(user *User) error { | ||
34 | return DB.Insert(user) | ||
35 | } | ||
36 | |||
37 | func ConfirmUserByEmail(email string) error { | ||
38 | _, err := DB.Model(&User{}).Set("status=?", Confirmed).Where("email=?", email).Returning("*").Update() | ||
39 | |||
40 | return err | ||
41 | } | ||
42 | |||
43 | func GetUserById(id int64) (*User, error) { | ||
44 | user := User{Id: id} | ||
45 | |||
46 | err := DB.Select(&user) | ||
47 | |||
48 | return &user, err | ||
49 | } | ||
50 | |||
51 | func GetUserByEmail(email string) (*User, error) { | ||
52 | var users []User | ||
53 | |||
54 | err := DB.Model(&users).Where("email = ?", email).Select() | ||
55 | |||
56 | if err != nil { | ||
57 | return nil, err | ||
58 | } | ||
59 | |||
60 | if len(users) == 0 { | ||
61 | return nil, nil | ||
62 | } | ||
63 | |||
64 | return &users[0], nil | ||
65 | } | ||
66 | |||
67 | func SetOtpSecret(user *User, secret string, temporary bool) error { | ||
68 | user.OtpSecret = secret | ||
69 | user.IsOtpSetup = !temporary | ||
70 | |||
71 | return DB.Update(user) | ||
72 | } | ||