From 5ef098f47eaf50d7a3019be628c1f517e6df5011 Mon Sep 17 00:00:00 2001 From: Julien Dessaux Date: Sun, 6 Oct 2024 00:13:53 +0200 Subject: fix(tfstated): add lock handler tests --- pkg/database/locks.go | 39 +++++++++++++++++++++++++++++---------- pkg/database/sql/000_init.sql | 2 +- pkg/database/states.go | 3 +++ 3 files changed, 33 insertions(+), 11 deletions(-) (limited to 'pkg') diff --git a/pkg/database/locks.go b/pkg/database/locks.go index 53c6f6e..1f66975 100644 --- a/pkg/database/locks.go +++ b/pkg/database/locks.go @@ -1,12 +1,14 @@ package database import ( + "database/sql" "encoding/json" + "errors" ) -// Atomically check if there is an existing lock in place on the state. Returns -// true if it can be set, otherwise returns false and lock is set to the value -// of the existing lock +// Atomically check the lock status of a state and lock it if unlocked. Returns +// true if the function locked the state, otherwise returns false and the lock +// parameter is updated to the value of the existing lock func (db *DB) SetLockOrGetExistingLock(name string, lock any) (bool, error) { tx, err := db.Begin() if err != nil { @@ -17,18 +19,35 @@ func (db *DB) SetLockOrGetExistingLock(name string, lock any) (bool, error) { _ = tx.Rollback() } }() - var data []byte - if err = tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE name = ?;`, name).Scan(&data); err != nil { - return false, err + var lockData []byte + if err = tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE name = ?;`, name).Scan(&lockData); err != nil { + if errors.Is(err, sql.ErrNoRows) { + if lockData, err = json.Marshal(lock); err != nil { + return false, err + } + _, err = tx.ExecContext(db.ctx, + `INSERT INTO states(name, lock) VALUES (:name, json(:lock))`, + sql.Named("lock", lockData), + sql.Named("name", name), + ) + if err != nil { + return false, err + } + err = tx.Commit() + return true, err + } else { + return false, err + } } - if data != nil { - err = json.Unmarshal(data, lock) + if lockData != nil { + _ = tx.Rollback() + err = json.Unmarshal(lockData, lock) return false, err } - if data, err = json.Marshal(lock); err != nil { + if lockData, err = json.Marshal(lock); err != nil { return false, err } - _, err = tx.Exec(`UPDATE states SET lock = json(?) WHERE name = ?;`, data, name) + _, err = tx.ExecContext(db.ctx, `UPDATE states SET lock = json(?) WHERE name = ?;`, lockData, name) if err != nil { return false, err } diff --git a/pkg/database/sql/000_init.sql b/pkg/database/sql/000_init.sql index 6ec2ef7..8278433 100644 --- a/pkg/database/sql/000_init.sql +++ b/pkg/database/sql/000_init.sql @@ -5,7 +5,7 @@ CREATE TABLE schema_version ( CREATE TABLE states ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, - data BLOB NOT NULL, + data BLOB, lock TEXT ) STRICT; CREATE UNIQUE INDEX states_name on states(name); diff --git a/pkg/database/states.go b/pkg/database/states.go index 46a5536..500d403 100644 --- a/pkg/database/states.go +++ b/pkg/database/states.go @@ -24,6 +24,9 @@ func (db *DB) GetState(name string) ([]byte, error) { if err != nil { return nil, err } + if encryptedData == nil { + return []byte{}, nil + } return db.dataEncryptionKey.DecryptAES256(encryptedData) } -- cgit v1.2.3