diff options
Diffstat (limited to '')
-rw-r--r-- | pkg/database/db.go | 103 | ||||
-rw-r--r-- | pkg/database/migrations.go | 63 | ||||
-rw-r--r-- | pkg/database/sql/000_init.sql | 10 | ||||
-rw-r--r-- | pkg/database/states.go | 30 |
4 files changed, 206 insertions, 0 deletions
diff --git a/pkg/database/db.go b/pkg/database/db.go new file mode 100644 index 0000000..2744570 --- /dev/null +++ b/pkg/database/db.go @@ -0,0 +1,103 @@ +package database + +import ( + "context" + "database/sql" + "runtime" + + "git.adyxax.org/adyxax/tfstated/pkg/scrypto" +) + +func initDB(ctx context.Context, url string) (*sql.DB, error) { + db, err := sql.Open("sqlite3", url) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + _ = db.Close() + } + }() + if _, err = db.ExecContext(ctx, "PRAGMA busy_timeout = 5000"); err != nil { + return nil, err + } + + return db, nil +} + +type DB struct { + ctx context.Context + dataEncryptionKey scrypto.AES256Key + readDB *sql.DB + writeDB *sql.DB +} + +func NewDB(ctx context.Context, url string) (*DB, error) { + readDB, err := initDB(ctx, url) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + _ = readDB.Close() + } + }() + readDB.SetMaxOpenConns(max(4, runtime.NumCPU())) + + writeDB, err := initDB(ctx, url) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + _ = writeDB.Close() + } + }() + writeDB.SetMaxOpenConns(1) + + db := DB{ + ctx: ctx, + readDB: readDB, + writeDB: writeDB, + } + if _, err = db.Exec("PRAGMA foreign_keys = ON"); err != nil { + return nil, err + } + if _, err = db.Exec("PRAGMA cache_size = 10000000"); err != nil { + return nil, err + } + if _, err = db.Exec("PRAGMA journal_mode = WAL"); err != nil { + return nil, err + } + if _, err = db.Exec("PRAGMA synchronous = NORMAL"); err != nil { + return nil, err + } + if err = db.migrate(); err != nil { + return nil, err + } + + return &db, nil +} + +func (db *DB) Begin() (*sql.Tx, error) { + return db.writeDB.Begin() +} + +func (db *DB) Close() error { + if err := db.readDB.Close(); err != nil { + _ = db.writeDB.Close() + } + return db.writeDB.Close() +} + +func (db *DB) Exec(query string, args ...any) (sql.Result, error) { + return db.writeDB.ExecContext(db.ctx, query, args...) +} + +func (db *DB) QueryRow(query string, args ...any) *sql.Row { + return db.readDB.QueryRowContext(db.ctx, query, args...) +} + +func (db *DB) SetDataEncryptionKey(s string) error { + return db.dataEncryptionKey.FromBase64(s) +} diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go new file mode 100644 index 0000000..b460884 --- /dev/null +++ b/pkg/database/migrations.go @@ -0,0 +1,63 @@ +package database + +import ( + "embed" + "io/fs" + + _ "github.com/mattn/go-sqlite3" +) + +//go:embed sql/*.sql +var schemaFiles embed.FS + +func (db *DB) migrate() error { + statements := make([]string, 0) + err := fs.WalkDir(schemaFiles, ".", func(path string, d fs.DirEntry, err error) error { + if d.IsDir() || err != nil { + return err + } + var stmts []byte + if stmts, err = schemaFiles.ReadFile(path); err != nil { + return err + } else { + statements = append(statements, string(stmts)) + } + return nil + }) + if err != nil { + return err + } + + tx, err := db.Begin() + if err != nil { + return err + } + defer func() { + if err != nil { + _ = tx.Rollback() + } + }() + + var version int + if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil { + if err.Error() == "no such table: schema_version" { + version = 0 + } else { + return err + } + } + + for version < len(statements) { + if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil { + return err + } + version++ + } + if _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version); err != nil { + return err + } + if err = tx.Commit(); err != nil { + return err + } + return nil +} diff --git a/pkg/database/sql/000_init.sql b/pkg/database/sql/000_init.sql new file mode 100644 index 0000000..0248896 --- /dev/null +++ b/pkg/database/sql/000_init.sql @@ -0,0 +1,10 @@ +CREATE TABLE schema_version ( + version INTEGER NOT NULL +) STRICT; + +CREATE TABLE states ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + data BLOB NOT NULL +) STRICT; +CREATE UNIQUE INDEX states_name on states(name); diff --git a/pkg/database/states.go b/pkg/database/states.go new file mode 100644 index 0000000..ef2263b --- /dev/null +++ b/pkg/database/states.go @@ -0,0 +1,30 @@ +package database + +import ( + "database/sql" +) + +func (db *DB) GetState(name string) ([]byte, error) { + var encryptedData []byte + err := db.QueryRow(`SELECT data FROM states WHERE name = ?;`, name).Scan(&encryptedData) + if err != nil { + return nil, err + } + return db.dataEncryptionKey.DecryptAES256(encryptedData) +} + +func (db *DB) SetState(name string, data []byte) error { + encryptedData, err := db.dataEncryptionKey.EncryptAES256(data) + if err != nil { + return err + } + _, err = db.Exec( + `INSERT INTO states(name, data) VALUES (:name, :data) ON CONFLICT DO UPDATE SET data = :data WHERE name = :name;`, + sql.Named("data", encryptedData), + sql.Named("name", name), + ) + if err != nil { + return err + } + return nil +} |