diff options
Diffstat (limited to 'pkg/database/db.go')
-rw-r--r-- | pkg/database/db.go | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/pkg/database/db.go b/pkg/database/db.go index 38265ba..5501a82 100644 --- a/pkg/database/db.go +++ b/pkg/database/db.go @@ -3,6 +3,7 @@ package database import ( "context" "database/sql" + "fmt" "runtime" "git.adyxax.org/adyxax/tfstated/pkg/scrypto" @@ -81,10 +82,6 @@ func NewDB(ctx context.Context, url string) (*DB, error) { return &db, nil } -func (db *DB) Begin() (*sql.Tx, error) { - return db.writeDB.Begin() -} - func (db *DB) Close() error { if err := db.readDB.Close(); err != nil { _ = db.writeDB.Close() @@ -107,3 +104,20 @@ func (db *DB) SetDataEncryptionKey(s string) error { func (db *DB) SetVersionsHistoryLimit(n int) { db.versionsHistoryLimit = n } + +func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error { + tx, err := db.writeDB.Begin() + if err != nil { + return err + } + err = f(tx) + if err == nil { + err = tx.Commit() + } + if err != nil { + if err2 := tx.Rollback(); err2 != nil { + panic(fmt.Sprintf("failed to rollback transaction: %+v. Reason for rollback: %+v", err2, err)) + } + } + return err +} |