-
-
Notifications
You must be signed in to change notification settings - Fork 3
/
repository.go
127 lines (99 loc) · 3.07 KB
/
repository.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package migrate
import (
"context"
"database/sql"
"errors"
"fmt"
_ "github.com/lib/pq" // postgres driver
)
type repository interface {
GetLatestMigrationNumber() (uint, error)
ApplyMigration(txFunc func(Tx) error) error
InsertMigration(m *migration) error
RemoveMigrationsAfter(number uint) error
EnsureMigrationTable() error
DropSchema(schemaName string) error
}
type repo struct {
db *sql.DB
}
func newRepo(databaseURI string) (*repo, error) {
db, err := sql.Open("postgres", databaseURI)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
return &repo{db}, nil
}
// GetLatestMigrationNumber returns 0,nil if not found.
func (r *repo) GetLatestMigrationNumber() (uint, error) {
var latestMigrationNumber uint
const query = "SELECT number FROM migrations ORDER BY number DESC LIMIT 1"
err := r.db.QueryRowContext(context.TODO(), query).
Scan(&latestMigrationNumber)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, nil
}
return 0, fmt.Errorf("failed to get latest migration number: %w", err)
}
return latestMigrationNumber, nil
}
func (r *repo) ApplyMigration(txFunc func(Tx) error) error {
dbTransaction, err := r.db.Begin()
if err != nil {
return fmt.Errorf("starting transaction: %w", err)
}
if err = txFunc(Tx{dbTransaction}); err != nil {
if rollbackErr := dbTransaction.Rollback(); rollbackErr != nil {
return fmt.Errorf("failed to rollback after failed transaction: %w", rollbackErr)
}
return fmt.Errorf("failed to apply the migration (rolled back successfully though): %w", err)
}
if err = dbTransaction.Commit(); err != nil {
if rollbackErr := dbTransaction.Rollback(); rollbackErr != nil {
return fmt.Errorf("failed to rollback after failed commit: %w", rollbackErr)
}
return fmt.Errorf("failed to commit the Transaction: %w", err)
}
return nil
}
func (r *repo) InsertMigration(m *migration) error {
const query = "INSERT INTO migrations (number, name) VALUES ($1, $2)"
_, err := r.db.ExecContext(context.TODO(), query, m.Number, m.Name)
if err != nil {
return fmt.Errorf("failed to create migration record: %w", err)
}
return nil
}
func (r *repo) RemoveMigrationsAfter(number uint) error {
const query = "DELETE FROM migrations WHERE number >= $1"
_, err := r.db.ExecContext(context.TODO(), query, number)
if err != nil {
return fmt.Errorf("failed to delete migrations: %w", err)
}
return nil
}
func (r *repo) EnsureMigrationTable() error {
const query = `
CREATE TABLE IF NOT EXISTS migrations (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP NOT NULL DEFAULT NOW(),
number INTEGER NOT NULL UNIQUE,
name VARCHAR(255) NOT NULL
)
`
_, err := r.db.ExecContext(context.TODO(), query)
if err != nil {
return fmt.Errorf("failed to ensure migration table: %w", err)
}
return nil
}
func (r *repo) DropSchema(schemaName string) error {
_, err := r.db.ExecContext(context.TODO(),
fmt.Sprintf(`DROP SCHEMA IF EXISTS %q CASCADE; CREATE SCHEMA IF NOT EXISTS %q;`,
schemaName, schemaName))
if err != nil {
return fmt.Errorf("failed to drop schema: %w", err)
}
return nil
}