From 3319e74279af910523f068a96eea169c7a60d70d Mon Sep 17 00:00:00 2001 From: Julien Dessaux Date: Mon, 14 Oct 2024 23:54:49 +0200 Subject: feat(tfstated): implement states versioning --- pkg/database/locks.go | 6 +---- pkg/database/sql/000_init.sql | 9 ++++++- pkg/database/states.go | 58 ++++++++++++++++++++++++++++++------------- 3 files changed, 50 insertions(+), 23 deletions(-) (limited to 'pkg/database') diff --git a/pkg/database/locks.go b/pkg/database/locks.go index 1f66975..82b66cd 100644 --- a/pkg/database/locks.go +++ b/pkg/database/locks.go @@ -25,11 +25,7 @@ func (db *DB) SetLockOrGetExistingLock(name string, lock any) (bool, error) { if lockData, err = json.Marshal(lock); err != nil { return false, err } - _, err = tx.ExecContext(db.ctx, - `INSERT INTO states(name, lock) VALUES (:name, json(:lock))`, - sql.Named("lock", lockData), - sql.Named("name", name), - ) + _, err = tx.ExecContext(db.ctx, `INSERT INTO states(name, lock) VALUES (?, json(?))`, name, lockData) if err != nil { return false, err } diff --git a/pkg/database/sql/000_init.sql b/pkg/database/sql/000_init.sql index 8278433..08a58cd 100644 --- a/pkg/database/sql/000_init.sql +++ b/pkg/database/sql/000_init.sql @@ -5,7 +5,14 @@ CREATE TABLE schema_version ( CREATE TABLE states ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, - data BLOB, lock TEXT ) STRICT; CREATE UNIQUE INDEX states_name on states(name); + +CREATE TABLE versions ( + id INTEGER PRIMARY KEY, + state_id INTEGER, + data BLOB, + created INTEGER DEFAULT (unixepoch()), + FOREIGN KEY(state_id) REFERENCES states(id) ON DELETE CASCADE +) STRICT; diff --git a/pkg/database/states.go b/pkg/database/states.go index 500d403..9c35212 100644 --- a/pkg/database/states.go +++ b/pkg/database/states.go @@ -2,7 +2,9 @@ package database import ( "database/sql" + "errors" "fmt" + "slices" ) // returns true in case of successful deletion @@ -20,8 +22,11 @@ func (db *DB) DeleteState(name string) (bool, error) { func (db *DB) GetState(name string) ([]byte, error) { var encryptedData []byte - err := db.QueryRow(`SELECT data FROM states WHERE name = ?;`, name).Scan(&encryptedData) + err := db.QueryRow(`SELECT versions.data FROM versions JOIN states ON states.id = versions.state_id WHERE states.name = ? ORDER BY versions.id DESC LIMIT 1;`, name).Scan(&encryptedData) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return []byte{}, nil + } return nil, err } if encryptedData == nil { @@ -31,30 +36,49 @@ func (db *DB) GetState(name string) ([]byte, error) { } // returns true in case of id mismatch -func (db *DB) SetState(name string, data []byte, id string) (bool, error) { +func (db *DB) SetState(name string, data []byte, lockID string) (bool, error) { encryptedData, err := db.dataEncryptionKey.EncryptAES256(data) if err != nil { return false, err } - if id == "" { - _, err = db.Exec( - `INSERT INTO states(name, data) VALUES (:name, :data) ON CONFLICT DO UPDATE SET data = :data WHERE name = :name;`, - sql.Named("data", encryptedData), - sql.Named("name", name), - ) + tx, err := db.Begin() + if err != nil { return false, err - } else { - result, err := db.Exec(`UPDATE states SET data = ? WHERE name = ? and lock->>'ID' = ?;`, encryptedData, name, id) + } + defer func() { if err != nil { - return false, err + _ = tx.Rollback() } - n, err := result.RowsAffected() - if err != nil { + }() + var ( + stateID int64 + lockData []byte + ) + if err = tx.QueryRowContext(db.ctx, `SELECT id, lock->>'ID' FROM states WHERE name = ?;`, name).Scan(&stateID, &lockData); err != nil { + if errors.Is(err, sql.ErrNoRows) { + var result sql.Result + result, err = tx.ExecContext(db.ctx, `INSERT INTO states(name) VALUES (?)`, name) + if err != nil { + return false, err + } + stateID, err = result.LastInsertId() + if err != nil { + return false, err + } + } else { return false, err } - if n != 1 { - return true, fmt.Errorf("failed to update state, lock ID does not match") - } - return false, nil } + + if lockID != "" && slices.Compare([]byte(lockID), lockData) != 0 { + err = fmt.Errorf("failed to update state, lock ID does not match") + return true, err + } + _, err = tx.ExecContext(db.ctx, `INSERT INTO versions(state_id, data) VALUES (?, ?);`, stateID, encryptedData) + if err != nil { + return false, err + } + // TODO delete old states + err = tx.Commit() + return false, err } -- cgit v1.2.3