chore(tfstated): refactor database environment handling

This commit is contained in:
Julien Dessaux 2024-12-17 23:18:34 +01:00
parent 6cf2ca82de
commit c1ddc47c95
Signed by: adyxax
GPG key ID: F92E51B86E07177E
3 changed files with 32 additions and 43 deletions

View file

@ -10,7 +10,6 @@ import (
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"strconv"
"sync" "sync"
"time" "time"
@ -36,22 +35,6 @@ func run(
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel() defer cancel()
dataEncryptionKey := getenv("DATA_ENCRYPTION_KEY")
if dataEncryptionKey == "" {
return fmt.Errorf("the DATA_ENCRYPTION_KEY environment variable is not set")
}
if err := db.SetDataEncryptionKey(dataEncryptionKey); err != nil {
return err
}
versionsHistoryLimit := getenv("VERSIONS_HISTORY_LIMIT")
if versionsHistoryLimit != "" {
n, err := strconv.Atoi(versionsHistoryLimit)
if err != nil {
return fmt.Errorf("failed to parse the VERSIONS_HISTORY_LIMIT environment variable: %w", err)
}
db.SetVersionsHistoryLimit(n)
}
if err := db.InitAdminAccount(); err != nil { if err := db.InitAdminAccount(); err != nil {
return err return err
} }
@ -107,7 +90,7 @@ func main() {
Port: "8080", Port: "8080",
} }
db, err := database.NewDB(ctx, "./tfstate.db?_txlock=immediate") db, err := database.NewDB(ctx, "./tfstate.db?_txlock=immediate", os.Getenv)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "database init error: %+v\n", err) fmt.Fprintf(os.Stderr, "database init error: %+v\n", err)
os.Exit(1) os.Exit(1)

View file

@ -24,19 +24,10 @@ var adminPassword string
var adminPasswordMutex sync.Mutex var adminPasswordMutex sync.Mutex
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
config := Config{ config := Config{
Host: "127.0.0.1", Host: "127.0.0.1",
Port: "8081", Port: "8081",
} }
_ = os.Remove("./test.db")
var err error
db, err = database.NewDB(ctx, "./test.db")
if err != nil {
fmt.Fprintf(os.Stderr, "%+v\n", err)
os.Exit(1)
}
getenv := func(key string) string { getenv := func(key string) string {
switch key { switch key {
case "DATA_ENCRYPTION_KEY": case "DATA_ENCRYPTION_KEY":
@ -48,6 +39,15 @@ func TestMain(m *testing.M) {
} }
} }
ctx := context.Background()
ctx, cancel := context.WithCancel(ctx)
_ = os.Remove("./test.db")
var err error
db, err = database.NewDB(ctx, "./test.db", getenv)
if err != nil {
fmt.Fprintf(os.Stderr, "%+v\n", err)
os.Exit(1)
}
database.AdvertiseAdminPassword = func(password string) { database.AdvertiseAdminPassword = func(password string) {
adminPasswordMutex.Lock() adminPasswordMutex.Lock()
defer adminPasswordMutex.Unlock() defer adminPasswordMutex.Unlock()
@ -125,14 +125,13 @@ func waitForReady(
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
fmt.Printf("Error making request: %s\n", err.Error()) fmt.Printf("Error making request: %s\n", err.Error())
continue } else {
} _ = resp.Body.Close()
if resp.StatusCode == http.StatusOK { if resp.StatusCode == http.StatusOK {
fmt.Println("Endpoint is ready!") fmt.Println("Endpoint is ready!")
resp.Body.Close()
return nil return nil
} }
resp.Body.Close() }
select { select {
case <-ctx.Done(): case <-ctx.Done():

View file

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"runtime" "runtime"
"strconv"
"git.adyxax.org/adyxax/tfstated/pkg/scrypto" "git.adyxax.org/adyxax/tfstated/pkg/scrypto"
) )
@ -34,7 +35,7 @@ type DB struct {
writeDB *sql.DB writeDB *sql.DB
} }
func NewDB(ctx context.Context, url string) (*DB, error) { func NewDB(ctx context.Context, url string, getenv func(string) string) (*DB, error) {
readDB, err := initDB(ctx, url) readDB, err := initDB(ctx, url)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to init read database connection: %w", err) return nil, fmt.Errorf("failed to init read database connection: %w", err)
@ -81,6 +82,20 @@ func NewDB(ctx context.Context, url string) (*DB, error) {
return nil, fmt.Errorf("failed to migrate: %w", err) return nil, fmt.Errorf("failed to migrate: %w", err)
} }
dataEncryptionKey := getenv("DATA_ENCRYPTION_KEY")
if dataEncryptionKey == "" {
return nil, fmt.Errorf("the DATA_ENCRYPTION_KEY environment variable is not set")
}
if err := db.dataEncryptionKey.FromBase64(dataEncryptionKey); err != nil {
return nil, fmt.Errorf("failed to decode the DATA_ENCRYPTION_KEY environment variable, expected base64: %w", err)
}
versionsHistoryLimit := getenv("VERSIONS_HISTORY_LIMIT")
if versionsHistoryLimit != "" {
if db.versionsHistoryLimit, err = strconv.Atoi(versionsHistoryLimit); err != nil {
return nil, fmt.Errorf("failed to parse the VERSIONS_HISTORY_LIMIT environment variable, expected an integer: %w", err)
}
}
return &db, nil return &db, nil
} }
@ -100,14 +115,6 @@ func (db *DB) QueryRow(query string, args ...any) *sql.Row {
return db.readDB.QueryRowContext(db.ctx, query, args...) return db.readDB.QueryRowContext(db.ctx, query, args...)
} }
func (db *DB) SetDataEncryptionKey(s string) error {
return db.dataEncryptionKey.FromBase64(s)
}
func (db *DB) SetVersionsHistoryLimit(n int) {
db.versionsHistoryLimit = n
}
func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error { func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error {
tx, err := db.writeDB.Begin() tx, err := db.writeDB.Begin()
if err != nil { if err != nil {