chore(tfstated): implement a transaction wrapper

This commit is contained in:
Julien Dessaux 2024-11-18 00:10:58 +01:00
parent 25ed1188ed
commit f649f7bbbf
Signed by: adyxax
GPG key ID: F92E51B86E07177E
5 changed files with 119 additions and 150 deletions

View file

@ -50,42 +50,31 @@ func (db *DB) LoadAccountByUsername(username string) (*model.Account, error) {
} }
func (db *DB) InitAdminAccount() error { func (db *DB) InitAdminAccount() error {
tx, err := db.Begin() return db.WithTransaction(func(tx *sql.Tx) error {
if err != nil { var hasAdminAccount bool
return err if err := tx.QueryRowContext(db.ctx, `SELECT EXISTS (SELECT 1 FROM accounts WHERE is_admin);`).Scan(&hasAdminAccount); err != nil {
} return fmt.Errorf("failed to select if there is an admin account in the database: %w", err)
defer func() {
if err != nil {
_ = tx.Rollback()
} }
}() if !hasAdminAccount {
var hasAdminAccount bool var password uuid.UUID
if err = tx.QueryRowContext(db.ctx, `SELECT EXISTS (SELECT 1 FROM accounts WHERE is_admin);`).Scan(&hasAdminAccount); err != nil { if err := password.Generate(uuid.V4); err != nil {
return fmt.Errorf("failed to select if there is an admin account in the database: %w", err) return fmt.Errorf("failed to generate initial admin password: %w", err)
} }
if hasAdminAccount { salt := helpers.GenerateSalt()
tx.Rollback() hash := helpers.HashPassword(password.String(), salt)
} else { if _, err := tx.ExecContext(db.ctx,
var password uuid.UUID `INSERT INTO accounts(username, salt, password_hash, is_admin)
if err = password.Generate(uuid.V4); err != nil {
return fmt.Errorf("failed to generate initial admin password: %w", err)
}
salt := helpers.GenerateSalt()
hash := helpers.HashPassword(password.String(), salt)
if _, err = tx.ExecContext(db.ctx,
`INSERT INTO accounts(username, salt, password_hash, is_admin)
VALUES ("admin", :salt, :hash, TRUE) VALUES ("admin", :salt, :hash, TRUE)
ON CONFLICT DO UPDATE SET password_hash = :hash ON CONFLICT DO UPDATE SET password_hash = :hash
WHERE username = "admin";`, WHERE username = "admin";`,
sql.Named("salt", salt), sql.Named("salt", salt),
sql.Named("hash", hash), sql.Named("hash", hash),
); err != nil { ); err == nil {
return fmt.Errorf("failed to set initial admin password: %w", err) AdvertiseAdminPassword(password.String())
} else {
return fmt.Errorf("failed to set initial admin password: %w", err)
}
} }
err = tx.Commit() return nil
if err == nil { })
AdvertiseAdminPassword(password.String())
}
}
return err
} }

View file

