summaryrefslogtreecommitdiff
path: root/pkg/database/states.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--pkg/database/states.go87
1 files changed, 39 insertions, 48 deletions
diff --git a/pkg/database/states.go b/pkg/database/states.go
index 4f0ce58..74839da 100644
--- a/pkg/database/states.go
+++ b/pkg/database/states.go
@@ -48,52 +48,46 @@ func (db *DB) SetState(path string, accountID int, data []byte, lockID string) (
if err != nil {
return false, err
}
- tx, err := db.Begin()
- if err != nil {
- return false, err
- }
- defer func() {
- if err != nil {
- _ = tx.Rollback()
- }
- }()
- var (
- stateID int64
- lockData []byte
- )
- if err = tx.QueryRowContext(db.ctx, `SELECT id, lock->>'ID' FROM states WHERE path = ?;`, path).Scan(&stateID, &lockData); err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- var result sql.Result
- result, err = tx.ExecContext(db.ctx, `INSERT INTO states(path) VALUES (?)`, path)
- if err != nil {
- return false, err
- }
- stateID, err = result.LastInsertId()
- if err != nil {
- return false, err
+ ret := false
+ return ret, db.WithTransaction(func(tx *sql.Tx) error {
+ var (
+ stateID int64
+ lockData []byte
+ )
+ if err = tx.QueryRowContext(db.ctx, `SELECT id, lock->>'ID' FROM states WHERE path = ?;`, path).Scan(&stateID, &lockData); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ var result sql.Result
+ result, err = tx.ExecContext(db.ctx, `INSERT INTO states(path) VALUES (?)`, path)
+ if err != nil {
+ return err
+ }
+ stateID, err = result.LastInsertId()
+ if err != nil {
+ return err
+ }
+ } else {
+ return err
}
- } else {
- return false, err
}
- }
- if lockID != "" && slices.Compare([]byte(lockID), lockData) != 0 {
- err = fmt.Errorf("failed to update state, lock ID does not match")
- return true, err
- }
- _, err = tx.ExecContext(db.ctx,
- `INSERT INTO versions(account_id, state_id, data, lock)
+ if lockID != "" && slices.Compare([]byte(lockID), lockData) != 0 {
+ err = fmt.Errorf("failed to update state, lock ID does not match")
+ ret = true
+ return err
+ }
+ _, err = tx.ExecContext(db.ctx,
+ `INSERT INTO versions(account_id, state_id, data, lock)
SELECT :accountID, :stateID, :data, lock
FROM states
WHERE states.id = :stateID;`,
- sql.Named("accountID", accountID),
- sql.Named("stateID", stateID),
- sql.Named("data", encryptedData))
- if err != nil {
- return false, err
- }
- _, err = tx.ExecContext(db.ctx,
- `DELETE FROM versions
+ sql.Named("accountID", accountID),
+ sql.Named("stateID", stateID),
+ sql.Named("data", encryptedData))
+ if err != nil {
+ return err
+ }
+ _, err = tx.ExecContext(db.ctx,
+ `DELETE FROM versions
WHERE state_id = (SELECT id
FROM states
WHERE path = :path)
@@ -104,12 +98,9 @@ func (db *DB) SetState(path string, accountID int, data []byte, lockID string) (
WHERE states.path = :path
ORDER BY versions.id DESC
LIMIT :limit));`,
- sql.Named("limit", db.versionsHistoryLimit),
- sql.Named("path", path),
- )
- if err != nil {
- return false, err
- }
- err = tx.Commit()
- return false, err
+ sql.Named("limit", db.versionsHistoryLimit),
+ sql.Named("path", path),
+ )
+ return err
+ })
}