From c1ddc47c95e04c242263ee9c5074cac9a2422ca2 Mon Sep 17 00:00:00 2001 From: Julien Dessaux Date: Tue, 17 Dec 2024 23:18:34 +0100 Subject: chore(tfstated): refactor database environment handling --- pkg/database/db.go | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) (limited to 'pkg/database/db.go') diff --git a/pkg/database/db.go b/pkg/database/db.go index b7f0e89..9aeec43 100644 --- a/pkg/database/db.go +++ b/pkg/database/db.go @@ -5,6 +5,7 @@ import ( "database/sql" "fmt" "runtime" + "strconv" "git.adyxax.org/adyxax/tfstated/pkg/scrypto" ) @@ -34,7 +35,7 @@ type DB struct { writeDB *sql.DB } -func NewDB(ctx context.Context, url string) (*DB, error) { +func NewDB(ctx context.Context, url string, getenv func(string) string) (*DB, error) { readDB, err := initDB(ctx, url) if err != nil { return nil, fmt.Errorf("failed to init read database connection: %w", err) @@ -81,6 +82,20 @@ func NewDB(ctx context.Context, url string) (*DB, error) { return nil, fmt.Errorf("failed to migrate: %w", err) } + dataEncryptionKey := getenv("DATA_ENCRYPTION_KEY") + if dataEncryptionKey == "" { + return nil, fmt.Errorf("the DATA_ENCRYPTION_KEY environment variable is not set") + } + if err := db.dataEncryptionKey.FromBase64(dataEncryptionKey); err != nil { + return nil, fmt.Errorf("failed to decode the DATA_ENCRYPTION_KEY environment variable, expected base64: %w", err) + } + versionsHistoryLimit := getenv("VERSIONS_HISTORY_LIMIT") + if versionsHistoryLimit != "" { + if db.versionsHistoryLimit, err = strconv.Atoi(versionsHistoryLimit); err != nil { + return nil, fmt.Errorf("failed to parse the VERSIONS_HISTORY_LIMIT environment variable, expected an integer: %w", err) + } + } + return &db, nil } @@ -100,14 +115,6 @@ func (db *DB) QueryRow(query string, args ...any) *sql.Row { return db.readDB.QueryRowContext(db.ctx, query, args...) } -func (db *DB) SetDataEncryptionKey(s string) error { - return db.dataEncryptionKey.FromBase64(s) -} - -func (db *DB) SetVersionsHistoryLimit(n int) { - db.versionsHistoryLimit = n -} - func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error { tx, err := db.writeDB.Begin() if err != nil { -- cgit v1.2.3