summaryrefslogtreecommitdiff
path: root/cmd
diff options
context:
space:
mode:
Diffstat (limited to 'cmd')
-rw-r--r--cmd/tfstated/get_test.go40
-rw-r--r--cmd/tfstated/healthz_test.go17
-rw-r--r--cmd/tfstated/main.go18
-rw-r--r--cmd/tfstated/main_test.go131
-rw-r--r--cmd/tfstated/post.go2
-rw-r--r--cmd/tfstated/post_test.go49
6 files changed, 248 insertions, 9 deletions
diff --git a/cmd/tfstated/get_test.go b/cmd/tfstated/get_test.go
new file mode 100644
index 0000000..a6d9382
--- /dev/null
+++ b/cmd/tfstated/get_test.go
@@ -0,0 +1,40 @@
+package main
+
+import (
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "testing"
+)
+
+func TestGet(t *testing.T) {
+ tests := []struct {
+ method string
+ uri *url.URL
+ body io.Reader
+ expect string
+ status int
+ msg string
+ }{
+ {"GET", &url.URL{Path: "/"}, nil, "", http.StatusBadRequest, "/"},
+ {"GET", &url.URL{Path: "/non_existent_get"}, nil, "", http.StatusNotFound, "non existent"},
+ {"POST", &url.URL{Path: "/test_get"}, strings.NewReader("the_test_get"), "", http.StatusOK, "/test_get"},
+ {"GET", &url.URL{Path: "/test_get"}, nil, "the_test_get", http.StatusOK, "/test_get"},
+ }
+ for _, tt := range tests {
+ runHTTPRequest(tt.method, tt.uri, tt.body, func(r *http.Response, err error) {
+ if err != nil {
+ t.Fatalf("failed %s with error: %+v", tt.method, err)
+ } else if r.StatusCode != tt.status {
+ t.Fatalf("%s %s should %s, got %s", tt.method, tt.msg, http.StatusText(tt.status), http.StatusText(r.StatusCode))
+ } else if tt.expect != "" {
+ if body, err := io.ReadAll(r.Body); err != nil {
+ t.Fatalf("failed to read body with error: %+v", err)
+ } else if strings.Compare(string(body), tt.expect) != 0 {
+ t.Fatalf("%s should have returned \"%s\", got %s", tt.method, tt.expect, string(body))
+ }
+ }
+ })
+ }
+}
diff --git a/cmd/tfstated/healthz_test.go b/cmd/tfstated/healthz_test.go
new file mode 100644
index 0000000..c4860c4
--- /dev/null
+++ b/cmd/tfstated/healthz_test.go
@@ -0,0 +1,17 @@
+package main
+
+import (
+ "net/http"
+ "net/url"
+ "testing"
+)
+
+func TestHealthz(t *testing.T) {
+ runHTTPRequest("GET", &url.URL{Path: "/healthz"}, nil, func(r *http.Response, err error) {
+ if err != nil {
+ t.Fatalf("failed healthcheck with error: %+v", err)
+ } else if r.StatusCode != http.StatusOK {
+ t.Fatalf("healthcheck should succeed, got %s", http.StatusText(r.StatusCode))
+ }
+ })
+}
diff --git a/cmd/tfstated/main.go b/cmd/tfstated/main.go
index cf93976..57c4461 100644
--- a/cmd/tfstated/main.go
+++ b/cmd/tfstated/main.go
@@ -26,10 +26,11 @@ func run(
ctx context.Context,
config *Config,
db *database.DB,
- args []string,
+ //args []string,
getenv func(string) string,
- stdin io.Reader,
- stdout, stderr io.Writer,
+ //stdin io.Reader,
+ //stdout io.Writer,
+ stderr io.Writer,
) error {
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()
@@ -55,7 +56,7 @@ func run(
go func() {
log.Printf("listening on %s\n", httpServer.Addr)
if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
- fmt.Fprintf(os.Stderr, "error listening and serving: %+v\n", err)
+ fmt.Fprintf(stderr, "error listening and serving: %+v\n", err)
}
}()
var wg sync.WaitGroup
@@ -67,7 +68,7 @@ func run(
shutdownCtx, cancel := context.WithTimeout(shutdownCtx, 10*time.Second)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
- fmt.Fprintf(os.Stderr, "error shutting down http server: %+v\n", err)
+ fmt.Fprintf(stderr, "error shutting down http server: %+v\n", err)
}
}()
wg.Wait()
@@ -104,10 +105,11 @@ func main() {
ctx,
&config,
db,
- os.Args,
+ //os.Args,
os.Getenv,
- os.Stdin,
- os.Stdout, os.Stderr,
+ //os.Stdin,
+ //os.Stdout,
+ os.Stderr,
); err != nil {
fmt.Fprintf(os.Stderr, "%+v\n", err)
os.Exit(1)
diff --git a/cmd/tfstated/main_test.go b/cmd/tfstated/main_test.go
new file mode 100644
index 0000000..7dece9d
--- /dev/null
+++ b/cmd/tfstated/main_test.go
@@ -0,0 +1,131 @@
+package main
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "net/url"
+ "os"
+ "testing"
+ "time"
+
+ "git.adyxax.org/adyxax/tfstated/pkg/database"
+)
+
+var baseURI = &url.URL{
+ Host: "127.0.0.1:8081",
+ Path: "/",
+ Scheme: "http",
+}
+
+func TestMain(m *testing.M) {
+ ctx := context.Background()
+ ctx, cancel := context.WithCancel(ctx)
+ config := Config{
+ Host: "127.0.0.1",
+ Port: "8081",
+ }
+ _ = os.Remove("./test.db")
+ db, err := database.NewDB(ctx, "./test.db")
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%+v\n", err)
+ os.Exit(1)
+ }
+ getenv := func(key string) string {
+ switch key {
+ case "DATA_ENCRYPTION_KEY":
+ return "hP3ZSCnY3LMgfTQjwTaGrhKwdA0yXMXIfv67OJnntqM="
+ default:
+ return ""
+ }
+ }
+
+ go run(
+ ctx,
+ &config,
+ db,
+ getenv,
+ os.Stderr,
+ )
+ err = waitForReady(ctx, 5*time.Second, "http://127.0.0.1:8081/healthz")
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%+v\n", err)
+ os.Exit(1)
+ }
+
+ ret := m.Run()
+
+ cancel()
+ err = db.Close()
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%+v\n", err)
+ os.Exit(1)
+ }
+ _ = os.Remove("./test.db")
+
+ os.Exit(ret)
+}
+
+func runHTTPRequest(method string, uriRef *url.URL, body io.Reader, testFunc func(*http.Response, error)) {
+ uri := baseURI.ResolveReference(uriRef)
+ client := http.Client{}
+ req, err := http.NewRequest(method, uri.String(), body)
+ if err != nil {
+ testFunc(nil, fmt.Errorf("failed to create request: %w", err))
+ return
+ }
+ resp, err := client.Do(req)
+ if err != nil {
+ testFunc(nil, fmt.Errorf("failed to do request: %w\n", err))
+ return
+ }
+ testFunc(resp, nil)
+ _ = resp.Body.Close()
+}
+
+// waitForReady calls the specified endpoint until it gets a 200
+// response or until the context is cancelled or the timeout is
+// reached.
+func waitForReady(
+ ctx context.Context,
+ timeout time.Duration,
+ endpoint string,
+) error {
+ client := http.Client{}
+ startTime := time.Now()
+ for {
+ req, err := http.NewRequestWithContext(
+ ctx,
+ http.MethodGet,
+ endpoint,
+ nil,
+ )
+ if err != nil {
+ return fmt.Errorf("failed to create request: %w", err)
+ }
+
+ resp, err := client.Do(req)
+ if err != nil {
+ fmt.Printf("Error making request: %s\n", err.Error())
+ continue
+ }
+ if resp.StatusCode == http.StatusOK {
+ fmt.Println("Endpoint is ready!")
+ resp.Body.Close()
+ return nil
+ }
+ resp.Body.Close()
+
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ if time.Since(startTime) >= timeout {
+ return fmt.Errorf("timeout reached while waiting for endpoint")
+ }
+ // wait a little while between checks
+ time.Sleep(250 * time.Millisecond)
+ }
+ }
+}
diff --git a/cmd/tfstated/post.go b/cmd/tfstated/post.go
index 1d570b3..cf788d4 100644
--- a/cmd/tfstated/post.go
+++ b/cmd/tfstated/post.go
@@ -20,7 +20,7 @@ func handlePost(db *database.DB) http.Handler {
id := r.URL.Query().Get("ID")
data, err := io.ReadAll(r.Body)
- if err != nil {
+ if err != nil || len(data) == 0 {
_ = errorResponse(w, http.StatusBadRequest, err)
return
}
diff --git a/cmd/tfstated/post_test.go b/cmd/tfstated/post_test.go
new file mode 100644
index 0000000..693baa2
--- /dev/null
+++ b/cmd/tfstated/post_test.go
@@ -0,0 +1,49 @@
+package main
+
+import (
+ "io"
+ "net/http"
+ "net/url"
+ "strings"
+ "testing"
+)
+
+func TestPost(t *testing.T) {
+ tests := []struct {
+ method string
+ uri *url.URL
+ body io.Reader
+ expect string
+ status int
+ msg string
+ }{
+ {"POST", &url.URL{Path: "/"}, nil, "", http.StatusBadRequest, "/"},
+ {"POST", &url.URL{Path: "/test_post"}, nil, "", http.StatusBadRequest, "without a body"},
+ {"POST", &url.URL{Path: "/test_post"}, strings.NewReader("the_test_post"), "", http.StatusOK, "without lock ID in query string"},
+ {"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post", http.StatusOK, "/test_post"},
+ {"POST", &url.URL{Path: "/test_post", RawQuery: "ID=test_post"}, strings.NewReader("the_test_post2"), "", http.StatusConflict, "with a lock ID on an unlocked state"},
+ {"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post", http.StatusOK, "/test_post"},
+ {"LOCK", &url.URL{Path: "/test_post"}, strings.NewReader("{\"ID\":\"test_post_lock\"}"), "", http.StatusOK, "/test_post"},
+ {"POST", &url.URL{Path: "/test_post", RawQuery: "ID=test_post_invalid"}, strings.NewReader("the_test_post3"), "", http.StatusConflict, "with a wrong lock ID on a locked state"},
+ {"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post", http.StatusOK, "/test_post"},
+ {"POST", &url.URL{Path: "/test_post", RawQuery: "ID=test_post_lock"}, strings.NewReader("the_test_post4"), "", http.StatusOK, "with a correct lock ID on a locked state"},
+ {"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post4", http.StatusOK, "/test_post"},
+ {"POST", &url.URL{Path: "/test_post"}, strings.NewReader("the_test_post5"), "", http.StatusOK, "without lock ID in query string on a locked state"},
+ {"GET", &url.URL{Path: "/test_post"}, nil, "the_test_post5", http.StatusOK, "/test_post"},
+ }
+ for _, tt := range tests {
+ runHTTPRequest(tt.method, tt.uri, tt.body, func(r *http.Response, err error) {
+ if err != nil {
+ t.Fatalf("failed %s with error: %+v", tt.method, err)
+ } else if r.StatusCode != tt.status {
+ t.Fatalf("%s %s should %s, got %s", tt.method, tt.msg, http.StatusText(tt.status), http.StatusText(r.StatusCode))
+ } else if tt.expect != "" {
+ if body, err := io.ReadAll(r.Body); err != nil {
+ t.Fatalf("failed to read body with error: %+v", err)
+ } else if strings.Compare(string(body), tt.expect) != 0 {
+ t.Fatalf("%s should have returned \"%s\", got %s", tt.method, tt.expect, string(body))
+ }
+ }
+ })
+ }
+}