1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
|
package database
import (
"context"
"database/sql"
"fmt"
"runtime"
"git.adyxax.org/adyxax/tfstated/pkg/scrypto"
)
func initDB(ctx context.Context, url string) (*sql.DB, error) {
db, err := sql.Open("sqlite3", url)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
defer func() {
if err != nil {
_ = db.Close()
}
}()
if _, err = db.ExecContext(ctx, "PRAGMA busy_timeout = 5000"); err != nil {
return nil, fmt.Errorf("failed to set pragma: %w", err)
}
return db, nil
}
type DB struct {
ctx context.Context
dataEncryptionKey scrypto.AES256Key
readDB *sql.DB
versionsHistoryLimit int
writeDB *sql.DB
}
func NewDB(ctx context.Context, url string) (*DB, error) {
readDB, err := initDB(ctx, url)
if err != nil {
return nil, fmt.Errorf("failed to init read database connection: %w", err)
}
defer func() {
if err != nil {
_ = readDB.Close()
}
}()
readDB.SetMaxOpenConns(max(4, runtime.NumCPU()))
writeDB, err := initDB(ctx, url)
if err != nil {
return nil, fmt.Errorf("failed to init write database connection: %w", err)
}
defer func() {
if err != nil {
_ = writeDB.Close()
}
}()
writeDB.SetMaxOpenConns(1)
db := DB{
ctx: ctx,
readDB: readDB,
versionsHistoryLimit: 64,
writeDB: writeDB,
}
pragmas := []struct {
key string
value string
}{
{"foreign_keys", "ON"},
{"cache_size", "10000000"},
{"journal_mode", "WAL"},
{"synchronous", "NORMAL"},
}
for _, pragma := range pragmas {
if _, err = db.Exec(fmt.Sprintf("PRAGMA %s = %s", pragma.key, pragma.value)); err != nil {
return nil, fmt.Errorf("failed to set pragma: %w", err)
}
}
if err = db.migrate(); err != nil {
return nil, fmt.Errorf("failed to migrate: %w", err)
}
return &db, nil
}
func (db *DB) Close() error {
if err := db.readDB.Close(); err != nil {
_ = db.writeDB.Close()
return fmt.Errorf("failed to close read database connection: %w", err)
}
return db.writeDB.Close()
}
func (db *DB) Exec(query string, args ...any) (sql.Result, error) {
return db.writeDB.ExecContext(db.ctx, query, args...)
}
func (db *DB) QueryRow(query string, args ...any) *sql.Row {
return db.readDB.QueryRowContext(db.ctx, query, args...)
}
func (db *DB) SetDataEncryptionKey(s string) error {
return db.dataEncryptionKey.FromBase64(s)
}
func (db *DB) SetVersionsHistoryLimit(n int) {
db.versionsHistoryLimit = n
}
func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error {
tx, err := db.writeDB.Begin()
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer func() {
if err != nil {
if err2 := tx.Rollback(); err2 != nil {
panic(fmt.Sprintf("failed to rollback transaction: %+v. Reason for rollback: %+v", err2, err))
}
}
}()
if err = f(tx); err != nil {
return fmt.Errorf("failed to execute function inside transaction: %w", err)
} else {
if err = tx.Commit(); err != nil {
err = fmt.Errorf("failed to commit transaction: %w", err)
}
}
return err
}
|