@ -3,6 +3,7 @@ package database
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"runtime" "runtime"
"git.adyxax.org/adyxax/tfstated/pkg/scrypto" "git.adyxax.org/adyxax/tfstated/pkg/scrypto"
@ -81,10 +82,6 @@ func NewDB(ctx context.Context, url string) (*DB, error) {
return &db, nil return &db, nil
} }
func (db *DB) Begin() (*sql.Tx, error) {
return db.writeDB.Begin()
}
func (db *DB) Close() error { func (db *DB) Close() error {
if err := db.readDB.Close(); err != nil { if err := db.readDB.Close(); err != nil {
_ = db.writeDB.Close() _ = db.writeDB.Close()
@ -107,3 +104,20 @@ func (db *DB) SetDataEncryptionKey(s string) error {
func (db *DB) SetVersionsHistoryLimit(n int) { func (db *DB) SetVersionsHistoryLimit(n int) {
db.versionsHistoryLimit = n db.versionsHistoryLimit = n
} }
func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error {
tx, err := db.writeDB.Begin()
if err != nil {
return err
}
err = f(tx)
if err == nil {
err = tx.Commit()
}
if err != nil {
if err2 := tx.Rollback(); err2 != nil {
panic(fmt.Sprintf("failed to rollback transaction: %+v. Reason for rollback: %+v", err2, err))
}
}
return err
}

View file

@ -10,45 +10,32 @@ import (
// true if the function locked the state, otherwise returns false and the lock // true if the function locked the state, otherwise returns false and the lock
// parameter is updated to the value of the existing lock // parameter is updated to the value of the existing lock
func (db *DB) SetLockOrGetExistingLock(path string, lock any) (bool, error) { func (db *DB) SetLockOrGetExistingLock(path string, lock any) (bool, error) {
tx, err := db.Begin() ret := false
if err != nil { return ret, db.WithTransaction(func(tx *sql.Tx) error {
return false, err var lockData []byte
} if err := tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE path = ?;`, path).Scan(&lockData); err != nil {
defer func() { if errors.Is(err, sql.ErrNoRows) {
if err != nil { if lockData, err = json.Marshal(lock); err != nil {
_ = tx.Rollback() return err
} }
}() _, err = tx.ExecContext(db.ctx, `INSERT INTO states(path, lock) VALUES (?, json(?))`, path, lockData)
var lockData []byte ret = true
if err = tx.QueryRowContext(db.ctx, `SELECT lock FROM states WHERE path = ?;`, path).Scan(&lockData); err != nil { return err
if errors.Is(err, sql.ErrNoRows) { } else {
if lockData, err = json.Marshal(lock); err != nil { return err
return false, err
} }
_, err = tx.ExecContext(db.ctx, `INSERT INTO states(path, lock) VALUES (?, json(?))`, path, lockData)
if err != nil {
return false, err
}
err = tx.Commit()
return true, err
} else {
return false, err
} }
} if lockData != nil {
if lockData != nil { return json.Unmarshal(lockData, lock)
_ = tx.Rollback() }
err = json.Unmarshal(lockData, lock) var err error
return false, err if lockData, err = json.Marshal(lock); err != nil {
} return err
if lockData, err = json.Marshal(lock); err != nil { }
return false, err _, err = tx.ExecContext(db.ctx, `UPDATE states SET lock = json(?) WHERE path = ?;`, lockData, path)
} ret = true
_, err = tx.ExecContext(db.ctx, `UPDATE states SET lock = json(?) WHERE path = ?;`, lockData, path) return err
if err != nil { })
return false, err
}
err = tx.Commit()
return true, err
} }
func (db *DB) Unlock(path, lock any) (bool, error) { func (db *DB) Unlock(path, lock any) (bool, error) {

View file

@ -1,6 +1,7 @@
package database package database
import ( import (
"database/sql"
"embed" "embed"
"io/fs" "io/fs"
@ -28,36 +29,23 @@ func (db *DB) migrate() error {
return err return err
} }
tx, err := db.Begin() return db.WithTransaction(func(tx *sql.Tx) error {
if err != nil { var version int
return err if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil {
} if err.Error() == "no such table: schema_version" {
defer func() { version = 0
if err != nil { } else {
_ = tx.Rollback() return err
}
} }
}()
var version int for version < len(statements) {
if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil { if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil {
if err.Error() == "no such table: schema_version" { return err
version = 0 }
} else { version++
return err
} }
} _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version)
for version < len(statements) {
if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil {
return err
}
version++
}
if _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version); err != nil {
return err return err
} })
if err = tx.Commit(); err != nil {
return err
}
return nil
} }

View file

@ -48,52 +48,46 @@ func (db *DB) SetState(path string, accountID int, data []byte, lockID string) (
if err != nil { if err != nil {
return false, err return false, err
} }
tx, err := db.Begin() ret := false
if err != nil { return ret, db.WithTransaction(func(tx *sql.Tx) error {
return false, err var (
} stateID int64
defer func() { lockData []byte
if err != nil { )
_ = tx.Rollback() if err = tx.QueryRowContext(db.ctx, `SELECT id, lock->>'ID' FROM states WHERE path = ?;`, path).Scan(&stateID, &lockData); err != nil {
} if errors.Is(err, sql.ErrNoRows) {
}() var result sql.Result
var ( result, err = tx.ExecContext(db.ctx, `INSERT INTO states(path) VALUES (?)`, path)
stateID int64 if err != nil {
lockData []byte return err
) }
if err = tx.QueryRowContext(db.ctx, `SELECT id, lock->>'ID' FROM states WHERE path = ?;`, path).Scan(&stateID, &lockData); err != nil { stateID, err = result.LastInsertId()
if errors.Is(err, sql.ErrNoRows) { if err != nil {
var result sql.Result return err
result, err = tx.ExecContext(db.ctx, `INSERT INTO states(path) VALUES (?)`, path) }
if err != nil { } else {
return false, err return err
} }
stateID, err = result.LastInsertId()
if err != nil {
return false, err
}
} else {
return false, err
} }
}
if lockID != "" && slices.Compare([]byte(lockID), lockData) != 0 { if lockID != "" && slices.Compare([]byte(lockID), lockData) != 0 {
err = fmt.Errorf("failed to update state, lock ID does not match") err = fmt.Errorf("failed to update state, lock ID does not match")
return true, err ret = true
} return err
_, err = tx.ExecContext(db.ctx, }
`INSERT INTO versions(account_id, state_id, data, lock) _, err = tx.ExecContext(db.ctx,
`INSERT INTO versions(account_id, state_id, data, lock)
SELECT :accountID, :stateID, :data, lock SELECT :accountID, :stateID, :data, lock
FROM states FROM states
WHERE states.id = :stateID;`, WHERE states.id = :stateID;`,
sql.Named("accountID", accountID), sql.Named("accountID", accountID),
sql.Named("stateID", stateID), sql.Named("stateID", stateID),
sql.Named("data", encryptedData)) sql.Named("data", encryptedData))
if err != nil { if err != nil {
return false, err return err
} }
_, err = tx.ExecContext(db.ctx, _, err = tx.ExecContext(db.ctx,
`DELETE FROM versions `DELETE FROM versions
WHERE state_id = (SELECT id WHERE state_id = (SELECT id
FROM states FROM states
WHERE path = :path) WHERE path = :path)
@ -104,12 +98,9 @@ func (db *DB) SetState(path string, accountID int, data []byte, lockID string) (
WHERE states.path = :path WHERE states.path = :path
ORDER BY versions.id DESC ORDER BY versions.id DESC
LIMIT :limit));`, LIMIT :limit));`,
sql.Named("limit", db.versionsHistoryLimit), sql.Named("limit", db.versionsHistoryLimit),
sql.Named("path", path), sql.Named("path", path),
) )
if err != nil { return err
return false, err })
}
err = tx.Commit()
return false, err
} }