diff options
-rw-r--r-- | cmd/tfstated/get.go | 9 | ||||
-rw-r--r-- | cmd/tfstated/get_test.go | 2 | ||||
-rw-r--r-- | pkg/database/locks.go | 6 | ||||
-rw-r--r-- | pkg/database/sql/000_init.sql | 9 | ||||
-rw-r--r-- | pkg/database/states.go | 58 |
5 files changed, 52 insertions, 32 deletions
diff --git a/cmd/tfstated/get.go b/cmd/tfstated/get.go index 7d54970..6bc4466 100644 --- a/cmd/tfstated/get.go +++ b/cmd/tfstated/get.go @@ -1,8 +1,6 @@ package main import ( - "database/sql" - "errors" "fmt" "net/http" @@ -20,12 +18,7 @@ func handleGet(db *database.DB) http.Handler { } if data, err := db.GetState(r.URL.Path); err != nil { - if errors.Is(err, sql.ErrNoRows) { - _ = errorResponse(w, http.StatusNotFound, - fmt.Errorf("state path not found: %s", r.URL.Path)) - } else { - _ = errorResponse(w, http.StatusInternalServerError, err) - } + _ = errorResponse(w, http.StatusInternalServerError, err) } else { w.WriteHeader(http.StatusOK) _, _ = w.Write(data) diff --git a/cmd/tfstated/get_test.go b/cmd/tfstated/get_test.go index a6d9382..aa56026 100644 --- a/cmd/tfstated/get_test.go +++ b/cmd/tfstated/get_test.go @@ -18,7 +18,7 @@ func TestGet(t *testing.T) { msg string }{ {"GET", &url.URL{Path: "/"}, nil, "", http.StatusBadRequest, "/"}, - {"GET", &url.URL{Path: "/non_existent_get"}, nil, "", http.StatusNotFound, "non existent"}, + {"GET", &url.URL{Path: "/non_existent_get"}, strings.NewReader(""), "", http.StatusOK, "non existent"}, {"POST", &url.URL{Path: "/test_get"}, strings.NewReader("the_test_get"), "", http.StatusOK, "/test_get"}, {"GET", &url.URL{Path: "/test_get"}, nil, "the_test_get", http.StatusOK, "/test_get"}, } diff --git a/pkg/database/locks.go b/pkg/database/locks.go index 1f66975..82b66cd 100644 --- a/pkg/database/locks.go +++ b/pkg/database/locks.go @@ -25,11 +25,7 @@ func (db *DB) SetLockOrGetExistingLock(name string, lock any) (bool, error) { if lockData, err = json.Marshal(lock); err != nil { return false, err } - _, err = tx.ExecContext(db.ctx, - `INSERT INTO states(name, lock) VALUES (:name, json(:lock))`, - sql.Named("lock", lockData), - sql.Named("name", name), - ) + _, err = tx.ExecContext(db.ctx, `INSERT INTO states(name, lock) VALUES (?, json(?))`, name, lockData) if err != nil { return false, err } diff --git a/pkg/database/sql/000_init.sql b/pkg/database/sql/000_init.sql index 8278433..08a58cd 100644 --- a/pkg/database/sql/000_init.sql +++ b/pkg/database/sql/000_init.sql @@ -5,7 +5,14 @@ CREATE TABLE schema_version ( CREATE TABLE states ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, - data BLOB, lock TEXT ) STRICT; CREATE UNIQUE INDEX states_name on states(name); + +CREATE TABLE versions ( + id INTEGER PRIMARY KEY, + state_id INTEGER, + data BLOB, + created INTEGER DEFAULT (unixepoch()), + FOREIGN KEY(state_id) REFERENCES states(id) ON DELETE CASCADE +) STRICT; diff --git a/pkg/database/states.go b/pkg/database/states.go index 500d403..9c35212 100644 --- a/pkg/database/states.go +++ b/pkg/database/states.go @@ -2,7 +2,9 @@ package database import ( "database/sql" + "errors" "fmt" + "slices" ) // returns true in case of successful deletion @@ -20,8 +22,11 @@ func (db *DB) DeleteState(name string) (bool, error) { func (db *DB) GetState(name string) ([]byte, error) { var encryptedData []byte - err := db.QueryRow(`SELECT data FROM states WHERE name = ?;`, name).Scan(&encryptedData) + err := db.QueryRow(`SELECT versions.data FROM versions JOIN states ON states.id = versions.state_id WHERE states.name = ? ORDER BY versions.id DESC LIMIT 1;`, name).Scan(&encryptedData) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return []byte{}, nil + } return nil, err } if encryptedData == nil { @@ -31,30 +36,49 @@ func (db *DB) GetState(name string) ([]byte, error) { } // returns true in case of id mismatch -func (db *DB) SetState(name string, data []byte, id string) (bool, error) { +func (db *DB) SetState(name string, data []byte, lockID string) (bool, error) { encryptedData, err := db.dataEncryptionKey.EncryptAES256(data) if err != nil { return false, err } - if id == "" { - _, err = db.Exec( - `INSERT INTO states(name, data) VALUES (:name, :data) ON CONFLICT DO UPDATE SET data = :data WHERE name = :name;`, - sql.Named("data", encryptedData), - sql.Named("name", name), - ) + tx, err := db.Begin() + if err != nil { return false, err - } else { - result, err := db.Exec(`UPDATE states SET data = ? WHERE name = ? and lock->>'ID' = ?;`, encryptedData, name, id) + } + defer func() { if err != nil { - return false, err + _ = tx.Rollback() } - n, err := result.RowsAffected() - if err != nil { + }() + var ( + stateID int64 + lockData []byte + ) + if err = tx.QueryRowContext(db.ctx, `SELECT id, lock->>'ID' FROM states WHERE name = ?;`, name).Scan(&stateID, &lockData); err != nil { + if errors.Is(err, sql.ErrNoRows) { + var result sql.Result + result, err = tx.ExecContext(db.ctx, `INSERT INTO states(name) VALUES (?)`, name) + if err != nil { + return false, err + } + stateID, err = result.LastInsertId() + if err != nil { + return false, err + } + } else { return false, err } - if n != 1 { - return true, fmt.Errorf("failed to update state, lock ID does not match") - } - return false, nil } + + if lockID != "" && slices.Compare([]byte(lockID), lockData) != 0 { + err = fmt.Errorf("failed to update state, lock ID does not match") + return true, err + } + _, err = tx.ExecContext(db.ctx, `INSERT INTO versions(state_id, data) VALUES (?, ?);`, stateID, encryptedData) + if err != nil { + return false, err + } + // TODO delete old states + err = tx.Commit() + return false, err } |