diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/database/db.go | 103 | ||||
-rw-r--r-- | pkg/database/migrations.go | 63 | ||||
-rw-r--r-- | pkg/database/sql/000_init.sql | 10 | ||||
-rw-r--r-- | pkg/database/states.go | 30 | ||||
-rw-r--r-- | pkg/scrypto/LICENSE | 13 | ||||
-rw-r--r-- | pkg/scrypto/README.md | 5 | ||||
-rw-r--r-- | pkg/scrypto/aes256.go | 138 | ||||
-rw-r--r-- | pkg/scrypto/aes256_test.go | 60 | ||||
-rw-r--r-- | pkg/scrypto/pkcs5.go | 31 | ||||
-rw-r--r-- | pkg/scrypto/pkcs5_test.go | 52 | ||||
-rw-r--r-- | pkg/scrypto/random.go | 16 |
11 files changed, 521 insertions, 0 deletions
diff --git a/pkg/database/db.go b/pkg/database/db.go new file mode 100644 index 0000000..2744570 --- /dev/null +++ b/pkg/database/db.go @@ -0,0 +1,103 @@ +package database + +import ( + "context" + "database/sql" + "runtime" + + "git.adyxax.org/adyxax/tfstated/pkg/scrypto" +) + +func initDB(ctx context.Context, url string) (*sql.DB, error) { + db, err := sql.Open("sqlite3", url) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + _ = db.Close() + } + }() + if _, err = db.ExecContext(ctx, "PRAGMA busy_timeout = 5000"); err != nil { + return nil, err + } + + return db, nil +} + +type DB struct { + ctx context.Context + dataEncryptionKey scrypto.AES256Key + readDB *sql.DB + writeDB *sql.DB +} + +func NewDB(ctx context.Context, url string) (*DB, error) { + readDB, err := initDB(ctx, url) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + _ = readDB.Close() + } + }() + readDB.SetMaxOpenConns(max(4, runtime.NumCPU())) + + writeDB, err := initDB(ctx, url) + if err != nil { + return nil, err + } + defer func() { + if err != nil { + _ = writeDB.Close() + } + }() + writeDB.SetMaxOpenConns(1) + + db := DB{ + ctx: ctx, + readDB: readDB, + writeDB: writeDB, + } + if _, err = db.Exec("PRAGMA foreign_keys = ON"); err != nil { + return nil, err + } + if _, err = db.Exec("PRAGMA cache_size = 10000000"); err != nil { + return nil, err + } + if _, err = db.Exec("PRAGMA journal_mode = WAL"); err != nil { + return nil, err + } + if _, err = db.Exec("PRAGMA synchronous = NORMAL"); err != nil { + return nil, err + } + if err = db.migrate(); err != nil { + return nil, err + } + + return &db, nil +} + +func (db *DB) Begin() (*sql.Tx, error) { + return db.writeDB.Begin() +} + +func (db *DB) Close() error { + if err := db.readDB.Close(); err != nil { + _ = db.writeDB.Close() + } + return db.writeDB.Close() +} + +func (db *DB) Exec(query string, args ...any) (sql.Result, error) { + return db.writeDB.ExecContext(db.ctx, query, args...) +} + +func (db *DB) QueryRow(query string, args ...any) *sql.Row { + return db.readDB.QueryRowContext(db.ctx, query, args...) +} + +func (db *DB) SetDataEncryptionKey(s string) error { + return db.dataEncryptionKey.FromBase64(s) +} diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go new file mode 100644 index 0000000..b460884 --- /dev/null +++ b/pkg/database/migrations.go @@ -0,0 +1,63 @@ +package database + +import ( + "embed" + "io/fs" + + _ "github.com/mattn/go-sqlite3" +) + +//go:embed sql/*.sql +var schemaFiles embed.FS + +func (db *DB) migrate() error { + statements := make([]string, 0) + err := fs.WalkDir(schemaFiles, ".", func(path string, d fs.DirEntry, err error) error { + if d.IsDir() || err != nil { + return err + } + var stmts []byte + if stmts, err = schemaFiles.ReadFile(path); err != nil { + return err + } else { + statements = append(statements, string(stmts)) + } + return nil + }) + if err != nil { + return err + } + + tx, err := db.Begin() + if err != nil { + return err + } + defer func() { + if err != nil { + _ = tx.Rollback() + } + }() + + var version int + if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil { + if err.Error() == "no such table: schema_version" { + version = 0 + } else { + return err + } + } + + for version < len(statements) { + if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil { + return err + } + version++ + } + if _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version); err != nil { + return err + } + if err = tx.Commit(); err != nil { + return err + } + return nil +} diff --git a/pkg/database/sql/000_init.sql b/pkg/database/sql/000_init.sql new file mode 100644 index 0000000..0248896 --- /dev/null +++ b/pkg/database/sql/000_init.sql @@ -0,0 +1,10 @@ +CREATE TABLE schema_version ( + version INTEGER NOT NULL +) STRICT; + +CREATE TABLE states ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + data BLOB NOT NULL +) STRICT; +CREATE UNIQUE INDEX states_name on states(name); diff --git a/pkg/database/states.go b/pkg/database/states.go new file mode 100644 index 0000000..ef2263b --- /dev/null +++ b/pkg/database/states.go @@ -0,0 +1,30 @@ +package database + +import ( + "database/sql" +) + +func (db *DB) GetState(name string) ([]byte, error) { + var encryptedData []byte + err := db.QueryRow(`SELECT data FROM states WHERE name = ?;`, name).Scan(&encryptedData) + if err != nil { + return nil, err + } + return db.dataEncryptionKey.DecryptAES256(encryptedData) +} + +func (db *DB) SetState(name string, data []byte) error { + encryptedData, err := db.dataEncryptionKey.EncryptAES256(data) + if err != nil { + return err + } + _, err = db.Exec( + `INSERT INTO states(name, data) VALUES (:name, :data) ON CONFLICT DO UPDATE SET data = :data WHERE name = :name;`, + sql.Named("data", encryptedData), + sql.Named("name", name), + ) + if err != nil { + return err + } + return nil +} diff --git a/pkg/scrypto/LICENSE b/pkg/scrypto/LICENSE new file mode 100644 index 0000000..aeafaf4 --- /dev/null +++ b/pkg/scrypto/LICENSE @@ -0,0 +1,13 @@ +Copyright (c) 2023-2024 Nicolas Martyanoff <nicolas@n16f.net> + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted, provided that the above +copyright notice and this permission notice appear in all copies. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR +OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +PERFORMANCE OF THIS SOFTWARE. diff --git a/pkg/scrypto/README.md b/pkg/scrypto/README.md new file mode 100644 index 0000000..e83bc8e --- /dev/null +++ b/pkg/scrypto/README.md @@ -0,0 +1,5 @@ +Original code imported from https://github.com/galdor/go-service/tree/master/pkg/scrypto + +Alterations: +- converted the EncryptAES256 and DecryptAES256 functions to methods +- rewrote the tests to get rid of the dependency on testify diff --git a/pkg/scrypto/aes256.go b/pkg/scrypto/aes256.go new file mode 100644 index 0000000..27a39d6 --- /dev/null +++ b/pkg/scrypto/aes256.go @@ -0,0 +1,138 @@ +package scrypto + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" +) + +type AES256Key [32]byte + +const ( + AES256IVSize int = aes.BlockSize +) + +var Zero = AES256Key{} + +func (key *AES256Key) Bytes() []byte { + return key[:] +} + +func (key *AES256Key) IsZero() bool { + return bytes.Equal(key.Bytes(), Zero.Bytes()) +} + +func (key *AES256Key) Hex() string { + return hex.EncodeToString(key[:]) +} + +func (key *AES256Key) FromHex(s string) error { + data, err := hex.DecodeString(s) + if err != nil { + return err + } + + if len(data) != 32 { + return fmt.Errorf("invalid key size") + } + + copy((*key)[:], data[:32]) + + return nil +} + +func (key *AES256Key) FromBase64(s string) error { + data, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return err + } + + if len(data) != 32 { + return fmt.Errorf("invalid key size") + } + + copy((*key)[:], data[:32]) + + return nil +} + +func (key *AES256Key) MarshalJSON() ([]byte, error) { + s := base64.StdEncoding.EncodeToString(key[:]) + return json.Marshal(s) +} + +func (pkey *AES256Key) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + + key, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return fmt.Errorf("cannot decode base64 value: %w", err) + } + + if len(key) != 32 { + return fmt.Errorf("invalid key size (must be 32 bytes long)") + } + + copy(pkey[:], key) + + return nil +} + +func (key *AES256Key) EncryptAES256(inputData []byte) ([]byte, error) { + blockCipher, err := aes.NewCipher(key[:]) + if err != nil { + return nil, fmt.Errorf("cannot create cipher: %w", err) + } + + paddedData := PadPKCS5(inputData, aes.BlockSize) + + outputData := make([]byte, AES256IVSize+len(paddedData)) + + iv := outputData[:AES256IVSize] + encryptedData := outputData[AES256IVSize:] + + if _, err := rand.Read(iv); err != nil { + return nil, fmt.Errorf("cannot generate iv: %w", err) + } + + encrypter := cipher.NewCBCEncrypter(blockCipher, iv) + encrypter.CryptBlocks(encryptedData, paddedData) + + return outputData, nil +} + +func (key *AES256Key) DecryptAES256(inputData []byte) ([]byte, error) { + blockCipher, err := aes.NewCipher(key[:]) + if err != nil { + return nil, fmt.Errorf("cannot create cipher: %w", err) + } + + if len(inputData) < AES256IVSize { + return nil, fmt.Errorf("truncated data") + } + + iv := inputData[:AES256IVSize] + paddedData := inputData[AES256IVSize:] + + if len(paddedData)%aes.BlockSize != 0 { + return nil, fmt.Errorf("invalid padded data length") + } + + decrypter := cipher.NewCBCDecrypter(blockCipher, iv) + decrypter.CryptBlocks(paddedData, paddedData) + + outputData, err := UnpadPKCS5(paddedData, aes.BlockSize) + if err != nil { + return nil, fmt.Errorf("invalid padded data: %w", err) + } + + return outputData, nil +} diff --git a/pkg/scrypto/aes256_test.go b/pkg/scrypto/aes256_test.go new file mode 100644 index 0000000..9914625 --- /dev/null +++ b/pkg/scrypto/aes256_test.go @@ -0,0 +1,60 @@ +package scrypto + +import ( + "encoding/hex" + "slices" + "strings" + "testing" +) + +func TestAES256KeyHex(t *testing.T) { + testKeyHex := "28278b7c0a25f01d3cab639633b9487f9ea1e9a2176dc9595a3f01323aa44284" + testKey, _ := hex.DecodeString(testKeyHex) + + var key AES256Key + if err := key.FromHex(testKeyHex); err != nil { + t.Errorf("got unexpected error %+v", err) + } + if slices.Compare(testKey, key[:]) != 0 { + t.Errorf("got %v, wanted %v", testKey, key[:]) + } + if strings.Compare(testKeyHex, key.Hex()) != 0 { + t.Errorf("got %v, wanted %v", testKeyHex, key.Hex()) + } +} + +func TestAES256(t *testing.T) { + keyHex := "28278b7c0a25f01d3cab639633b9487f9ea1e9a2176dc9595a3f01323aa44284" + var key AES256Key + if err := key.FromHex(keyHex); err != nil { + t.Errorf("got unexpected error %+v", err) + } + + data := []byte("Hello world!") + encryptedData, err := key.EncryptAES256(data) + if err != nil { + t.Errorf("got unexpected error when encrypting data %+v", err) + } + + decryptedData, err := key.DecryptAES256(encryptedData) + if err != nil { + t.Errorf("got unexpected error when decrypting data %+v", err) + } + + if slices.Compare(data, decryptedData) != 0 { + t.Errorf("got %v, wanted %v", decryptedData, data) + } +} + +func TestAES256InvalidData(t *testing.T) { + keyHex := "28278b7c0a25f01d3cab639633b9487f9ea1e9a2176dc9595a3f01323aa44284" + var key AES256Key + if err := key.FromHex(keyHex); err != nil { + t.Errorf("got unexpected error when converting data from base64: %+v", err) + } + + iv := make([]byte, AES256IVSize) + if _, err := key.DecryptAES256(append(iv, []byte("foo")...)); err == nil { + t.Error("decrypting operation should have failed") + } +} diff --git a/pkg/scrypto/pkcs5.go b/pkg/scrypto/pkcs5.go new file mode 100644 index 0000000..b41c564 --- /dev/null +++ b/pkg/scrypto/pkcs5.go @@ -0,0 +1,31 @@ +package scrypto + +import ( + "bytes" + "fmt" +) + +func PadPKCS5(data []byte, blockSize int) []byte { + paddingSize := blockSize - len(data)%blockSize + padding := bytes.Repeat([]byte{byte(paddingSize)}, paddingSize) + + return append(data, padding...) +} + +func UnpadPKCS5(data []byte, blockSize int) ([]byte, error) { + dataSize := len(data) + if dataSize%blockSize != 0 { + return nil, fmt.Errorf("truncated data") + } + + if len(data) == 0 { + return data, nil + } else { + paddingSize := int(data[dataSize-1]) + if paddingSize > dataSize || paddingSize > blockSize { + return nil, fmt.Errorf("invalid padding size %d", paddingSize) + } + + return data[:dataSize-paddingSize], nil + } +} diff --git a/pkg/scrypto/pkcs5_test.go b/pkg/scrypto/pkcs5_test.go new file mode 100644 index 0000000..8bd1a1d --- /dev/null +++ b/pkg/scrypto/pkcs5_test.go @@ -0,0 +1,52 @@ +package scrypto + +import ( + "slices" + "testing" +) + +func TestPadPKCS5(t *testing.T) { + tests := []struct { + data []byte + want []byte + }{ + {data: []byte(""), want: []byte("\x04\x04\x04\x04")}, + {data: []byte("a"), want: []byte("a\x03\x03\x03")}, + {data: []byte("ab"), want: []byte("ab\x02\x02")}, + {data: []byte("abc"), want: []byte("abc\x01")}, + {data: []byte("abcd"), want: []byte("abcd\x04\x04\x04\x04")}, + {data: []byte("abcde"), want: []byte("abcde\x03\x03\x03")}, + {data: []byte("abcdefgh"), want: []byte("abcdefgh\x04\x04\x04\x04")}, + } + for _, tt := range tests { + got := PadPKCS5(tt.data, 4) + if slices.Compare(got, tt.want) != 0 { + t.Errorf("got %v, want %v", got, tt.want) + } + } +} + +func TestUnpadPKCS5(t *testing.T) { + tests := []struct { + data []byte + want []byte + }{ + {data: []byte("\x04\x04\x04\x04"), want: []byte("")}, + {data: []byte("a\x03\x03\x03"), want: []byte("a")}, + {data: []byte("ab\x02\x02"), want: []byte("ab")}, + {data: []byte("abc\x01"), want: []byte("abc")}, + {data: []byte("abcd\x04\x04\x04\x04"), want: []byte("abcd")}, + {data: []byte("abcde\x03\x03\x03"), want: []byte("abcde")}, + {data: []byte("abcdefgh\x04\x04\x04\x04"), want: []byte("abcdefgh")}, + } + for _, tt := range tests { + got, err := UnpadPKCS5(tt.data, 4) + if err != nil { + t.Errorf("got err %+v while wanted %v", err, tt.want) + } else { + if slices.Compare(got, tt.want) != 0 { + t.Errorf("got %v, want %v", got, tt.want) + } + } + } +} diff --git a/pkg/scrypto/random.go b/pkg/scrypto/random.go new file mode 100644 index 0000000..47d9557 --- /dev/null +++ b/pkg/scrypto/random.go @@ -0,0 +1,16 @@ +package scrypto + +import ( + "crypto/rand" + "fmt" +) + +func RandomBytes(n int) []byte { + data := make([]byte, n) + + if _, err := rand.Read(data); err != nil { + panic(fmt.Sprintf("cannot generate random data: %+v", err)) + } + + return data +} |