summaryrefslogtreecommitdiff
path: root/pkg/database/db.go
blob: 0777556e3604dab6ec4c1fb854a5c96d9382990f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package database

import (
	"context"
	"database/sql"
	"fmt"
	"runtime"

	"git.adyxax.org/adyxax/tfstated/pkg/scrypto"
)

func initDB(ctx context.Context, url string) (*sql.DB, error) {
	db, err := sql.Open("sqlite3", url)
	if err != nil {
		return nil, err
	}
	defer func() {
		if err != nil {
			_ = db.Close()
		}
	}()
	if _, err = db.ExecContext(ctx, "PRAGMA busy_timeout = 5000"); err != nil {
		return nil, err
	}

	return db, nil
}

type DB struct {
	ctx                  context.Context
	dataEncryptionKey    scrypto.AES256Key
	readDB               *sql.DB
	versionsHistoryLimit int
	writeDB              *sql.DB
}

func NewDB(ctx context.Context, url string) (*DB, error) {
	readDB, err := initDB(ctx, url)
	if err != nil {
		return nil, err
	}
	defer func() {
		if err != nil {
			_ = readDB.Close()
		}
	}()
	readDB.SetMaxOpenConns(max(4, runtime.NumCPU()))

	writeDB, err := initDB(ctx, url)
	if err != nil {
		return nil, err
	}
	defer func() {
		if err != nil {
			_ = writeDB.Close()
		}
	}()
	writeDB.SetMaxOpenConns(1)

	db := DB{
		ctx:                  ctx,
		readDB:               readDB,
		versionsHistoryLimit: 64,
		writeDB:              writeDB,
	}
	if _, err = db.Exec("PRAGMA foreign_keys = ON"); err != nil {
		return nil, err
	}
	if _, err = db.Exec("PRAGMA cache_size = 10000000"); err != nil {
		return nil, err
	}
	if _, err = db.Exec("PRAGMA journal_mode = WAL"); err != nil {
		return nil, err
	}
	if _, err = db.Exec("PRAGMA synchronous = NORMAL"); err != nil {
		return nil, err
	}
	if err = db.migrate(); err != nil {
		return nil, err
	}

	return &db, nil
}

func (db *DB) Close() error {
	if err := db.readDB.Close(); err != nil {
		_ = db.writeDB.Close()
	}
	return db.writeDB.Close()
}

func (db *DB) Exec(query string, args ...any) (sql.Result, error) {
	return db.writeDB.ExecContext(db.ctx, query, args...)
}

func (db *DB) QueryRow(query string, args ...any) *sql.Row {
	return db.readDB.QueryRowContext(db.ctx, query, args...)
}

func (db *DB) SetDataEncryptionKey(s string) error {
	return db.dataEncryptionKey.FromBase64(s)
}

func (db *DB) SetVersionsHistoryLimit(n int) {
	db.versionsHistoryLimit = n
}

func (db *DB) WithTransaction(f func(tx *sql.Tx) error) error {
	tx, err := db.writeDB.Begin()
	if err != nil {
		return fmt.Errorf("failed to begin transaction: %w", err)
	}
	defer func() {
		if err != nil {
			if err2 := tx.Rollback(); err2 != nil {
				panic(fmt.Sprintf("failed to rollback transaction: %+v. Reason for rollback: %+v", err2, err))
			}
		}
	}()
	if err = f(tx); err != nil {
		return fmt.Errorf("failed to execute function inside transaction: %w", err)
	} else {
		if err = tx.Commit(); err != nil {
			err = fmt.Errorf("failed to commit transaction: %w", err)
		}
	}
	return err
}