From 5ef098f47eaf50d7a3019be628c1f517e6df5011 Mon Sep 17 00:00:00 2001 From: Julien Dessaux Date: Sun, 6 Oct 2024 00:13:53 +0200 Subject: fix(tfstated): add lock handler tests --- pkg/database/locks.go | 39 +++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) (limited to 'pkg/database/locks.go') diff --git a/pkg/database/locks.go b/pkg/database/locks.go index 53c6f6e..1f66975 100644 --- a/pkg/database/locks.go +++ b/pkg/database/locks.go @@ -1,12 +1,14 @@ package database import ( + "database/sql" "encoding/json" + "errors" ) -// Atomically check if there is an existing lock in place on the state. Returns -// true if it can be set, otherwise returns false and lock is set to the value -// of the existing lock +// Atomically check the lock status of a state and lock it if unlocked. Returns +// true if the function locked the state, otherwise returns false and the lock +// parameter is updated to the value of the existing lock func (db *DB) SetLockOrGetExistingLock(name string, lock any) (bool, error) { tx, err := db.Begin() if err != nil { @@ -17,18 +19,35 @@ func (db *DB) SetLockOrGetExistingLock(name string, lock any) (bool, error) { _ = tx.Rollback() } }() - var data []byte - if err = tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE name = ?;`, name).Scan(&data); err != nil { - return false, err + var lockData []byte + if err = tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE name = ?;`, name).Scan(&lockData); err != nil { + if errors.Is(err, sql.ErrNoRows) { + if lockData, err = json.Marshal(lock); err != nil { + return false, err + } + _, err = tx.ExecContext(db.ctx, + `INSERT INTO states(name, lock) VALUES (:name, json(:lock))`, + sql.Named("lock", lockData), + sql.Named("name", name), + ) + if err != nil { + return false, err + } + err = tx.Commit() + return true, err + } else { + return false, err + } } - if data != nil { - err = json.Unmarshal(data, lock) + if lockData != nil { + _ = tx.Rollback() + err = json.Unmarshal(lockData, lock) return false, err } - if data, err = json.Marshal(lock); err != nil { + if lockData, err = json.Marshal(lock); err != nil { return false, err } - _, err = tx.Exec(`UPDATE states SET lock = json(?) WHERE name = ?;`, data, name) + _, err = tx.ExecContext(db.ctx, `UPDATE states SET lock = json(?) WHERE name = ?;`, lockData, name) if err != nil { return false, err } -- cgit v1.2.3