diff options
Diffstat (limited to 'db/db.go')
-rw-r--r-- | db/db.go | 86 |
1 files changed, 50 insertions, 36 deletions
@@ -1,9 +1,7 @@ | |||
1 | package db | 1 | package db |
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "fmt" | 4 | migrate "github.com/go-pg/migrations" |
5 | "strings" | ||
6 | |||
7 | "github.com/go-pg/pg" | 5 | "github.com/go-pg/pg" |
8 | "github.com/go-pg/pg/orm" | 6 | "github.com/go-pg/pg/orm" |
9 | "github.com/jloup/utils" | 7 | "github.com/jloup/utils" |
@@ -25,15 +23,60 @@ func Init(config DBConfig) { | |||
25 | 23 | ||
26 | DB = connect(config) | 24 | DB = connect(config) |
27 | 25 | ||
28 | err = createSchema(DB) | 26 | err = migratedb() |
29 | if err != nil { | 27 | if err != nil { |
30 | log.Errorf("cannot create schemas %v\n", err) | 28 | log.Fatalf("cannot migratedb '%v'\n", err) |
31 | } | 29 | } |
30 | } | ||
32 | 31 | ||
33 | err = createIndexes(DB) | 32 | func migratedb() error { |
33 | /* Remove after first MEP */ | ||
34 | version, err := migrate.Version(DB) | ||
34 | if err != nil { | 35 | if err != nil { |
35 | log.Errorf("cannot create indexes %v\n", err) | 36 | return err |
37 | } | ||
38 | |||
39 | if version == 0 { | ||
40 | return migrate.SetVersion(DB, 1) | ||
36 | } | 41 | } |
42 | /***/ | ||
43 | |||
44 | mig := make([]migrate.Migration, 0) | ||
45 | |||
46 | for _, migration := range migrations { | ||
47 | mig = append(mig, migrate.Migration{ | ||
48 | Version: migration.Version, | ||
49 | Up: func(db orm.DB) error { | ||
50 | for _, query := range migration.Up { | ||
51 | _, err := db.Exec(query) | ||
52 | if err != nil { | ||
53 | return err | ||
54 | } | ||
55 | } | ||
56 | |||
57 | return nil | ||
58 | }, | ||
59 | Down: func(db orm.DB) error { | ||
60 | for _, query := range migration.Down { | ||
61 | _, err := db.Exec(query) | ||
62 | if err != nil { | ||
63 | return err | ||
64 | } | ||
65 | } | ||
66 | |||
67 | return nil | ||
68 | }, | ||
69 | }) | ||
70 | } | ||
71 | |||
72 | oldVersion, newVersion, err := migrate.RunMigrations(DB, mig, "up") | ||
73 | |||
74 | if oldVersion != newVersion { | ||
75 | log.Infof("Migrate DB: %v -> %v", oldVersion, newVersion) | ||
76 | } else { | ||
77 | log.Infof("DB up-to-date: version '%v'", newVersion) | ||
78 | } | ||
79 | return err | ||
37 | } | 80 | } |
38 | 81 | ||
39 | func connect(config DBConfig) *pg.DB { | 82 | func connect(config DBConfig) *pg.DB { |
@@ -44,32 +87,3 @@ func connect(config DBConfig) *pg.DB { | |||
44 | Addr: config.Address, | 87 | Addr: config.Address, |
45 | }) | 88 | }) |
46 | } | 89 | } |
47 | |||
48 | func createSchema(db *pg.DB) error { | ||
49 | for _, model := range []interface{}{&User{}, &MarketConfig{}} { | ||
50 | err := db.CreateTable(model, &orm.CreateTableOptions{IfNotExists: true}) | ||
51 | if err != nil { | ||
52 | return err | ||
53 | } | ||
54 | } | ||
55 | return nil | ||
56 | } | ||
57 | |||
58 | func createIndexes(db *pg.DB) error { | ||
59 | indexes := []struct { | ||
60 | TableName string | ||
61 | Name string | ||
62 | Columns []string | ||
63 | }{ | ||
64 | {"market_configs", "market_name_user_id_idx", []string{"user_id", "market_name"}}, | ||
65 | } | ||
66 | |||
67 | for _, index := range indexes { | ||
68 | _, err := db.Exec(fmt.Sprintf("CREATE UNIQUE INDEX IF NOT EXISTS %s ON %s (%s)", index.Name, index.TableName, strings.Join(index.Columns, ","))) | ||
69 | if err != nil { | ||
70 | return err | ||
71 | } | ||
72 | } | ||
73 | |||
74 | return nil | ||
75 | } | ||