From f649f7bbbf978db0ad50b0c6d1432264cfec3f43 Mon Sep 17 00:00:00 2001 From: Julien Dessaux Date: Mon, 18 Nov 2024 00:10:58 +0100 Subject: chore(tfstated): implement a transaction wrapper --- pkg/database/accounts.go | 55 ++++++++++++----------------- pkg/database/db.go | 22 +++++++++--- pkg/database/locks.go | 61 +++++++++++++------------------- pkg/database/migrations.go | 44 +++++++++-------------- pkg/database/states.go | 87 +++++++++++++++++++++------------------------- 5 files changed, 119 insertions(+), 150 deletions(-) (limited to 'pkg') diff --git a/pkg/database/accounts.go b/pkg/database/accounts.go index f506155..e1c7109 100644 --- a/pkg/database/accounts.go +++ b/pkg/database/accounts.go @@ -50,42 +50,31 @@ func (db *DB) LoadAccountByUsername(username string) (*model.Account, error) { } func (db *DB) InitAdminAccount() error { - tx, err := db.Begin() - if err != nil { - return err - } - defer func() { - if err != nil { - _ = tx.Rollback() - } - }() - var hasAdminAccount bool - if err = tx.QueryRowContext(db.ctx, `SELECT EXISTS (SELECT 1 FROM accounts WHERE is_admin);`).Scan(&hasAdminAccount); err != nil { - return fmt.Errorf("failed to select if there is an admin account in the database: %w", err) - } - if hasAdminAccount { - tx.Rollback() - } else { - var password uuid.UUID - if err = password.Generate(uuid.V4); err != nil { - return fmt.Errorf("failed to generate initial admin password: %w", err) + return db.WithTransaction(func(tx *sql.Tx) error { + var hasAdminAccount bool + if err := tx.QueryRowContext(db.ctx, `SELECT EXISTS (SELECT 1 FROM accounts WHERE is_admin);`).Scan(&hasAdminAccount); err != nil { + return fmt.Errorf("failed to select if there is an admin account in the database: %w", err) } - salt := helpers.GenerateSalt() - hash := helpers.HashPassword(password.String(), salt) - if _, err = tx.ExecContext(db.ctx, - `INSERT INTO accounts(username, salt, password_hash, is_admin) + if !hasAdminAccount { + var password uuid.UUID + if err := password.Generate(uuid.V4); err != nil { + return fmt.Errorf("failed to generate initial admin password: %w", err) + } + salt := helpers.GenerateSalt() + hash := helpers.HashPassword(password.String(), salt) + if _, err := tx.ExecContext(db.ctx, + `INSERT INTO accounts(username, salt, password_hash, is_admin) VALUES ("admin", :salt, :hash, TRUE) ON CONFLICT DO UPDATE SET password_hash = :hash WHERE username = "admin";`, - sql.Named("salt", salt), - sql.Named("hash", hash), - ); err != nil { - return fmt.Errorf("failed to set initial admin password: %w", err) - } - err = tx.Commit() - if err == nil { - AdvertiseAdminPassword(password.String()) + sql.Named("salt", salt), + sql.Named("hash", hash), + ); err == nil { + AdvertiseAdminPassword(password.String()) + } else { + return fmt.Errorf("failed to set initial admin password: %w", err) + } } - } - return err + return nil + }) } diff --git a/pkg/database/db.go b/pkg/database/db.go index 38265ba..5501a82 100644 --- a/pkg/database/db.go +++ b/pkg/database/db.go @@ -3,6 +3,7 @@ package database import ( "context" "database/sql" + "fmt" "runtime" "git.adyxax.org/adyxax/tfstated/pkg/scrypto" @@ -81,10 +82,6 @@ func NewDB(ctx context.Context, url string) (*DB, error) { return &db, nil } -func (db *DB) Begin() (*sql.Tx, error) { - return db.writeDB.Begin() -} - func (db *DB) Close() error { if err := db.readDB.Close(); err != nil { _ = db.writeDB.Close() @@ -107,3 +104,20 @@ func (db *DB) SetDataEncryptionKey(s string) error { func (db *DB) SetVersionsHistoryLimit(n int) { db.versionsHistoryLimit = n } + +func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error { + tx, err := db.writeDB.Begin() + if err != nil { + return err + } + err = f(tx) + if err == nil { + err = tx.Commit() + } + if err != nil { + if err2 := tx.Rollback(); err2 != nil { + panic(fmt.Sprintf("failed to rollback transaction: %+v. Reason for rollback: %+v", err2, err)) + } + } + return err +} diff --git a/pkg/database/locks.go b/pkg/database/locks.go index 6951337..e78fd0c 100644 --- a/pkg/database/locks.go +++ b/pkg/database/locks.go @@ -10,45 +10,32 @@ import ( // true if the function locked the state, otherwise returns false and the lock // parameter is updated to the value of the existing lock func (db *DB) SetLockOrGetExistingLock(path string, lock any) (bool, error) { - tx, err := db.Begin() - if err != nil { - return false, err - } - defer func() { - if err != nil { - _ = tx.Rollback() - } - }() - var lockData []byte - if err = tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE path = ?;`, path).Scan(&lockData); err != nil { - if errors.Is(err, sql.ErrNoRows) { - if lockData, err = json.Marshal(lock); err != nil { - return false, err + ret := false + return ret, db.WithTransaction(func(tx *sql.Tx) error { + var lockData []byte + if err := tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE path = ?;`, path).Scan(&lockData); err != nil { + if errors.Is(err, sql.ErrNoRows) { + if lockData, err = json.Marshal(lock); err != nil { + return err + } + _, err = tx.ExecContext(db.ctx, `INSERT INTO states(path, lock) VALUES (?, json(?))`, path, lockData) + ret = true + return err + } else { + return err } - _, err = tx.ExecContext(db.ctx, `INSERT INTO states(path, lock) VALUES (?, json(?))`, path, lockData) - if err != nil { - return false, err - } - err = tx.Commit() - return true, err - } else { - return false, err } - } - if lockData != nil { - _ = tx.Rollback() - err = json.Unmarshal(lockData, lock) - return false, err - } - if lockData, err = json.Marshal(lock); err != nil { - return false, err - } - _, err = tx.ExecContext(db.ctx, `UPDATE states SET lock = json(?) WHERE path = ?;`, lockData, path) - if err != nil { - return false, err - } - err = tx.Commit() - return true, err + if lockData != nil { + return json.Unmarshal(lockData, lock) + } + var err error + if lockData, err = json.Marshal(lock); err != nil { + return err + } + _, err = tx.ExecContext(db.ctx, `UPDATE states SET lock = json(?) WHERE path = ?;`, lockData, path) + ret = true + return err + }) } func (db *DB) Unlock(path, lock any) (bool, error) { diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go index b460884..31cc3d3 100644 --- a/pkg/database/migrations.go +++ b/pkg/database/migrations.go @@ -1,6 +1,7 @@ package database import ( + "database/sql" "embed" "io/fs" @@ -28,36 +29,23 @@ func (db *DB) migrate() error { return err } - tx, err := db.Begin() - if err != nil { - return err - } - defer func() { - if err != nil { - _ = tx.Rollback() - } - }() - - var version int - if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil { - if err.Error() == "no such table: schema_version" { - version = 0 - } else { - return err + return db.WithTransaction(func(tx *sql.Tx) error { + var version int + if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil { + if err.Error() == "no such table: schema_version" { + version = 0 + } else { + return err + } } - } - for version < len(statements) { - if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil { - return err + for version < len(statements) { + if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil { + return err + } + version++ } - version++ - } - if _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version); err != nil { + _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version) return err - } - if err = tx.Commit(); err != nil { - return err - } - return nil + }) } diff --git a/pkg/database/states.go b/pkg/database/states.go index 4f0ce58..74839da 100644 --- a/pkg/database/states.go +++ b/pkg/database/states.go @@ -48,52 +48,46 @@ func (db *DB) SetState(path string, accountID int, data []byte, lockID string) ( if err != nil { return false, err } - tx, err := db.Begin() - if err != nil { - return false, err - } - defer func() { - if err != nil { - _ = tx.Rollback() - } - }() - var ( - stateID int64 - lockData []byte - ) - if err = tx.QueryRowContext(db.ctx, `SELECT id, lock->>'ID' FROM states WHERE path = ?;`, path).Scan(&stateID, &lockData); err != nil { - if errors.Is(err, sql.ErrNoRows) { - var result sql.Result - result, err = tx.ExecContext(db.ctx, `INSERT INTO states(path) VALUES (?)`, path) - if err != nil { - return false, err - } - stateID, err = result.LastInsertId() - if err != nil { - return false, err + ret := false + return ret, db.WithTransaction(func(tx *sql.Tx) error { + var ( + stateID int64 + lockData []byte + ) + if err = tx.QueryRowContext(db.ctx, `SELECT id, lock->>'ID' FROM states WHERE path = ?;`, path).Scan(&stateID, &lockData); err != nil { + if errors.Is(err, sql.ErrNoRows) { + var result sql.Result + result, err = tx.ExecContext(db.ctx, `INSERT INTO states(path) VALUES (?)`, path) + if err != nil { + return err + } + stateID, err = result.LastInsertId() + if err != nil { + return err + } + } else { + return err } - } else { - return false, err } - } - if lockID != "" && slices.Compare([]byte(lockID), lockData) != 0 { - err = fmt.Errorf("failed to update state, lock ID does not match") - return true, err - } - _, err = tx.ExecContext(db.ctx, - `INSERT INTO versions(account_id, state_id, data, lock) + if lockID != "" && slices.Compare([]byte(lockID), lockData) != 0 { + err = fmt.Errorf("failed to update state, lock ID does not match") + ret = true + return err + } + _, err = tx.ExecContext(db.ctx, + `INSERT INTO versions(account_id, state_id, data, lock) SELECT :accountID, :stateID, :data, lock FROM states WHERE states.id = :stateID;`, - sql.Named("accountID", accountID), - sql.Named("stateID", stateID), - sql.Named("data", encryptedData)) - if err != nil { - return false, err - } - _, err = tx.ExecContext(db.ctx, - `DELETE FROM versions + sql.Named("accountID", accountID), + sql.Named("stateID", stateID), + sql.Named("data", encryptedData)) + if err != nil { + return err + } + _, err = tx.ExecContext(db.ctx, + `DELETE FROM versions WHERE state_id = (SELECT id FROM states WHERE path = :path) @@ -104,12 +98,9 @@ func (db *DB) SetState(path string, accountID int, data []byte, lockID string) ( WHERE states.path = :path ORDER BY versions.id DESC LIMIT :limit));`, - sql.Named("limit", db.versionsHistoryLimit), - sql.Named("path", path), - ) - if err != nil { - return false, err - } - err = tx.Commit() - return false, err + sql.Named("limit", db.versionsHistoryLimit), + sql.Named("path", path), + ) + return err + }) } -- cgit v1.2.3