diff options
-rw-r--r-- | cmd/tfstated/main.go | 19 | ||||
-rw-r--r-- | cmd/tfstated/main_test.go | 31 | ||||
-rw-r--r-- | pkg/database/db.go | 25 |
3 files changed, 32 insertions, 43 deletions
diff --git a/cmd/tfstated/main.go b/cmd/tfstated/main.go index e6c6272..a1721e1 100644 --- a/cmd/tfstated/main.go +++ b/cmd/tfstated/main.go @@ -10,7 +10,6 @@ import ( "net/http" "os" "os/signal" - "strconv" "sync" "time" @@ -36,22 +35,6 @@ func run( ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) defer cancel() - dataEncryptionKey := getenv("DATA_ENCRYPTION_KEY") - if dataEncryptionKey == "" { - return fmt.Errorf("the DATA_ENCRYPTION_KEY environment variable is not set") - } - if err := db.SetDataEncryptionKey(dataEncryptionKey); err != nil { - return err - } - versionsHistoryLimit := getenv("VERSIONS_HISTORY_LIMIT") - if versionsHistoryLimit != "" { - n, err := strconv.Atoi(versionsHistoryLimit) - if err != nil { - return fmt.Errorf("failed to parse the VERSIONS_HISTORY_LIMIT environment variable: %w", err) - } - db.SetVersionsHistoryLimit(n) - } - if err := db.InitAdminAccount(); err != nil { return err } @@ -107,7 +90,7 @@ func main() { Port: "8080", } - db, err := database.NewDB(ctx, "./tfstate.db?_txlock=immediate") + db, err := database.NewDB(ctx, "./tfstate.db?_txlock=immediate", os.Getenv) if err != nil { fmt.Fprintf(os.Stderr, "database init error: %+v\n", err) os.Exit(1) diff --git a/cmd/tfstated/main_test.go b/cmd/tfstated/main_test.go index bd21918..dc19a3f 100644 --- a/cmd/tfstated/main_test.go +++ b/cmd/tfstated/main_test.go @@ -24,19 +24,10 @@ var adminPassword string var adminPasswordMutex sync.Mutex func TestMain(m *testing.M) { - ctx := context.Background() - ctx, cancel := context.WithCancel(ctx) config := Config{ Host: "127.0.0.1", Port: "8081", } - _ = os.Remove("./test.db") - var err error - db, err = database.NewDB(ctx, "./test.db") - if err != nil { - fmt.Fprintf(os.Stderr, "%+v\n", err) - os.Exit(1) - } getenv := func(key string) string { switch key { case "DATA_ENCRYPTION_KEY": @@ -48,6 +39,15 @@ func TestMain(m *testing.M) { } } + ctx := context.Background() + ctx, cancel := context.WithCancel(ctx) + _ = os.Remove("./test.db") + var err error + db, err = database.NewDB(ctx, "./test.db", getenv) + if err != nil { + fmt.Fprintf(os.Stderr, "%+v\n", err) + os.Exit(1) + } database.AdvertiseAdminPassword = func(password string) { adminPasswordMutex.Lock() defer adminPasswordMutex.Unlock() @@ -125,14 +125,13 @@ func waitForReady( resp, err := client.Do(req) if err != nil { fmt.Printf("Error making request: %s\n", err.Error()) - continue - } - if resp.StatusCode == http.StatusOK { - fmt.Println("Endpoint is ready!") - resp.Body.Close() - return nil + } else { + _ = resp.Body.Close() + if resp.StatusCode == http.StatusOK { + fmt.Println("Endpoint is ready!") + return nil + } } - resp.Body.Close() select { case <-ctx.Done(): diff --git a/pkg/database/db.go b/pkg/database/db.go index b7f0e89..9aeec43 100644 --- a/pkg/database/db.go +++ b/pkg/database/db.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "runtime" + "strconv" "git.adyxax.org/adyxax/tfstated/pkg/scrypto" ) @@ -34,7 +35,7 @@ type DB struct { writeDB *sql.DB } -func NewDB(ctx context.Context, url string) (*DB, error) { +func NewDB(ctx context.Context, url string, getenv func(string) string) (*DB, error) { readDB, err := initDB(ctx, url) if err != nil { return nil, fmt.Errorf("failed to init read database connection: %w", err) @@ -81,6 +82,20 @@ func NewDB(ctx context.Context, url string) (*DB, error) { return nil, fmt.Errorf("failed to migrate: %w", err) } + dataEncryptionKey := getenv("DATA_ENCRYPTION_KEY") + if dataEncryptionKey == "" { + return nil, fmt.Errorf("the DATA_ENCRYPTION_KEY environment variable is not set") + } + if err := db.dataEncryptionKey.FromBase64(dataEncryptionKey); err != nil { + return nil, fmt.Errorf("failed to decode the DATA_ENCRYPTION_KEY environment variable, expected base64: %w", err) + } + versionsHistoryLimit := getenv("VERSIONS_HISTORY_LIMIT") + if versionsHistoryLimit != "" { + if db.versionsHistoryLimit, err = strconv.Atoi(versionsHistoryLimit); err != nil { + return nil, fmt.Errorf("failed to parse the VERSIONS_HISTORY_LIMIT environment variable, expected an integer: %w", err) + } + } + return &db, nil } @@ -100,14 +115,6 @@ func (db *DB) QueryRow(query string, args ...any) *sql.Row { return db.readDB.QueryRowContext(db.ctx, query, args...) } -func (db *DB) SetDataEncryptionKey(s string) error { - return db.dataEncryptionKey.FromBase64(s) -} - -func (db *DB) SetVersionsHistoryLimit(n int) { - db.versionsHistoryLimit = n -} - func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error { tx, err := db.writeDB.Begin() if err != nil { |