From f649f7bbbf978db0ad50b0c6d1432264cfec3f43 Mon Sep 17 00:00:00 2001
From: Julien Dessaux
Date: Mon, 18 Nov 2024 00:10:58 +0100
Subject: chore(tfstated): implement a transaction wrapper

---
 pkg/database/db.go | 22 ++++++++++++++++++----
 1 file changed, 18 insertions(+), 4 deletions(-)

(limited to 'pkg/database/db.go')

diff --git a/pkg/database/db.go b/pkg/database/db.go
index 38265ba..5501a82 100644
--- a/pkg/database/db.go
+++ b/pkg/database/db.go
@@ -3,6 +3,7 @@ package database
 import (
 	"context"
 	"database/sql"
+	"fmt"
 	"runtime"
 
 	"git.adyxax.org/adyxax/tfstated/pkg/scrypto"
@@ -81,10 +82,6 @@ func NewDB(ctx context.Context, url string) (*DB, error) {
 	return &db, nil
 }
 
-func (db *DB) Begin() (*sql.Tx, error) {
-	return db.writeDB.Begin()
-}
-
 func (db *DB) Close() error {
 	if err := db.readDB.Close(); err != nil {
 		_ = db.writeDB.Close()
@@ -107,3 +104,20 @@ func (db *DB) SetDataEncryptionKey(s string) error {
 func (db *DB) SetVersionsHistoryLimit(n int) {
 	db.versionsHistoryLimit = n
 }
+
+func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error {
+	tx, err := db.writeDB.Begin()
+	if err != nil {
+		return err
+	}
+	err = f(tx)
+	if err == nil {
+		err = tx.Commit()
+	}
+	if err != nil {
+		if err2 := tx.Rollback(); err2 != nil {
+			panic(fmt.Sprintf("failed to rollback transaction: %+v. Reason for rollback: %+v", err2, err))
+		}
+	}
+	return err
+}
-- 
cgit v1.2.3