chore(tfstated): refactor database environment handling
This commit is contained in:
parent
6cf2ca82de
commit
c1ddc47c95
3 changed files with 32 additions and 43 deletions
|
@ -10,7 +10,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -36,22 +35,6 @@ func run(
|
||||||
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
dataEncryptionKey := getenv("DATA_ENCRYPTION_KEY")
|
|
||||||
if dataEncryptionKey == "" {
|
|
||||||
return fmt.Errorf("the DATA_ENCRYPTION_KEY environment variable is not set")
|
|
||||||
}
|
|
||||||
if err := db.SetDataEncryptionKey(dataEncryptionKey); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
versionsHistoryLimit := getenv("VERSIONS_HISTORY_LIMIT")
|
|
||||||
if versionsHistoryLimit != "" {
|
|
||||||
n, err := strconv.Atoi(versionsHistoryLimit)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to parse the VERSIONS_HISTORY_LIMIT environment variable: %w", err)
|
|
||||||
}
|
|
||||||
db.SetVersionsHistoryLimit(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := db.InitAdminAccount(); err != nil {
|
if err := db.InitAdminAccount(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -107,7 +90,7 @@ func main() {
|
||||||
Port: "8080",
|
Port: "8080",
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := database.NewDB(ctx, "./tfstate.db?_txlock=immediate")
|
db, err := database.NewDB(ctx, "./tfstate.db?_txlock=immediate", os.Getenv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "database init error: %+v\n", err)
|
fmt.Fprintf(os.Stderr, "database init error: %+v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
|
|
@ -24,19 +24,10 @@ var adminPassword string
|
||||||
var adminPasswordMutex sync.Mutex
|
var adminPasswordMutex sync.Mutex
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
ctx := context.Background()
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
config := Config{
|
config := Config{
|
||||||
Host: "127.0.0.1",
|
Host: "127.0.0.1",
|
||||||
Port: "8081",
|
Port: "8081",
|
||||||
}
|
}
|
||||||
_ = os.Remove("./test.db")
|
|
||||||
var err error
|
|
||||||
db, err = database.NewDB(ctx, "./test.db")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "%+v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
getenv := func(key string) string {
|
getenv := func(key string) string {
|
||||||
switch key {
|
switch key {
|
||||||
case "DATA_ENCRYPTION_KEY":
|
case "DATA_ENCRYPTION_KEY":
|
||||||
|
@ -48,6 +39,15 @@ func TestMain(m *testing.M) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
_ = os.Remove("./test.db")
|
||||||
|
var err error
|
||||||
|
db, err = database.NewDB(ctx, "./test.db", getenv)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%+v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
database.AdvertiseAdminPassword = func(password string) {
|
database.AdvertiseAdminPassword = func(password string) {
|
||||||
adminPasswordMutex.Lock()
|
adminPasswordMutex.Lock()
|
||||||
defer adminPasswordMutex.Unlock()
|
defer adminPasswordMutex.Unlock()
|
||||||
|
@ -125,14 +125,13 @@ func waitForReady(
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("Error making request: %s\n", err.Error())
|
fmt.Printf("Error making request: %s\n", err.Error())
|
||||||
continue
|
} else {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if resp.StatusCode == http.StatusOK {
|
||||||
|
fmt.Println("Endpoint is ready!")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if resp.StatusCode == http.StatusOK {
|
|
||||||
fmt.Println("Endpoint is ready!")
|
|
||||||
resp.Body.Close()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
resp.Body.Close()
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"git.adyxax.org/adyxax/tfstated/pkg/scrypto"
|
"git.adyxax.org/adyxax/tfstated/pkg/scrypto"
|
||||||
)
|
)
|
||||||
|
@ -34,7 +35,7 @@ type DB struct {
|
||||||
writeDB *sql.DB
|
writeDB *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDB(ctx context.Context, url string) (*DB, error) {
|
func NewDB(ctx context.Context, url string, getenv func(string) string) (*DB, error) {
|
||||||
readDB, err := initDB(ctx, url)
|
readDB, err := initDB(ctx, url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to init read database connection: %w", err)
|
return nil, fmt.Errorf("failed to init read database connection: %w", err)
|
||||||
|
@ -81,6 +82,20 @@ func NewDB(ctx context.Context, url string) (*DB, error) {
|
||||||
return nil, fmt.Errorf("failed to migrate: %w", err)
|
return nil, fmt.Errorf("failed to migrate: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dataEncryptionKey := getenv("DATA_ENCRYPTION_KEY")
|
||||||
|
if dataEncryptionKey == "" {
|
||||||
|
return nil, fmt.Errorf("the DATA_ENCRYPTION_KEY environment variable is not set")
|
||||||
|
}
|
||||||
|
if err := db.dataEncryptionKey.FromBase64(dataEncryptionKey); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode the DATA_ENCRYPTION_KEY environment variable, expected base64: %w", err)
|
||||||
|
}
|
||||||
|
versionsHistoryLimit := getenv("VERSIONS_HISTORY_LIMIT")
|
||||||
|
if versionsHistoryLimit != "" {
|
||||||
|
if db.versionsHistoryLimit, err = strconv.Atoi(versionsHistoryLimit); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse the VERSIONS_HISTORY_LIMIT environment variable, expected an integer: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &db, nil
|
return &db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,14 +115,6 @@ func (db *DB) QueryRow(query string, args ...any) *sql.Row {
|
||||||
return db.readDB.QueryRowContext(db.ctx, query, args...)
|
return db.readDB.QueryRowContext(db.ctx, query, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (db *DB) SetDataEncryptionKey(s string) error {
|
|
||||||
return db.dataEncryptionKey.FromBase64(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) SetVersionsHistoryLimit(n int) {
|
|
||||||
db.versionsHistoryLimit = n
|
|
||||||
}
|
|
||||||
|
|
||||||
func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error {
|
func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error {
|
||||||
tx, err := db.writeDB.Begin()
|
tx, err := db.writeDB.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Add table
Reference in a new issue