summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--cmd/tfstated/lock.go27
-rw-r--r--cmd/tfstated/lock_test.go45
-rw-r--r--cmd/tfstated/post_test.go8
-rw-r--r--pkg/database/locks.go39
-rw-r--r--pkg/database/sql/000_init.sql2
-rw-r--r--pkg/database/states.go3
6 files changed, 101 insertions, 23 deletions
diff --git a/cmd/tfstated/lock.go b/cmd/tfstated/lock.go
index e34c5b5..bab9c6b 100644
--- a/cmd/tfstated/lock.go
+++ b/cmd/tfstated/lock.go
@@ -1,10 +1,9 @@
package main
import (
- "database/sql"
- "errors"
"fmt"
"net/http"
+ "regexp"
"time"
"git.adyxax.org/adyxax/tfstated/pkg/database"
@@ -20,6 +19,18 @@ type lockRequest struct {
Who string `json:"Who"`
}
+var (
+ validID = regexp.MustCompile("[a-f0-9]{8}-(?:[a-f0-9]{4}-){3}[a-f0-9]{12}")
+)
+
+func (l *lockRequest) valid() []error {
+ err := make([]error, 0)
+ if !validID.MatchString(l.ID) {
+ err = append(err, fmt.Errorf("invalid ID"))
+ }
+ return err
+}
+
func handleLock(db *database.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
@@ -33,13 +44,13 @@ func handleLock(db *database.DB) http.Handler {
_ = encode(w, http.StatusBadRequest, err)
return
}
+ if errs := lock.valid(); len(errs) > 0 {
+ _ = encode(w, http.StatusBadRequest,
+ fmt.Errorf("invalid lock: %+v", errs))
+ return
+ }
if success, err := db.SetLockOrGetExistingLock(r.URL.Path, &lock); err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- _ = encode(w, http.StatusNotFound,
- fmt.Errorf("state path not found: %s", r.URL.Path))
- } else {
- _ = errorResponse(w, http.StatusInternalServerError, err)
- }
+ _ = errorResponse(w, http.StatusInternalServerError, err)
} else if success {
w.WriteHeader(http.StatusOK)
} else {
diff --git a/cmd/tfstated/lock_test.go b/cmd/tfstated/lock_test.go
new file mode 100644
index 0000000..9d8261b
--- /dev/null
+++ b/cmd/tfstated/lock_test.go
@@ -0,0 +1,45 @@
+package main
+
+import (
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "testing"
+)
+
+func TestLock(t *testing.T) {
+ tests := []struct {
+ method string
+ uri *url.URL
+ body io.Reader
+ expect string
+ status int
+ msg string
+ }{
+ {"LOCK", &url.URL{Path: "/"}, nil, "", http.StatusBadRequest, "/"},
+ {"LOCK", &url.URL{Path: "/non_existent_lock"}, nil, "", http.StatusBadRequest, "no lock data on non existent state"},
+ {"LOCK", &url.URL{Path: "/non_existent_lock"}, strings.NewReader("{}"), "", http.StatusBadRequest, "invalid lock data on non existent state"},
+ {"LOCK", &url.URL{Path: "/test_lock"}, strings.NewReader("{\"ID\":\"00000000-0000-0000-0000-000000000000\"}"), "", http.StatusOK, "valid lock data on non existent state should create it empty"},
+ {"GET", &url.URL{Path: "/test_lock"}, nil, "", http.StatusOK, "/test_lock"},
+ {"LOCK", &url.URL{Path: "/test_lock"}, strings.NewReader("{\"ID\":\"\"}"), "", http.StatusBadRequest, "invalid lock data on already locked state"},
+ {"LOCK", &url.URL{Path: "/test_lock"}, strings.NewReader("{\"ID\":\"00000000-0000-0000-0000-000000000000\"}"), "", http.StatusConflict, "valid lock data on already locked state"},
+ {"POST", &url.URL{Path: "/test_lock", RawQuery: "ID=00000000-0000-0000-0000-000000000000"}, strings.NewReader("the_test_lock"), "", http.StatusOK, "/test_lock"},
+ {"GET", &url.URL{Path: "/test_lock"}, nil, "the_test_lock", http.StatusOK, "/test_lock"},
+ }
+ for _, tt := range tests {
+ runHTTPRequest(tt.method, tt.uri, tt.body, func(r *http.Response, err error) {
+ if err != nil {
+ t.Fatalf("failed %s with error: %+v", tt.method, err)
+ } else if r.StatusCode != tt.status {
+ t.Fatalf("%s %s should %s, got %s", tt.method, tt.msg, http.StatusText(tt.status), http.StatusText(r.StatusCode))
+ } else if tt.expect != "" {
+ if body, err := io.ReadAll(r.Body); err != nil {
+ t.Fatalf("failed to read body with error: %+v", err)
+ } else if strings.Compare(string(body), tt.expect) != 0 {
+ t.Fatalf("%s should have returned \"%s\", got %s", tt.method, tt.expect, string(body))
+ }
+ }
+ })
+ }
+}
diff --git a/cmd/tfstated/post_test.go b/cmd/tfstated/post_test.go
index 693baa2..d346bad 100644
--- a/cmd/tfstated/post_test.go
+++ b/cmd/tfstated/post_test.go
@@ -21,12 +21,12 @@ func TestPost(t *testing.T) {
{"POST", &url.URL{Path: "/test_post"}, nil, "", http.StatusBadRequest, "without a body"},
{"POST", &url.URL{Path: "/test_post"}, strings.NewReader("the_test_post"), "", http.StatusOK, "without lock ID in query string"},
{"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post", http.StatusOK, "/test_post"},
- {"POST", &url.URL{Path: "/test_post", RawQuery: "ID=test_post"}, strings.NewReader("the_test_post2"), "", http.StatusConflict, "with a lock ID on an unlocked state"},
+ {"POST", &url.URL{Path: "/test_post", RawQuery: "ID=00000000-0000-0000-0000-000000000000"}, strings.NewReader("the_test_post2"), "", http.StatusConflict, "with a lock ID on an unlocked state"},
{"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post", http.StatusOK, "/test_post"},
- {"LOCK", &url.URL{Path: "/test_post"}, strings.NewReader("{\"ID\":\"test_post_lock\"}"), "", http.StatusOK, "/test_post"},
- {"POST", &url.URL{Path: "/test_post", RawQuery: "ID=test_post_invalid"}, strings.NewReader("the_test_post3"), "", http.StatusConflict, "with a wrong lock ID on a locked state"},
+ {"LOCK", &url.URL{Path: "/test_post"}, strings.NewReader("{\"ID\":\"00000000-0000-0000-0000-000000000000\"}"), "", http.StatusOK, "/test_post"},
+ {"POST", &url.URL{Path: "/test_post", RawQuery: "ID=ffffffff-ffff-ffff-ffff-ffffffffffff"}, strings.NewReader("the_test_post3"), "", http.StatusConflict, "with a wrong lock ID on a locked state"},
{"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post", http.StatusOK, "/test_post"},
- {"POST", &url.URL{Path: "/test_post", RawQuery: "ID=test_post_lock"}, strings.NewReader("the_test_post4"), "", http.StatusOK, "with a correct lock ID on a locked state"},
+ {"POST", &url.URL{Path: "/test_post", RawQuery: "ID=00000000-0000-0000-0000-000000000000"}, strings.NewReader("the_test_post4"), "", http.StatusOK, "with a correct lock ID on a locked state"},
{"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post4", http.StatusOK, "/test_post"},
{"POST", &url.URL{Path: "/test_post"}, strings.NewReader("the_test_post5"), "", http.StatusOK, "without lock ID in query string on a locked state"},
{"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post5", http.StatusOK, "/test_post"},
diff --git a/pkg/database/locks.go b/pkg/database/locks.go
index 53c6f6e..1f66975 100644
--- a/pkg/database/locks.go
+++ b/pkg/database/locks.go
@@ -1,12 +1,14 @@
package database
import (
+ "database/sql"
"encoding/json"
+ "errors"
)
-// Atomically check if there is an existing lock in place on the state. Returns
-// true if it can be set, otherwise returns false and lock is set to the value
-// of the existing lock
+// Atomically check the lock status of a state and lock it if unlocked. Returns
+// true if the function locked the state, otherwise returns false and the lock
+// parameter is updated to the value of the existing lock
func (db *DB) SetLockOrGetExistingLock(name string, lock any) (bool, error) {
tx, err := db.Begin()
if err != nil {
@@ -17,18 +19,35 @@ func (db *DB) SetLockOrGetExistingLock(name string, lock any) (bool, error) {
_ = tx.Rollback()
}
}()
- var data []byte
- if err = tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE name = ?;`, name).Scan(&data); err != nil {
- return false, err
+ var lockData []byte
+ if err = tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE name = ?;`, name).Scan(&lockData); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ if lockData, err = json.Marshal(lock); err != nil {
+ return false, err
+ }
+ _, err = tx.ExecContext(db.ctx,
+ `INSERT INTO states(name, lock) VALUES (:name, json(:lock))`,
+ sql.Named("lock", lockData),
+ sql.Named("name", name),
+ )
+ if err != nil {
+ return false, err
+ }
+ err = tx.Commit()
+ return true, err
+ } else {
+ return false, err
+ }
}
- if data != nil {
- err = json.Unmarshal(data, lock)
+ if lockData != nil {
+ _ = tx.Rollback()
+ err = json.Unmarshal(lockData, lock)
return false, err
}
- if data, err = json.Marshal(lock); err != nil {
+ if lockData, err = json.Marshal(lock); err != nil {
return false, err
}
- _, err = tx.Exec(`UPDATE states SET lock = json(?) WHERE name = ?;`, data, name)
+ _, err = tx.ExecContext(db.ctx, `UPDATE states SET lock = json(?) WHERE name = ?;`, lockData, name)
if err != nil {
return false, err
}
diff --git a/pkg/database/sql/000_init.sql b/pkg/database/sql/000_init.sql
index 6ec2ef7..8278433 100644
--- a/pkg/database/sql/000_init.sql
+++ b/pkg/database/sql/000_init.sql
@@ -5,7 +5,7 @@ CREATE TABLE schema_version (
CREATE TABLE states (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
- data BLOB NOT NULL,
+ data BLOB,
lock TEXT
) STRICT;
CREATE UNIQUE INDEX states_name on states(name);
diff --git a/pkg/database/states.go b/pkg/database/states.go
index 46a5536..500d403 100644
--- a/pkg/database/states.go
+++ b/pkg/database/states.go
@@ -24,6 +24,9 @@ func (db *DB) GetState(name string) ([]byte, error) {
if err != nil {
return nil, err
}
+ if encryptedData == nil {
+ return []byte{}, nil
+ }
return db.dataEncryptionKey.DecryptAES256(encryptedData)
}