summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cmd/tfstated/main.go19
-rw-r--r--cmd/tfstated/main_test.go31
-rw-r--r--pkg/database/db.go25
3 files changed, 32 insertions, 43 deletions
diff --git a/cmd/tfstated/main.go b/cmd/tfstated/main.go
index e6c6272..a1721e1 100644
--- a/cmd/tfstated/main.go
+++ b/cmd/tfstated/main.go
@@ -10,7 +10,6 @@ import (
"net/http"
"os"
"os/signal"
- "strconv"
"sync"
"time"
@@ -36,22 +35,6 @@ func run(
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()
- dataEncryptionKey := getenv("DATA_ENCRYPTION_KEY")
- if dataEncryptionKey == "" {
- return fmt.Errorf("the DATA_ENCRYPTION_KEY environment variable is not set")
- }
- if err := db.SetDataEncryptionKey(dataEncryptionKey); err != nil {
- return err
- }
- versionsHistoryLimit := getenv("VERSIONS_HISTORY_LIMIT")
- if versionsHistoryLimit != "" {
- n, err := strconv.Atoi(versionsHistoryLimit)
- if err != nil {
- return fmt.Errorf("failed to parse the VERSIONS_HISTORY_LIMIT environment variable: %w", err)
- }
- db.SetVersionsHistoryLimit(n)
- }
-
if err := db.InitAdminAccount(); err != nil {
return err
}
@@ -107,7 +90,7 @@ func main() {
Port: "8080",
}
- db, err := database.NewDB(ctx, "./tfstate.db?_txlock=immediate")
+ db, err := database.NewDB(ctx, "./tfstate.db?_txlock=immediate", os.Getenv)
if err != nil {
fmt.Fprintf(os.Stderr, "database init error: %+v\n", err)
os.Exit(1)
diff --git a/cmd/tfstated/main_test.go b/cmd/tfstated/main_test.go
index bd21918..dc19a3f 100644
--- a/cmd/tfstated/main_test.go
+++ b/cmd/tfstated/main_test.go
@@ -24,19 +24,10 @@ var adminPassword string
var adminPasswordMutex sync.Mutex
func TestMain(m *testing.M) {
- ctx := context.Background()
- ctx, cancel := context.WithCancel(ctx)
config := Config{
Host: "127.0.0.1",
Port: "8081",
}
- _ = os.Remove("./test.db")
- var err error
- db, err = database.NewDB(ctx, "./test.db")
- if err != nil {
- fmt.Fprintf(os.Stderr, "%+v\n", err)
- os.Exit(1)
- }
getenv := func(key string) string {
switch key {
case "DATA_ENCRYPTION_KEY":
@@ -48,6 +39,15 @@ func TestMain(m *testing.M) {
}
}
+ ctx := context.Background()
+ ctx, cancel := context.WithCancel(ctx)
+ _ = os.Remove("./test.db")
+ var err error
+ db, err = database.NewDB(ctx, "./test.db", getenv)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%+v\n", err)
+ os.Exit(1)
+ }
database.AdvertiseAdminPassword = func(password string) {
adminPasswordMutex.Lock()
defer adminPasswordMutex.Unlock()
@@ -125,14 +125,13 @@ func waitForReady(
resp, err := client.Do(req)
if err != nil {
fmt.Printf("Error making request: %s\n", err.Error())
- continue
- }
- if resp.StatusCode == http.StatusOK {
- fmt.Println("Endpoint is ready!")
- resp.Body.Close()
- return nil
+ } else {
+ _ = resp.Body.Close()
+ if resp.StatusCode == http.StatusOK {
+ fmt.Println("Endpoint is ready!")
+ return nil
+ }
}
- resp.Body.Close()
select {
case <-ctx.Done():
diff --git a/pkg/database/db.go b/pkg/database/db.go
index b7f0e89..9aeec43 100644
--- a/pkg/database/db.go
+++ b/pkg/database/db.go
@@ -5,6 +5,7 @@ import (
"database/sql"
"fmt"
"runtime"
+ "strconv"
"git.adyxax.org/adyxax/tfstated/pkg/scrypto"
)
@@ -34,7 +35,7 @@ type DB struct {
writeDB *sql.DB
}
-func NewDB(ctx context.Context, url string) (*DB, error) {
+func NewDB(ctx context.Context, url string, getenv func(string) string) (*DB, error) {
readDB, err := initDB(ctx, url)
if err != nil {
return nil, fmt.Errorf("failed to init read database connection: %w", err)
@@ -81,6 +82,20 @@ func NewDB(ctx context.Context, url string) (*DB, error) {
return nil, fmt.Errorf("failed to migrate: %w", err)
}
+ dataEncryptionKey := getenv("DATA_ENCRYPTION_KEY")
+ if dataEncryptionKey == "" {
+ return nil, fmt.Errorf("the DATA_ENCRYPTION_KEY environment variable is not set")
+ }
+ if err := db.dataEncryptionKey.FromBase64(dataEncryptionKey); err != nil {
+ return nil, fmt.Errorf("failed to decode the DATA_ENCRYPTION_KEY environment variable, expected base64: %w", err)
+ }
+ versionsHistoryLimit := getenv("VERSIONS_HISTORY_LIMIT")
+ if versionsHistoryLimit != "" {
+ if db.versionsHistoryLimit, err = strconv.Atoi(versionsHistoryLimit); err != nil {
+ return nil, fmt.Errorf("failed to parse the VERSIONS_HISTORY_LIMIT environment variable, expected an integer: %w", err)
+ }
+ }
+
return &db, nil
}
@@ -100,14 +115,6 @@ func (db *DB) QueryRow(query string, args ...any) *sql.Row {
return db.readDB.QueryRowContext(db.ctx, query, args...)
}
-func (db *DB) SetDataEncryptionKey(s string) error {
- return db.dataEncryptionKey.FromBase64(s)
-}
-
-func (db *DB) SetVersionsHistoryLimit(n int) {
- db.versionsHistoryLimit = n
-}
-
func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error {
tx, err := db.writeDB.Begin()
if err != nil {