fix(tfstated): add lock handler tests

This commit is contained in:
Julien Dessaux 2024-10-06 00:13:53 +02:00
parent c4ce5e8623
commit 5ef098f47e
Signed by: adyxax
GPG key ID: F92E51B86E07177E
6 changed files with 102 additions and 24 deletions

View file

@ -1,10 +1,9 @@
package main package main
import ( import (
"database/sql"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"regexp"
"time" "time"
"git.adyxax.org/adyxax/tfstated/pkg/database" "git.adyxax.org/adyxax/tfstated/pkg/database"
@ -20,6 +19,18 @@ type lockRequest struct {
Who string `json:"Who"` Who string `json:"Who"`
} }
var (
validID = regexp.MustCompile("[a-f0-9]{8}-(?:[a-f0-9]{4}-){3}[a-f0-9]{12}")
)
func (l *lockRequest) valid() []error {
err := make([]error, 0)
if !validID.MatchString(l.ID) {
err = append(err, fmt.Errorf("invalid ID"))
}
return err
}
func handleLock(db *database.DB) http.Handler { func handleLock(db *database.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" { if r.URL.Path == "/" {
@ -33,13 +44,13 @@ func handleLock(db *database.DB) http.Handler {
_ = encode(w, http.StatusBadRequest, err) _ = encode(w, http.StatusBadRequest, err)
return return
} }
if errs := lock.valid(); len(errs) > 0 {
_ = encode(w, http.StatusBadRequest,
fmt.Errorf("invalid lock: %+v", errs))
return
}
if success, err := db.SetLockOrGetExistingLock(r.URL.Path, &lock); err != nil { if success, err := db.SetLockOrGetExistingLock(r.URL.Path, &lock); err != nil {
if errors.Is(err, sql.ErrNoRows) { _ = errorResponse(w, http.StatusInternalServerError, err)
_ = encode(w, http.StatusNotFound,
fmt.Errorf("state path not found: %s", r.URL.Path))
} else {
_ = errorResponse(w, http.StatusInternalServerError, err)
}
} else if success { } else if success {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} else { } else {

45
cmd/tfstated/lock_test.go Normal file
View file

@ -0,0 +1,45 @@
package main
import (
"io"
"net/http"
"net/url"
"strings"
"testing"
)
func TestLock(t *testing.T) {
tests := []struct {
method string
uri *url.URL
body io.Reader
expect string
status int
msg string
}{
{"LOCK", &url.URL{Path: "/"}, nil, "", http.StatusBadRequest, "/"},
{"LOCK", &url.URL{Path: "/non_existent_lock"}, nil, "", http.StatusBadRequest, "no lock data on non existent state"},
{"LOCK", &url.URL{Path: "/non_existent_lock"}, strings.NewReader("{}"), "", http.StatusBadRequest, "invalid lock data on non existent state"},
{"LOCK", &url.URL{Path: "/test_lock"}, strings.NewReader("{\"ID\":\"00000000-0000-0000-0000-000000000000\"}"), "", http.StatusOK, "valid lock data on non existent state should create it empty"},
{"GET", &url.URL{Path: "/test_lock"}, nil, "", http.StatusOK, "/test_lock"},
{"LOCK", &url.URL{Path: "/test_lock"}, strings.NewReader("{\"ID\":\"\"}"), "", http.StatusBadRequest, "invalid lock data on already locked state"},
{"LOCK", &url.URL{Path: "/test_lock"}, strings.NewReader("{\"ID\":\"00000000-0000-0000-0000-000000000000\"}"), "", http.StatusConflict, "valid lock data on already locked state"},
{"POST", &url.URL{Path: "/test_lock", RawQuery: "ID=00000000-0000-0000-0000-000000000000"}, strings.NewReader("the_test_lock"), "", http.StatusOK, "/test_lock"},
{"GET", &url.URL{Path: "/test_lock"}, nil, "the_test_lock", http.StatusOK, "/test_lock"},
}
for _, tt := range tests {
runHTTPRequest(tt.method, tt.uri, tt.body, func(r *http.Response, err error) {
if err != nil {
t.Fatalf("failed %s with error: %+v", tt.method, err)
} else if r.StatusCode != tt.status {
t.Fatalf("%s %s should %s, got %s", tt.method, tt.msg, http.StatusText(tt.status), http.StatusText(r.StatusCode))
} else if tt.expect != "" {
if body, err := io.ReadAll(r.Body); err != nil {
t.Fatalf("failed to read body with error: %+v", err)
} else if strings.Compare(string(body), tt.expect) != 0 {
t.Fatalf("%s should have returned \"%s\", got %s", tt.method, tt.expect, string(body))
}
}
})
}
}

View file

@ -21,12 +21,12 @@ func TestPost(t *testing.T) {
{"POST", &url.URL{Path: "/test_post"}, nil, "", http.StatusBadRequest, "without a body"}, {"POST", &url.URL{Path: "/test_post"}, nil, "", http.StatusBadRequest, "without a body"},
{"POST", &url.URL{Path: "/test_post"}, strings.NewReader("the_test_post"), "", http.StatusOK, "without lock ID in query string"}, {"POST", &url.URL{Path: "/test_post"}, strings.NewReader("the_test_post"), "", http.StatusOK, "without lock ID in query string"},
{"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post", http.StatusOK, "/test_post"}, {"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post", http.StatusOK, "/test_post"},
{"POST", &url.URL{Path: "/test_post", RawQuery: "ID=test_post"}, strings.NewReader("the_test_post2"), "", http.StatusConflict, "with a lock ID on an unlocked state"}, {"POST", &url.URL{Path: "/test_post", RawQuery: "ID=00000000-0000-0000-0000-000000000000"}, strings.NewReader("the_test_post2"), "", http.StatusConflict, "with a lock ID on an unlocked state"},
{"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post", http.StatusOK, "/test_post"}, {"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post", http.StatusOK, "/test_post"},
{"LOCK", &url.URL{Path: "/test_post"}, strings.NewReader("{\"ID\":\"test_post_lock\"}"), "", http.StatusOK, "/test_post"}, {"LOCK", &url.URL{Path: "/test_post"}, strings.NewReader("{\"ID\":\"00000000-0000-0000-0000-000000000000\"}"), "", http.StatusOK, "/test_post"},
{"POST", &url.URL{Path: "/test_post", RawQuery: "ID=test_post_invalid"}, strings.NewReader("the_test_post3"), "", http.StatusConflict, "with a wrong lock ID on a locked state"}, {"POST", &url.URL{Path: "/test_post", RawQuery: "ID=ffffffff-ffff-ffff-ffff-ffffffffffff"}, strings.NewReader("the_test_post3"), "", http.StatusConflict, "with a wrong lock ID on a locked state"},
{"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post", http.StatusOK, "/test_post"}, {"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post", http.StatusOK, "/test_post"},
{"POST", &url.URL{Path: "/test_post", RawQuery: "ID=test_post_lock"}, strings.NewReader("the_test_post4"), "", http.StatusOK, "with a correct lock ID on a locked state"}, {"POST", &url.URL{Path: "/test_post", RawQuery: "ID=00000000-0000-0000-0000-000000000000"}, strings.NewReader("the_test_post4"), "", http.StatusOK, "with a correct lock ID on a locked state"},
{"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post4", http.StatusOK, "/test_post"}, {"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post4", http.StatusOK, "/test_post"},
{"POST", &url.URL{Path: "/test_post"}, strings.NewReader("the_test_post5"), "", http.StatusOK, "without lock ID in query string on a locked state"}, {"POST", &url.URL{Path: "/test_post"}, strings.NewReader("the_test_post5"), "", http.StatusOK, "without lock ID in query string on a locked state"},
{"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post5", http.StatusOK, "/test_post"}, {"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post5", http.StatusOK, "/test_post"},

View file

@ -1,12 +1,14 @@
package database package database
import ( import (
"database/sql"
"encoding/json" "encoding/json"
"errors"
) )
// Atomically check if there is an existing lock in place on the state. Returns // Atomically check the lock status of a state and lock it if unlocked. Returns
// true if it can be set, otherwise returns false and lock is set to the value // true if the function locked the state, otherwise returns false and the lock
// of the existing lock // parameter is updated to the value of the existing lock
func (db *DB) SetLockOrGetExistingLock(name string, lock any) (bool, error) { func (db *DB) SetLockOrGetExistingLock(name string, lock any) (bool, error) {
tx, err := db.Begin() tx, err := db.Begin()
if err != nil { if err != nil {
@ -17,18 +19,35 @@ func (db *DB) SetLockOrGetExistingLock(name string, lock any) (bool, error) {
_ = tx.Rollback() _ = tx.Rollback()
} }
}() }()
var data []byte var lockData []byte
if err = tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE name = ?;`, name).Scan(&data); err != nil { if err = tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE name = ?;`, name).Scan(&lockData); err != nil {
if errors.Is(err, sql.ErrNoRows) {
if lockData, err = json.Marshal(lock); err != nil {
return false, err
}
_, err = tx.ExecContext(db.ctx,
`INSERT INTO states(name, lock) VALUES (:name, json(:lock))`,
sql.Named("lock", lockData),
sql.Named("name", name),
)
if err != nil {
return false, err
}
err = tx.Commit()
return true, err
} else {
return false, err
}
}
if lockData != nil {
_ = tx.Rollback()
err = json.Unmarshal(lockData, lock)
return false, err return false, err
} }
if data != nil { if lockData, err = json.Marshal(lock); err != nil {
err = json.Unmarshal(data, lock)
return false, err return false, err
} }
if data, err = json.Marshal(lock); err != nil { _, err = tx.ExecContext(db.ctx, `UPDATE states SET lock = json(?) WHERE name = ?;`, lockData, name)
return false, err
}
_, err = tx.Exec(`UPDATE states SET lock = json(?) WHERE name = ?;`, data, name)
if err != nil { if err != nil {
return false, err return false, err
} }

View file

@ -5,7 +5,7 @@ CREATE TABLE schema_version (
CREATE TABLE states ( CREATE TABLE states (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
name TEXT NOT NULL, name TEXT NOT NULL,
data BLOB NOT NULL, data BLOB,
lock TEXT lock TEXT
) STRICT; ) STRICT;
CREATE UNIQUE INDEX states_name on states(name); CREATE UNIQUE INDEX states_name on states(name);

View file

@ -24,6 +24,9 @@ func (db *DB) GetState(name string) ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if encryptedData == nil {
return []byte{}, nil
}
return db.dataEncryptionKey.DecryptAES256(encryptedData) return db.dataEncryptionKey.DecryptAES256(encryptedData)
} }