diff options
Diffstat (limited to 'pkg/database/migrations.go')
-rw-r--r-- | pkg/database/migrations.go | 44 |
1 files changed, 16 insertions, 28 deletions
diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go index b460884..31cc3d3 100644 --- a/pkg/database/migrations.go +++ b/pkg/database/migrations.go @@ -1,6 +1,7 @@ package database import ( + "database/sql" "embed" "io/fs" @@ -28,36 +29,23 @@ func (db *DB) migrate() error { return err } - tx, err := db.Begin() - if err != nil { - return err - } - defer func() { - if err != nil { - _ = tx.Rollback() - } - }() - - var version int - if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil { - if err.Error() == "no such table: schema_version" { - version = 0 - } else { - return err + return db.WithTransaction(func(tx *sql.Tx) error { + var version int + if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil { + if err.Error() == "no such table: schema_version" { + version = 0 + } else { + return err + } } - } - for version < len(statements) { - if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil { - return err + for version < len(statements) { + if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil { + return err + } + version++ } - version++ - } - if _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version); err != nil { + _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version) return err - } - if err = tx.Commit(); err != nil { - return err - } - return nil + }) } |