summaryrefslogtreecommitdiff
path: root/golang/pkg/database/migrations.go
diff options
context:
space:
mode:
authorJulien Dessaux2025-02-12 20:20:02 +0100
committerJulien Dessaux2025-02-12 20:31:25 +0100
commit5611f5b468303aee16f5cb21c8eda9a44bb5d013 (patch)
tree41e99f399b358d8f58cf584b8dd9c215d6402fd4 /golang/pkg/database/migrations.go
parent[golang] updated dependencies (diff)
downloadspacetraders-5611f5b468303aee16f5cb21c8eda9a44bb5d013.tar.gz
spacetraders-5611f5b468303aee16f5cb21c8eda9a44bb5d013.tar.bz2
spacetraders-5611f5b468303aee16f5cb21c8eda9a44bb5d013.zip
[golang] rewrite database handling
Diffstat (limited to '')
-rw-r--r--golang/pkg/database/migrations.go69
1 files changed, 20 insertions, 49 deletions
diff --git a/golang/pkg/database/migrations.go b/golang/pkg/database/migrations.go
index b167087..31cc3d3 100644
--- a/golang/pkg/database/migrations.go
+++ b/golang/pkg/database/migrations.go
@@ -1,7 +1,6 @@
package database
import (
- "context"
"database/sql"
"embed"
"io/fs"
@@ -12,35 +11,9 @@ import (
//go:embed sql/*.sql
var schemaFiles embed.FS
-func DBInit(ctx context.Context, url string) (myDB *DB, err error) {
- var db *sql.DB
- if db, err = sql.Open("sqlite3", url); err != nil {
- return nil, err
- }
- defer func() {
- if err != nil {
- _ = db.Close()
- }
- }()
-
- if _, err = db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil {
- return nil, err
- }
- if _, err = db.ExecContext(ctx, "PRAGMA journal_mode = WAL"); err != nil {
- return nil, err
- }
-
- var version int
- if err = db.QueryRowContext(ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil {
- if err.Error() == "no such table: schema_version" {
- version = 0
- } else {
- return nil, err
- }
- }
-
+func (db *DB) migrate() error {
statements := make([]string, 0)
- err = fs.WalkDir(schemaFiles, ".", func(path string, d fs.DirEntry, err error) error {
+ err := fs.WalkDir(schemaFiles, ".", func(path string, d fs.DirEntry, err error) error {
if d.IsDir() || err != nil {
return err
}
@@ -53,28 +26,26 @@ func DBInit(ctx context.Context, url string) (myDB *DB, err error) {
return nil
})
if err != nil {
- return nil, err
+ return err
}
- tx, err := db.Begin()
- if err != nil {
- return nil, err
- }
- for version < len(statements) {
- if _, err = tx.ExecContext(ctx, statements[version]); err != nil {
- tx.Rollback()
- return nil, err
+ return db.WithTransaction(func(tx *sql.Tx) error {
+ var version int
+ if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil {
+ if err.Error() == "no such table: schema_version" {
+ version = 0
+ } else {
+ return err
+ }
}
- version++
- }
- if _, err = tx.ExecContext(ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version); err != nil {
- tx.Rollback()
- return nil, err
- }
- tx.Commit()
- return &DB{ctx: ctx, db: db}, nil
-}
-func (db *DB) Close() error {
- return db.db.Close()
+ for version < len(statements) {
+ if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil {
+ return err
+ }
+ version++
+ }
+ _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version)
+ return err
+ })
}