summaryrefslogtreecommitdiff
path: root/pkg/database/locks.go
diff options
context:
space:
mode:
authorJulien Dessaux2024-11-18 00:10:58 +0100
committerJulien Dessaux2024-12-17 23:19:21 +0100
commitf649f7bbbf978db0ad50b0c6d1432264cfec3f43 (patch)
tree2e68e6cc116b387865ea2ceee67d1df1c5f50a7b /pkg/database/locks.go
parentchore(tfstated): refactor helpers to their own package (diff)
downloadtfstated-f649f7bbbf978db0ad50b0c6d1432264cfec3f43.tar.gz
tfstated-f649f7bbbf978db0ad50b0c6d1432264cfec3f43.tar.bz2
tfstated-f649f7bbbf978db0ad50b0c6d1432264cfec3f43.zip
chore(tfstated): implement a transaction wrapper
Diffstat (limited to 'pkg/database/locks.go')
-rw-r--r--pkg/database/locks.go61
1 files changed, 24 insertions, 37 deletions
diff --git a/pkg/database/locks.go b/pkg/database/locks.go
index 6951337..e78fd0c 100644
--- a/pkg/database/locks.go
+++ b/pkg/database/locks.go
@@ -10,45 +10,32 @@ import (
// true if the function locked the state, otherwise returns false and the lock
// parameter is updated to the value of the existing lock
func (db *DB) SetLockOrGetExistingLock(path string, lock any) (bool, error) {
- tx, err := db.Begin()
- if err != nil {
- return false, err
- }
- defer func() {
- if err != nil {
- _ = tx.Rollback()
- }
- }()
- var lockData []byte
- if err = tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE path = ?;`, path).Scan(&lockData); err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- if lockData, err = json.Marshal(lock); err != nil {
- return false, err
+ ret := false
+ return ret, db.WithTransaction(func(tx *sql.Tx) error {
+ var lockData []byte
+ if err := tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE path = ?;`, path).Scan(&lockData); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ if lockData, err = json.Marshal(lock); err != nil {
+ return err
+ }
+ _, err = tx.ExecContext(db.ctx, `INSERT INTO states(path, lock) VALUES (?, json(?))`, path, lockData)
+ ret = true
+ return err
+ } else {
+ return err
}
- _, err = tx.ExecContext(db.ctx, `INSERT INTO states(path, lock) VALUES (?, json(?))`, path, lockData)
- if err != nil {
- return false, err
- }
- err = tx.Commit()
- return true, err
- } else {
- return false, err
}
- }
- if lockData != nil {
- _ = tx.Rollback()
- err = json.Unmarshal(lockData, lock)
- return false, err
- }
- if lockData, err = json.Marshal(lock); err != nil {
- return false, err
- }
- _, err = tx.ExecContext(db.ctx, `UPDATE states SET lock = json(?) WHERE path = ?;`, lockData, path)
- if err != nil {
- return false, err
- }
- err = tx.Commit()
- return true, err
+ if lockData != nil {
+ return json.Unmarshal(lockData, lock)
+ }
+ var err error
+ if lockData, err = json.Marshal(lock); err != nil {
+ return err
+ }
+ _, err = tx.ExecContext(db.ctx, `UPDATE states SET lock = json(?) WHERE path = ?;`, lockData, path)
+ ret = true
+ return err
+ })
}
func (db *DB) Unlock(path, lock any) (bool, error) {