From f649f7bbbf978db0ad50b0c6d1432264cfec3f43 Mon Sep 17 00:00:00 2001 From: Julien Dessaux Date: Mon, 18 Nov 2024 00:10:58 +0100 Subject: chore(tfstated): implement a transaction wrapper --- pkg/database/states.go | 87 ++++++++++++++++++++++---------------------------- 1 file changed, 39 insertions(+), 48 deletions(-) (limited to 'pkg/database/states.go') diff --git a/pkg/database/states.go b/pkg/database/states.go index 4f0ce58..74839da 100644 --- a/pkg/database/states.go +++ b/pkg/database/states.go @@ -48,52 +48,46 @@ func (db *DB) SetState(path string, accountID int, data []byte, lockID string) ( if err != nil { return false, err } - tx, err := db.Begin() - if err != nil { - return false, err - } - defer func() { - if err != nil { - _ = tx.Rollback() - } - }() - var ( - stateID int64 - lockData []byte - ) - if err = tx.QueryRowContext(db.ctx, `SELECT id, lock->>'ID' FROM states WHERE path = ?;`, path).Scan(&stateID, &lockData); err != nil { - if errors.Is(err, sql.ErrNoRows) { - var result sql.Result - result, err = tx.ExecContext(db.ctx, `INSERT INTO states(path) VALUES (?)`, path) - if err != nil { - return false, err - } - stateID, err = result.LastInsertId() - if err != nil { - return false, err + ret := false + return ret, db.WithTransaction(func(tx *sql.Tx) error { + var ( + stateID int64 + lockData []byte + ) + if err = tx.QueryRowContext(db.ctx, `SELECT id, lock->>'ID' FROM states WHERE path = ?;`, path).Scan(&stateID, &lockData); err != nil { + if errors.Is(err, sql.ErrNoRows) { + var result sql.Result + result, err = tx.ExecContext(db.ctx, `INSERT INTO states(path) VALUES (?)`, path) + if err != nil { + return err + } + stateID, err = result.LastInsertId() + if err != nil { + return err + } + } else { + return err } - } else { - return false, err } - } - if lockID != "" && slices.Compare([]byte(lockID), lockData) != 0 { - err = fmt.Errorf("failed to update state, lock ID does not match") - return true, err - } - _, err = tx.ExecContext(db.ctx, - `INSERT INTO versions(account_id, state_id, data, lock) + if lockID != "" && slices.Compare([]byte(lockID), lockData) != 0 { + err = fmt.Errorf("failed to update state, lock ID does not match") + ret = true + return err + } + _, err = tx.ExecContext(db.ctx, + `INSERT INTO versions(account_id, state_id, data, lock) SELECT :accountID, :stateID, :data, lock FROM states WHERE states.id = :stateID;`, - sql.Named("accountID", accountID), - sql.Named("stateID", stateID), - sql.Named("data", encryptedData)) - if err != nil { - return false, err - } - _, err = tx.ExecContext(db.ctx, - `DELETE FROM versions + sql.Named("accountID", accountID), + sql.Named("stateID", stateID), + sql.Named("data", encryptedData)) + if err != nil { + return err + } + _, err = tx.ExecContext(db.ctx, + `DELETE FROM versions WHERE state_id = (SELECT id FROM states WHERE path = :path) @@ -104,12 +98,9 @@ func (db *DB) SetState(path string, accountID int, data []byte, lockID string) ( WHERE states.path = :path ORDER BY versions.id DESC LIMIT :limit));`, - sql.Named("limit", db.versionsHistoryLimit), - sql.Named("path", path), - ) - if err != nil { - return false, err - } - err = tx.Commit() - return false, err + sql.Named("limit", db.versionsHistoryLimit), + sql.Named("path", path), + ) + return err + }) } -- cgit v1.2.3