summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--pkg/database/accounts.go55
-rw-r--r--pkg/database/db.go22
-rw-r--r--pkg/database/locks.go61
-rw-r--r--pkg/database/migrations.go44
-rw-r--r--pkg/database/states.go87
5 files changed, 119 insertions, 150 deletions
diff --git a/pkg/database/accounts.go b/pkg/database/accounts.go
index f506155..e1c7109 100644
--- a/pkg/database/accounts.go
+++ b/pkg/database/accounts.go
@@ -50,42 +50,31 @@ func (db *DB) LoadAccountByUsername(username string) (*model.Account, error) {
}
func (db *DB) InitAdminAccount() error {
- tx, err := db.Begin()
- if err != nil {
- return err
- }
- defer func() {
- if err != nil {
- _ = tx.Rollback()
- }
- }()
- var hasAdminAccount bool
- if err = tx.QueryRowContext(db.ctx, `SELECT EXISTS (SELECT 1 FROM accounts WHERE is_admin);`).Scan(&hasAdminAccount); err != nil {
- return fmt.Errorf("failed to select if there is an admin account in the database: %w", err)
- }
- if hasAdminAccount {
- tx.Rollback()
- } else {
- var password uuid.UUID
- if err = password.Generate(uuid.V4); err != nil {
- return fmt.Errorf("failed to generate initial admin password: %w", err)
+ return db.WithTransaction(func(tx *sql.Tx) error {
+ var hasAdminAccount bool
+ if err := tx.QueryRowContext(db.ctx, `SELECT EXISTS (SELECT 1 FROM accounts WHERE is_admin);`).Scan(&hasAdminAccount); err != nil {
+ return fmt.Errorf("failed to select if there is an admin account in the database: %w", err)
}
- salt := helpers.GenerateSalt()
- hash := helpers.HashPassword(password.String(), salt)
- if _, err = tx.ExecContext(db.ctx,
- `INSERT INTO accounts(username, salt, password_hash, is_admin)
+ if !hasAdminAccount {
+ var password uuid.UUID
+ if err := password.Generate(uuid.V4); err != nil {
+ return fmt.Errorf("failed to generate initial admin password: %w", err)
+ }
+ salt := helpers.GenerateSalt()
+ hash := helpers.HashPassword(password.String(), salt)
+ if _, err := tx.ExecContext(db.ctx,
+ `INSERT INTO accounts(username, salt, password_hash, is_admin)
VALUES ("admin", :salt, :hash, TRUE)
ON CONFLICT DO UPDATE SET password_hash = :hash
WHERE username = "admin";`,
- sql.Named("salt", salt),
- sql.Named("hash", hash),
- ); err != nil {
- return fmt.Errorf("failed to set initial admin password: %w", err)
- }
- err = tx.Commit()
- if err == nil {
- AdvertiseAdminPassword(password.String())
+ sql.Named("salt", salt),
+ sql.Named("hash", hash),
+ ); err == nil {
+ AdvertiseAdminPassword(password.String())
+ } else {
+ return fmt.Errorf("failed to set initial admin password: %w", err)
+ }
}
- }
- return err
+ return nil
+ })
}
diff --git a/pkg/database/db.go b/pkg/database/db.go
index 38265ba..5501a82 100644
--- a/pkg/database/db.go
+++ b/pkg/database/db.go
@@ -3,6 +3,7 @@ package database
import (
"context"
"database/sql"
+ "fmt"
"runtime"
"git.adyxax.org/adyxax/tfstated/pkg/scrypto"
@@ -81,10 +82,6 @@ func NewDB(ctx context.Context, url string) (*DB, error) {
return &db, nil
}
-func (db *DB) Begin() (*sql.Tx, error) {
- return db.writeDB.Begin()
-}
-
func (db *DB) Close() error {
if err := db.readDB.Close(); err != nil {
_ = db.writeDB.Close()
@@ -107,3 +104,20 @@ func (db *DB) SetDataEncryptionKey(s string) error {
func (db *DB) SetVersionsHistoryLimit(n int) {
db.versionsHistoryLimit = n
}
+
+func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error {
+ tx, err := db.writeDB.Begin()
+ if err != nil {
+ return err
+ }
+ err = f(tx)
+ if err == nil {
+ err = tx.Commit()
+ }
+ if err != nil {
+ if err2 := tx.Rollback(); err2 != nil {
+ panic(fmt.Sprintf("failed to rollback transaction: %+v. Reason for rollback: %+v", err2, err))
+ }
+ }
+ return err
+}
diff --git a/pkg/database/locks.go b/pkg/database/locks.go
index 6951337..e78fd0c 100644
--- a/pkg/database/locks.go
+++ b/pkg/database/locks.go
@@ -10,45 +10,32 @@ import (
// true if the function locked the state, otherwise returns false and the lock
// parameter is updated to the value of the existing lock
func (db *DB) SetLockOrGetExistingLock(path string, lock any) (bool, error) {
- tx, err := db.Begin()
- if err != nil {
- return false, err
- }
- defer func() {
- if err != nil {
- _ = tx.Rollback()
- }
- }()
- var lockData []byte
- if err = tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE path = ?;`, path).Scan(&lockData); err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- if lockData, err = json.Marshal(lock); err != nil {
- return false, err
+ ret := false
+ return ret, db.WithTransaction(func(tx *sql.Tx) error {
+ var lockData []byte
+ if err := tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE path = ?;`, path).Scan(&lockData); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ if lockData, err = json.Marshal(lock); err != nil {
+ return err
+ }
+ _, err = tx.ExecContext(db.ctx, `INSERT INTO states(path, lock) VALUES (?, json(?))`, path, lockData)
+ ret = true
+ return err
+ } else {
+ return err
}
- _, err = tx.ExecContext(db.ctx, `INSERT INTO states(path, lock) VALUES (?, json(?))`, path, lockData)
- if err != nil {
- return false, err
- }
- err = tx.Commit()
- return true, err
- } else {
- return false, err
}
- }
- if lockData != nil {
- _ = tx.Rollback()
- err = json.Unmarshal(lockData, lock)
- return false, err
- }
- if lockData, err = json.Marshal(lock); err != nil {
- return false, err
- }
- _, err = tx.ExecContext(db.ctx, `UPDATE states SET lock = json(?) WHERE path = ?;`, lockData, path)
- if err != nil {
- return false, err
- }
- err = tx.Commit()
- return true, err
+ if lockData != nil {
+ return json.Unmarshal(lockData, lock)
+ }
+ var err error
+ if lockData, err = json.Marshal(lock); err != nil {
+ return err
+ }
+ _, err = tx.ExecContext(db.ctx, `UPDATE states SET lock = json(?) WHERE path = ?;`, lockData, path)
+ ret = true
+ return err
+ })
}
func (db *DB) Unlock(path, lock any) (bool, error) {
diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go
index b460884..31cc3d3 100644
--- a/pkg/database/migrations.go
+++ b/pkg/database/migrations.go
@@ -1,6 +1,7 @@
package database
import (
+ "database/sql"
"embed"
"io/fs"
@@ -28,36 +29,23 @@ func (db *DB) migrate() error {
return err
}
- tx, err := db.Begin()
- if err != nil {
- return err
- }
- defer func() {
- if err != nil {
- _ = tx.Rollback()
- }
- }()
-
- var version int
- if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil {
- if err.Error() == "no such table: schema_version" {
- version = 0
- } else {
- return err
+ return db.WithTransaction(func(tx *sql.Tx) error {
+ var version int
+ if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil {
+ if err.Error() == "no such table: schema_version" {
+ version = 0
+ } else {
+ return err
+ }
}
- }
- for version < len(statements) {
- if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil {
- return err
+ for version < len(statements) {
+ if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil {
+ return err
+ }
+ version++
}
- version++
- }
- if _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version); err != nil {
+ _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version)
return err
- }
- if err = tx.Commit(); err != nil {
- return err
- }
- return nil
+ })
}
diff --git a/pkg/database/states.go b/pkg/database/states.go
index 4f0ce58..74839da 100644
--- a/pkg/database/states.go
+++ b/pkg/database/states.go
@@ -48,52 +48,46 @@ func (db *DB) SetState(path string, accountID int, data []byte, lockID string) (
if err != nil {
return false, err
}
- tx, err := db.Begin()
- if err != nil {
- return false, err
- }
- defer func() {
- if err != nil {
- _ = tx.Rollback()
- }
- }()
- var (
- stateID int64
- lockData []byte
- )
- if err = tx.QueryRowContext(db.ctx, `SELECT id, lock->>'ID' FROM states WHERE path = ?;`, path).Scan(&stateID, &lockData); err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- var result sql.Result
- result, err = tx.ExecContext(db.ctx, `INSERT INTO states(path) VALUES (?)`, path)
- if err != nil {
- return false, err
- }
- stateID, err = result.LastInsertId()
- if err != nil {
- return false, err
+ ret := false
+ return ret, db.WithTransaction(func(tx *sql.Tx) error {
+ var (
+ stateID int64
+ lockData []byte
+ )
+ if err = tx.QueryRowContext(db.ctx, `SELECT id, lock->>'ID' FROM states WHERE path = ?;`, path).Scan(&stateID, &lockData); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ var result sql.Result
+ result, err = tx.ExecContext(db.ctx, `INSERT INTO states(path) VALUES (?)`, path)
+ if err != nil {
+ return err
+ }
+ stateID, err = result.LastInsertId()
+ if err != nil {
+ return err
+ }
+ } else {
+ return err
}
- } else {
- return false, err
}
- }
- if lockID != "" && slices.Compare([]byte(lockID), lockData) != 0 {
- err = fmt.Errorf("failed to update state, lock ID does not match")
- return true, err
- }
- _, err = tx.ExecContext(db.ctx,
- `INSERT INTO versions(account_id, state_id, data, lock)
+ if lockID != "" && slices.Compare([]byte(lockID), lockData) != 0 {
+ err = fmt.Errorf("failed to update state, lock ID does not match")
+ ret = true
+ return err
+ }
+ _, err = tx.ExecContext(db.ctx,
+ `INSERT INTO versions(account_id, state_id, data, lock)
SELECT :accountID, :stateID, :data, lock
FROM states
WHERE states.id = :stateID;`,
- sql.Named("accountID", accountID),
- sql.Named("stateID", stateID),
- sql.Named("data", encryptedData))
- if err != nil {
- return false, err
- }
- _, err = tx.ExecContext(db.ctx,
- `DELETE FROM versions
+ sql.Named("accountID", accountID),
+ sql.Named("stateID", stateID),
+ sql.Named("data", encryptedData))
+ if err != nil {
+ return err
+ }
+ _, err = tx.ExecContext(db.ctx,
+ `DELETE FROM versions
WHERE state_id = (SELECT id
FROM states
WHERE path = :path)
@@ -104,12 +98,9 @@ func (db *DB) SetState(path string, accountID int, data []byte, lockID string) (
WHERE states.path = :path
ORDER BY versions.id DESC
LIMIT :limit));`,
- sql.Named("limit", db.versionsHistoryLimit),
- sql.Named("path", path),
- )
- if err != nil {
- return false, err
- }
- err = tx.Commit()
- return false, err
+ sql.Named("limit", db.versionsHistoryLimit),
+ sql.Named("path", path),
+ )
+ return err
+ })
}