diff options
author | Julien Dessaux | 2025-02-12 20:20:02 +0100 |
---|---|---|
committer | Julien Dessaux | 2025-02-12 20:31:25 +0100 |
commit | 5611f5b468303aee16f5cb21c8eda9a44bb5d013 (patch) | |
tree | 41e99f399b358d8f58cf584b8dd9c215d6402fd4 /golang | |
parent | [golang] updated dependencies (diff) | |
download | spacetraders-5611f5b468303aee16f5cb21c8eda9a44bb5d013.tar.gz spacetraders-5611f5b468303aee16f5cb21c8eda9a44bb5d013.tar.bz2 spacetraders-5611f5b468303aee16f5cb21c8eda9a44bb5d013.zip |
[golang] rewrite database handling
Diffstat (limited to '')
-rw-r--r-- | golang/cmd/spacetraders/main.go | 5 | ||||
-rw-r--r-- | golang/pkg/database/db.go | 114 | ||||
-rw-r--r-- | golang/pkg/database/migrations.go | 69 |
3 files changed, 134 insertions, 54 deletions
diff --git a/golang/cmd/spacetraders/main.go b/golang/cmd/spacetraders/main.go index 84d34ea..e97cc32 100644 --- a/golang/cmd/spacetraders/main.go +++ b/golang/cmd/spacetraders/main.go @@ -26,7 +26,10 @@ func main() { ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) defer cancel() - db, err := database.DBInit(ctx, "./spacetraders.db") + db, err := database.NewDB( + ctx, + "./spacetraders.db?_txlock=immediate", + ) if err != nil { fmt.Fprintf(os.Stderr, "DbInit error %+v\n", err) os.Exit(1) diff --git a/golang/pkg/database/db.go b/golang/pkg/database/db.go index 3e07db3..cb15e52 100644 --- a/golang/pkg/database/db.go +++ b/golang/pkg/database/db.go @@ -3,17 +3,123 @@ package database import ( "context" "database/sql" + "fmt" + "runtime" ) +func initDB(ctx context.Context, url string) (*sql.DB, error) { + db, err := sql.Open("sqlite3", url) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + defer func() { + if err != nil { + _ = db.Close() + } + }() + if _, err = db.ExecContext(ctx, "PRAGMA busy_timeout = 5000"); err != nil { + return nil, fmt.Errorf("failed to set pragma: %w", err) + } + + return db, nil +} + type DB struct { - ctx context.Context - db *sql.DB + ctx context.Context + readDB *sql.DB + writeDB *sql.DB +} + +func NewDB(ctx context.Context, url string) (*DB, error) { + readDB, err := initDB(ctx, url) + if err != nil { + return nil, fmt.Errorf("failed to init read database connection: %w", err) + } + defer func() { + if err != nil { + _ = readDB.Close() + } + }() + readDB.SetMaxOpenConns(max(4, runtime.NumCPU())) + + writeDB, err := initDB(ctx, url) + if err != nil { + return nil, fmt.Errorf("failed to init write database connection: %w", err) + } + defer func() { + if err != nil { + _ = writeDB.Close() + } + }() + writeDB.SetMaxOpenConns(1) + + db := DB{ + ctx: ctx, + readDB: readDB, + writeDB: writeDB, + } + pragmas := []struct { + key string + value string + }{ + {"foreign_keys", "ON"}, + {"cache_size", "10000000"}, + {"journal_mode", "WAL"}, + {"synchronous", "NORMAL"}, + } + for _, pragma := range pragmas { + if _, err = db.Exec(fmt.Sprintf("PRAGMA %s = %s", pragma.key, pragma.value)); err != nil { + return nil, fmt.Errorf("failed to set pragma: %w", err) + } + } + if err = db.migrate(); err != nil { + return nil, fmt.Errorf("failed to migrate: %w", err) + } + + return &db, nil +} + +func (db *DB) Close() error { + if err := db.readDB.Close(); err != nil { + _ = db.writeDB.Close() + return fmt.Errorf("failed to close read database connection: %w", err) + } + if err := db.writeDB.Close(); err != nil { + return fmt.Errorf("failed to close write database connection: %w", err) + } + return nil } func (db *DB) Exec(query string, args ...any) (sql.Result, error) { - return db.db.ExecContext(db.ctx, query, args...) + return db.writeDB.ExecContext(db.ctx, query, args...) +} + +func (db *DB) Query(query string, args ...any) (*sql.Rows, error) { + return db.readDB.QueryContext(db.ctx, query, args...) } func (db *DB) QueryRow(query string, args ...any) *sql.Row { - return db.db.QueryRowContext(db.ctx, query, args...) + return db.readDB.QueryRowContext(db.ctx, query, args...) +} + +func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error { + tx, err := db.writeDB.Begin() + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer func() { + if err != nil { + if err2 := tx.Rollback(); err2 != nil { + panic(fmt.Sprintf("failed to rollback transaction: %+v. Reason for rollback: %+v", err2, err)) + } + } + }() + if err = f(tx); err != nil { + return fmt.Errorf("failed to execute function inside transaction: %w", err) + } else { + if err = tx.Commit(); err != nil { + err = fmt.Errorf("failed to commit transaction: %w", err) + } + } + return err } diff --git a/golang/pkg/database/migrations.go b/golang/pkg/database/migrations.go index b167087..31cc3d3 100644 --- a/golang/pkg/database/migrations.go +++ b/golang/pkg/database/migrations.go @@ -1,7 +1,6 @@ package database import ( - "context" "database/sql" "embed" "io/fs" @@ -12,35 +11,9 @@ import ( //go:embed sql/*.sql var schemaFiles embed.FS -func DBInit(ctx context.Context, url string) (myDB *DB, err error) { - var db *sql.DB - if db, err = sql.Open("sqlite3", url); err != nil { - return nil, err - } - defer func() { - if err != nil { - _ = db.Close() - } - }() - - if _, err = db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil { - return nil, err - } - if _, err = db.ExecContext(ctx, "PRAGMA journal_mode = WAL"); err != nil { - return nil, err - } - - var version int - if err = db.QueryRowContext(ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil { - if err.Error() == "no such table: schema_version" { - version = 0 - } else { - return nil, err - } - } - +func (db *DB) migrate() error { statements := make([]string, 0) - err = fs.WalkDir(schemaFiles, ".", func(path string, d fs.DirEntry, err error) error { + err := fs.WalkDir(schemaFiles, ".", func(path string, d fs.DirEntry, err error) error { if d.IsDir() || err != nil { return err } @@ -53,28 +26,26 @@ func DBInit(ctx context.Context, url string) (myDB *DB, err error) { return nil }) if err != nil { - return nil, err + return err } - tx, err := db.Begin() - if err != nil { - return nil, err - } - for version < len(statements) { - if _, err = tx.ExecContext(ctx, statements[version]); err != nil { - tx.Rollback() - return nil, err + return db.WithTransaction(func(tx *sql.Tx) error { + var version int + if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil { + if err.Error() == "no such table: schema_version" { + version = 0 + } else { + return err + } } - version++ - } - if _, err = tx.ExecContext(ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version); err != nil { - tx.Rollback() - return nil, err - } - tx.Commit() - return &DB{ctx: ctx, db: db}, nil -} -func (db *DB) Close() error { - return db.db.Close() + for version < len(statements) { + if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil { + return err + } + version++ + } + _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version) + return err + }) } |