summaryrefslogtreecommitdiff
path: root/pkg/database
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--pkg/database/db.go103
-rw-r--r--pkg/database/migrations.go63
-rw-r--r--pkg/database/sql/000_init.sql10
-rw-r--r--pkg/database/states.go30
4 files changed, 206 insertions, 0 deletions
diff --git a/pkg/database/db.go b/pkg/database/db.go
new file mode 100644
index 0000000..2744570
--- /dev/null
+++ b/pkg/database/db.go
@@ -0,0 +1,103 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "runtime"
+
+ "git.adyxax.org/adyxax/tfstated/pkg/scrypto"
+)
+
+func initDB(ctx context.Context, url string) (*sql.DB, error) {
+ db, err := sql.Open("sqlite3", url)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ if err != nil {
+ _ = db.Close()
+ }
+ }()
+ if _, err = db.ExecContext(ctx, "PRAGMA busy_timeout = 5000"); err != nil {
+ return nil, err
+ }
+
+ return db, nil
+}
+
+type DB struct {
+ ctx context.Context
+ dataEncryptionKey scrypto.AES256Key
+ readDB *sql.DB
+ writeDB *sql.DB
+}
+
+func NewDB(ctx context.Context, url string) (*DB, error) {
+ readDB, err := initDB(ctx, url)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ if err != nil {
+ _ = readDB.Close()
+ }
+ }()
+ readDB.SetMaxOpenConns(max(4, runtime.NumCPU()))
+
+ writeDB, err := initDB(ctx, url)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ if err != nil {
+ _ = writeDB.Close()
+ }
+ }()
+ writeDB.SetMaxOpenConns(1)
+
+ db := DB{
+ ctx: ctx,
+ readDB: readDB,
+ writeDB: writeDB,
+ }
+ if _, err = db.Exec("PRAGMA foreign_keys = ON"); err != nil {
+ return nil, err
+ }
+ if _, err = db.Exec("PRAGMA cache_size = 10000000"); err != nil {
+ return nil, err
+ }
+ if _, err = db.Exec("PRAGMA journal_mode = WAL"); err != nil {
+ return nil, err
+ }
+ if _, err = db.Exec("PRAGMA synchronous = NORMAL"); err != nil {
+ return nil, err
+ }
+ if err = db.migrate(); err != nil {
+ return nil, err
+ }
+
+ return &db, nil
+}
+
+func (db *DB) Begin() (*sql.Tx, error) {
+ return db.writeDB.Begin()
+}
+
+func (db *DB) Close() error {
+ if err := db.readDB.Close(); err != nil {
+ _ = db.writeDB.Close()
+ }
+ return db.writeDB.Close()
+}
+
+func (db *DB) Exec(query string, args ...any) (sql.Result, error) {
+ return db.writeDB.ExecContext(db.ctx, query, args...)
+}
+
+func (db *DB) QueryRow(query string, args ...any) *sql.Row {
+ return db.readDB.QueryRowContext(db.ctx, query, args...)
+}
+
+func (db *DB) SetDataEncryptionKey(s string) error {
+ return db.dataEncryptionKey.FromBase64(s)
+}
diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go
new file mode 100644
index 0000000..b460884
--- /dev/null
+++ b/pkg/database/migrations.go
@@ -0,0 +1,63 @@
+package database
+
+import (
+ "embed"
+ "io/fs"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+//go:embed sql/*.sql
+var schemaFiles embed.FS
+
+func (db *DB) migrate() error {
+ statements := make([]string, 0)
+ err := fs.WalkDir(schemaFiles, ".", func(path string, d fs.DirEntry, err error) error {
+ if d.IsDir() || err != nil {
+ return err
+ }
+ var stmts []byte
+ if stmts, err = schemaFiles.ReadFile(path); err != nil {
+ return err
+ } else {
+ statements = append(statements, string(stmts))
+ }
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+
+ tx, err := db.Begin()
+ if err != nil {
+ return err
+ }
+ defer func() {
+ if err != nil {
+ _ = tx.Rollback()
+ }
+ }()
+
+ var version int
+ if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil {
+ if err.Error() == "no such table: schema_version" {
+ version = 0
+ } else {
+ return err
+ }
+ }
+
+ for version < len(statements) {
+ if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil {
+ return err
+ }
+ version++
+ }
+ if _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version); err != nil {
+ return err
+ }
+ if err = tx.Commit(); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/pkg/database/sql/000_init.sql b/pkg/database/sql/000_init.sql
new file mode 100644
index 0000000..0248896
--- /dev/null
+++ b/pkg/database/sql/000_init.sql
@@ -0,0 +1,10 @@
+CREATE TABLE schema_version (
+ version INTEGER NOT NULL
+) STRICT;
+
+CREATE TABLE states (
+ id INTEGER PRIMARY KEY,
+ name TEXT NOT NULL,
+ data BLOB NOT NULL
+) STRICT;
+CREATE UNIQUE INDEX states_name on states(name);
diff --git a/pkg/database/states.go b/pkg/database/states.go
new file mode 100644
index 0000000..ef2263b
--- /dev/null
+++ b/pkg/database/states.go
@@ -0,0 +1,30 @@
+package database
+
+import (
+ "database/sql"
+)
+
+func (db *DB) GetState(name string) ([]byte, error) {
+ var encryptedData []byte
+ err := db.QueryRow(`SELECT data FROM states WHERE name = ?;`, name).Scan(&encryptedData)
+ if err != nil {
+ return nil, err
+ }
+ return db.dataEncryptionKey.DecryptAES256(encryptedData)
+}
+
+func (db *DB) SetState(name string, data []byte) error {
+ encryptedData, err := db.dataEncryptionKey.EncryptAES256(data)
+ if err != nil {
+ return err
+ }
+ _, err = db.Exec(
+ `INSERT INTO states(name, data) VALUES (:name, :data) ON CONFLICT DO UPDATE SET data = :data WHERE name = :name;`,
+ sql.Named("data", encryptedData),
+ sql.Named("name", name),
+ )
+ if err != nil {
+ return err
+ }
+ return nil
+}