diff options
author | Julien Dessaux | 2025-02-12 20:20:02 +0100 |
---|---|---|
committer | Julien Dessaux | 2025-02-12 20:31:25 +0100 |
commit | 5611f5b468303aee16f5cb21c8eda9a44bb5d013 (patch) | |
tree | 41e99f399b358d8f58cf584b8dd9c215d6402fd4 /golang/pkg/database/migrations.go | |
parent | [golang] updated dependencies (diff) | |
download | spacetraders-5611f5b468303aee16f5cb21c8eda9a44bb5d013.tar.gz spacetraders-5611f5b468303aee16f5cb21c8eda9a44bb5d013.tar.bz2 spacetraders-5611f5b468303aee16f5cb21c8eda9a44bb5d013.zip |
[golang] rewrite database handling
Diffstat (limited to '')
-rw-r--r-- | golang/pkg/database/migrations.go | 69 |
1 files changed, 20 insertions, 49 deletions
diff --git a/golang/pkg/database/migrations.go b/golang/pkg/database/migrations.go index b167087..31cc3d3 100644 --- a/golang/pkg/database/migrations.go +++ b/golang/pkg/database/migrations.go @@ -1,7 +1,6 @@ package database import ( - "context" "database/sql" "embed" "io/fs" @@ -12,35 +11,9 @@ import ( //go:embed sql/*.sql var schemaFiles embed.FS -func DBInit(ctx context.Context, url string) (myDB *DB, err error) { - var db *sql.DB - if db, err = sql.Open("sqlite3", url); err != nil { - return nil, err - } - defer func() { - if err != nil { - _ = db.Close() - } - }() - - if _, err = db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { - return nil, err - } - if _, err = db.ExecContext(ctx, "PRAGMA journal_mode = WAL"); err != nil { - return nil, err - } - - var version int - if err = db.QueryRowContext(ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil { - if err.Error() == "no such table: schema_version" { - version = 0 - } else { - return nil, err - } - } - +func (db *DB) migrate() error { statements := make([]string, 0) - err = fs.WalkDir(schemaFiles, ".", func(path string, d fs.DirEntry, err error) error { + err := fs.WalkDir(schemaFiles, ".", func(path string, d fs.DirEntry, err error) error { if d.IsDir() || err != nil { return err } @@ -53,28 +26,26 @@ func DBInit(ctx context.Context, url string) (myDB *DB, err error) { return nil }) if err != nil { - return nil, err + return err } - tx, err := db.Begin() - if err != nil { - return nil, err - } - for version < len(statements) { - if _, err = tx.ExecContext(ctx, statements[version]); err != nil { - tx.Rollback() - return nil, err + return db.WithTransaction(func(tx *sql.Tx) error { + var version int + if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil { + if err.Error() == "no such table: schema_version" { + version = 0 + } else { + return err + } } - version++ - } - if _, err = tx.ExecContext(ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version); err != nil { - tx.Rollback() - return nil, err - } - tx.Commit() - return &DB{ctx: ctx, db: db}, nil -} -func (db *DB) Close() error { - return db.db.Close() + for version < len(statements) { + if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil { + return err + } + version++ + } + _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version) + return err + }) } |