From 3319e74279af910523f068a96eea169c7a60d70d Mon Sep 17 00:00:00 2001 From: Julien Dessaux Date: Mon, 14 Oct 2024 23:54:49 +0200 Subject: feat(tfstated): implement states versioning --- pkg/database/states.go | 58 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 41 insertions(+), 17 deletions(-) (limited to 'pkg/database/states.go') diff --git a/pkg/database/states.go b/pkg/database/states.go index 500d403..9c35212 100644 --- a/pkg/database/states.go +++ b/pkg/database/states.go @@ -2,7 +2,9 @@ package database import ( "database/sql" + "errors" "fmt" + "slices" ) // returns true in case of successful deletion @@ -20,8 +22,11 @@ func (db *DB) DeleteState(name string) (bool, error) { func (db *DB) GetState(name string) ([]byte, error) { var encryptedData []byte - err := db.QueryRow(`SELECT data FROM states WHERE name = ?;`, name).Scan(&encryptedData) + err := db.QueryRow(`SELECT versions.data FROM versions JOIN states ON states.id = versions.state_id WHERE states.name = ? ORDER BY versions.id DESC LIMIT 1;`, name).Scan(&encryptedData) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return []byte{}, nil + } return nil, err } if encryptedData == nil { @@ -31,30 +36,49 @@ func (db *DB) GetState(name string) ([]byte, error) { } // returns true in case of id mismatch -func (db *DB) SetState(name string, data []byte, id string) (bool, error) { +func (db *DB) SetState(name string, data []byte, lockID string) (bool, error) { encryptedData, err := db.dataEncryptionKey.EncryptAES256(data) if err != nil { return false, err } - if id == "" { - _, err = db.Exec( - `INSERT INTO states(name, data) VALUES (:name, :data) ON CONFLICT DO UPDATE SET data = :data WHERE name = :name;`, - sql.Named("data", encryptedData), - sql.Named("name", name), - ) + tx, err := db.Begin() + if err != nil { return false, err - } else { - result, err := db.Exec(`UPDATE states SET data = ? WHERE name = ? and lock->>'ID' = ?;`, encryptedData, name, id) + } + defer func() { if err != nil { - return false, err + _ = tx.Rollback() } - n, err := result.RowsAffected() - if err != nil { + }() + var ( + stateID int64 + lockData []byte + ) + if err = tx.QueryRowContext(db.ctx, `SELECT id, lock->>'ID' FROM states WHERE name = ?;`, name).Scan(&stateID, &lockData); err != nil { + if errors.Is(err, sql.ErrNoRows) { + var result sql.Result + result, err = tx.ExecContext(db.ctx, `INSERT INTO states(name) VALUES (?)`, name) + if err != nil { + return false, err + } + stateID, err = result.LastInsertId() + if err != nil { + return false, err + } + } else { return false, err } - if n != 1 { - return true, fmt.Errorf("failed to update state, lock ID does not match") - } - return false, nil } + + if lockID != "" && slices.Compare([]byte(lockID), lockData) != 0 { + err = fmt.Errorf("failed to update state, lock ID does not match") + return true, err + } + _, err = tx.ExecContext(db.ctx, `INSERT INTO versions(state_id, data) VALUES (?, ?);`, stateID, encryptedData) + if err != nil { + return false, err + } + // TODO delete old states + err = tx.Commit() + return false, err } -- cgit v1.2.3