summaryrefslogtreecommitdiff
path: root/pkg/database/migrations.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/database/migrations.go')
-rw-r--r--pkg/database/migrations.go44
1 files changed, 16 insertions, 28 deletions
diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go
index b460884..31cc3d3 100644
--- a/pkg/database/migrations.go
+++ b/pkg/database/migrations.go
@@ -1,6 +1,7 @@
package database
import (
+ "database/sql"
"embed"
"io/fs"
@@ -28,36 +29,23 @@ func (db *DB) migrate() error {
return err
}
- tx, err := db.Begin()
- if err != nil {
- return err
- }
- defer func() {
- if err != nil {
- _ = tx.Rollback()
- }
- }()
-
- var version int
- if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil {
- if err.Error() == "no such table: schema_version" {
- version = 0
- } else {
- return err
+ return db.WithTransaction(func(tx *sql.Tx) error {
+ var version int
+ if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil {
+ if err.Error() == "no such table: schema_version" {
+ version = 0
+ } else {
+ return err
+ }
}
- }
- for version < len(statements) {
- if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil {
- return err
+ for version < len(statements) {
+ if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil {
+ return err
+ }
+ version++
}
- version++
- }
- if _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version); err != nil {
+ _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version)
return err
- }
- if err = tx.Commit(); err != nil {
- return err
- }
- return nil
+ })
}