feat(tfstated): implement GET and POST methods, states are encrypted in a sqlite3 database

This commit is contained in:
Julien Dessaux 2024-09-30 00:58:49 +02:00
parent baf5aac08e
commit 4ff490806c
Signed by: adyxax
GPG key ID: F92E51B86E07177E
18 changed files with 627 additions and 2 deletions

3
.gitignore vendored
View file

@ -1 +1,4 @@
tfstate tfstate
tfstate.db
tfstate.db-shm
tfstate.db-wal

35
cmd/tfstated/get.go Normal file
View file

@ -0,0 +1,35 @@
package main
import (
"database/sql"
"errors"
"fmt"
"net/http"
"git.adyxax.org/adyxax/tfstated/pkg/database"
)
func handleGet(db *database.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Cache-Control", "no-store, no-cache")
w.Header().Set("Content-Type", "application/json")
if r.URL.Path == "/" {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("{\"msg\": \"No state path provided, cannot GET /\"}"))
return
}
if data, err := db.GetState(r.URL.Path); err != nil {
if errors.Is(err, sql.ErrNoRows) {
w.WriteHeader(http.StatusNotFound)
} else {
w.WriteHeader(http.StatusInternalServerError)
}
_, _ = w.Write([]byte(fmt.Sprintf("{\"msg\": \"%+v\"}", err)))
} else {
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
}
})
}

View file

@ -12,6 +12,8 @@ import (
"os/signal" "os/signal"
"sync" "sync"
"time" "time"
"git.adyxax.org/adyxax/tfstated/pkg/database"
) )
type Config struct { type Config struct {
@ -22,6 +24,7 @@ type Config struct {
func run( func run(
ctx context.Context, ctx context.Context,
config *Config, config *Config,
db *database.DB,
args []string, args []string,
getenv func(string) string, getenv func(string) string,
stdin io.Reader, stdin io.Reader,
@ -30,10 +33,20 @@ func run(
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt) ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel() defer cancel()
dataEncryptionKey := getenv("DATA_ENCRYPTION_KEY")
if dataEncryptionKey == "" {
return fmt.Errorf("the DATA_ENCRYPTION_KEY environment variable is not set")
}
if err := db.SetDataEncryptionKey(dataEncryptionKey); err != nil {
return err
}
mux := http.NewServeMux() mux := http.NewServeMux()
addRoutes( addRoutes(
mux, mux,
db,
) )
httpServer := &http.Server{ httpServer := &http.Server{
Addr: net.JoinHostPort(config.Host, config.Port), Addr: net.JoinHostPort(config.Host, config.Port),
Handler: mux, Handler: mux,
@ -79,9 +92,17 @@ func main() {
Port: "8080", Port: "8080",
} }
db, err := database.NewDB(ctx, "./tfstate.db?_txlock=immediate")
if err != nil {
fmt.Fprintf(os.Stderr, "database init error: %+v\n", err)
os.Exit(1)
}
defer db.Close()
if err := run( if err := run(
ctx, ctx,
&config, &config,
db,
os.Args, os.Args,
os.Getenv, os.Getenv,
os.Stdin, os.Stdin,

32
cmd/tfstated/post.go Normal file
View file

@ -0,0 +1,32 @@
package main
import (
"fmt"
"io"
"net/http"
"git.adyxax.org/adyxax/tfstated/pkg/database"
)
func handlePost(db *database.DB) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if r.URL.Path == "/" {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("{\"msg\": \"No state path provided, cannot POST /\"}"))
return
}
data, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte(fmt.Sprintf("{\"msg\": \"failed to read request body: %+v\"}", err)))
return
}
if err := db.SetState(r.URL.Path, data); err != nil {
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte(fmt.Sprintf("{\"msg\": \"%+v\"}", err)))
} else {
w.WriteHeader(http.StatusOK)
}
})
}

View file

@ -1,7 +1,16 @@
package main package main
import "net/http" import (
"net/http"
func addRoutes(mux *http.ServeMux) { "git.adyxax.org/adyxax/tfstated/pkg/database"
)
func addRoutes(
mux *http.ServeMux,
db *database.DB,
) {
mux.Handle("GET /healthz", handleHealthz()) mux.Handle("GET /healthz", handleHealthz())
mux.Handle("GET /", handleGet(db))
mux.Handle("POST /", handlePost(db))
} }

2
go.mod
View file

@ -1,3 +1,5 @@
module git.adyxax.org/adyxax/tfstated module git.adyxax.org/adyxax/tfstated
go 1.23.1 go 1.23.1
require github.com/mattn/go-sqlite3 v1.14.23

2
go.sum Normal file
View file

@ -0,0 +1,2 @@
github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0=
github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=

103
pkg/database/db.go Normal file
View file

@ -0,0 +1,103 @@
package database
import (
"context"
"database/sql"
"runtime"
"git.adyxax.org/adyxax/tfstated/pkg/scrypto"
)
func initDB(ctx context.Context, url string) (*sql.DB, error) {
db, err := sql.Open("sqlite3", url)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = db.Close()
}
}()
if _, err = db.ExecContext(ctx, "PRAGMA busy_timeout = 5000"); err != nil {
return nil, err
}
return db, nil
}
type DB struct {
ctx context.Context
dataEncryptionKey scrypto.AES256Key
readDB *sql.DB
writeDB *sql.DB
}
func NewDB(ctx context.Context, url string) (*DB, error) {
readDB, err := initDB(ctx, url)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = readDB.Close()
}
}()
readDB.SetMaxOpenConns(max(4, runtime.NumCPU()))
writeDB, err := initDB(ctx, url)
if err != nil {
return nil, err
}
defer func() {
if err != nil {
_ = writeDB.Close()
}
}()
writeDB.SetMaxOpenConns(1)
db := DB{
ctx: ctx,
readDB: readDB,
writeDB: writeDB,
}
if _, err = db.Exec("PRAGMA foreign_keys = ON"); err != nil {
return nil, err
}
if _, err = db.Exec("PRAGMA cache_size = 10000000"); err != nil {
return nil, err
}
if _, err = db.Exec("PRAGMA journal_mode = WAL"); err != nil {
return nil, err
}
if _, err = db.Exec("PRAGMA synchronous = NORMAL"); err != nil {
return nil, err
}
if err = db.migrate(); err != nil {
return nil, err
}
return &db, nil
}
func (db *DB) Begin() (*sql.Tx, error) {
return db.writeDB.Begin()
}
func (db *DB) Close() error {
if err := db.readDB.Close(); err != nil {
_ = db.writeDB.Close()
}
return db.writeDB.Close()
}
func (db *DB) Exec(query string, args ...any) (sql.Result, error) {
return db.writeDB.ExecContext(db.ctx, query, args...)
}
func (db *DB) QueryRow(query string, args ...any) *sql.Row {
return db.readDB.QueryRowContext(db.ctx, query, args...)
}
func (db *DB) SetDataEncryptionKey(s string) error {
return db.dataEncryptionKey.FromBase64(s)
}

View file

@ -0,0 +1,63 @@
package database
import (
"embed"
"io/fs"
_ "github.com/mattn/go-sqlite3"
)
//go:embed sql/*.sql
var schemaFiles embed.FS
func (db *DB) migrate() error {
statements := make([]string, 0)
err := fs.WalkDir(schemaFiles, ".", func(path string, d fs.DirEntry, err error) error {
if d.IsDir() || err != nil {
return err
}
var stmts []byte
if stmts, err = schemaFiles.ReadFile(path); err != nil {
return err
} else {
statements = append(statements, string(stmts))
}
return nil
})
if err != nil {
return err
}
tx, err := db.Begin()
if err != nil {
return err
}
defer func() {
if err != nil {
_ = tx.Rollback()
}
}()
var version int
if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil {
if err.Error() == "no such table: schema_version" {
version = 0
} else {
return err
}
}
for version < len(statements) {
if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil {
return err
}
version++
}
if _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version); err != nil {
return err
}
if err = tx.Commit(); err != nil {
return err
}
return nil
}

View file

@ -0,0 +1,10 @@
CREATE TABLE schema_version (
version INTEGER NOT NULL
) STRICT;
CREATE TABLE states (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
data BLOB NOT NULL
) STRICT;
CREATE UNIQUE INDEX states_name on states(name);

30
pkg/database/states.go Normal file
View file

@ -0,0 +1,30 @@
package database
import (
"database/sql"
)
func (db *DB) GetState(name string) ([]byte, error) {
var encryptedData []byte
err := db.QueryRow(`SELECT data FROM states WHERE name = ?;`, name).Scan(&encryptedData)
if err != nil {
return nil, err
}
return db.dataEncryptionKey.DecryptAES256(encryptedData)
}
func (db *DB) SetState(name string, data []byte) error {
encryptedData, err := db.dataEncryptionKey.EncryptAES256(data)
if err != nil {
return err
}
_, err = db.Exec(
`INSERT INTO states(name, data) VALUES (:name, :data) ON CONFLICT DO UPDATE SET data = :data WHERE name = :name;`,
sql.Named("data", encryptedData),
sql.Named("name", name),
)
if err != nil {
return err
}
return nil
}

13
pkg/scrypto/LICENSE Normal file
View file

@ -0,0 +1,13 @@
Copyright (c) 2023-2024 Nicolas Martyanoff <nicolas@n16f.net>
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
PERFORMANCE OF THIS SOFTWARE.

5
pkg/scrypto/README.md Normal file
View file

@ -0,0 +1,5 @@
Original code imported from https://github.com/galdor/go-service/tree/master/pkg/scrypto
Alterations:
- converted the EncryptAES256 and DecryptAES256 functions to methods
- rewrote the tests to get rid of the dependency on testify

138
pkg/scrypto/aes256.go Normal file
View file

@ -0,0 +1,138 @@
package scrypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
)
type AES256Key [32]byte
const (
AES256IVSize int = aes.BlockSize
)
var Zero = AES256Key{}
func (key *AES256Key) Bytes() []byte {
return key[:]
}
func (key *AES256Key) IsZero() bool {
return bytes.Equal(key.Bytes(), Zero.Bytes())
}
func (key *AES256Key) Hex() string {
return hex.EncodeToString(key[:])
}
func (key *AES256Key) FromHex(s string) error {
data, err := hex.DecodeString(s)
if err != nil {
return err
}
if len(data) != 32 {
return fmt.Errorf("invalid key size")
}
copy((*key)[:], data[:32])
return nil
}
func (key *AES256Key) FromBase64(s string) error {
data, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return err
}
if len(data) != 32 {
return fmt.Errorf("invalid key size")
}
copy((*key)[:], data[:32])
return nil
}
func (key *AES256Key) MarshalJSON() ([]byte, error) {
s := base64.StdEncoding.EncodeToString(key[:])
return json.Marshal(s)
}
func (pkey *AES256Key) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
key, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return fmt.Errorf("cannot decode base64 value: %w", err)
}
if len(key) != 32 {
return fmt.Errorf("invalid key size (must be 32 bytes long)")
}
copy(pkey[:], key)
return nil
}
func (key *AES256Key) EncryptAES256(inputData []byte) ([]byte, error) {
blockCipher, err := aes.NewCipher(key[:])
if err != nil {
return nil, fmt.Errorf("cannot create cipher: %w", err)
}
paddedData := PadPKCS5(inputData, aes.BlockSize)
outputData := make([]byte, AES256IVSize+len(paddedData))
iv := outputData[:AES256IVSize]
encryptedData := outputData[AES256IVSize:]
if _, err := rand.Read(iv); err != nil {
return nil, fmt.Errorf("cannot generate iv: %w", err)
}
encrypter := cipher.NewCBCEncrypter(blockCipher, iv)
encrypter.CryptBlocks(encryptedData, paddedData)
return outputData, nil
}
func (key *AES256Key) DecryptAES256(inputData []byte) ([]byte, error) {
blockCipher, err := aes.NewCipher(key[:])
if err != nil {
return nil, fmt.Errorf("cannot create cipher: %w", err)
}
if len(inputData) < AES256IVSize {
return nil, fmt.Errorf("truncated data")
}
iv := inputData[:AES256IVSize]
paddedData := inputData[AES256IVSize:]
if len(paddedData)%aes.BlockSize != 0 {
return nil, fmt.Errorf("invalid padded data length")
}
decrypter := cipher.NewCBCDecrypter(blockCipher, iv)
decrypter.CryptBlocks(paddedData, paddedData)
outputData, err := UnpadPKCS5(paddedData, aes.BlockSize)
if err != nil {
return nil, fmt.Errorf("invalid padded data: %w", err)
}
return outputData, nil
}

View file

@ -0,0 +1,60 @@
package scrypto
import (
"encoding/hex"
"slices"
"strings"
"testing"
)
func TestAES256KeyHex(t *testing.T) {
testKeyHex := "28278b7c0a25f01d3cab639633b9487f9ea1e9a2176dc9595a3f01323aa44284"
testKey, _ := hex.DecodeString(testKeyHex)
var key AES256Key
if err := key.FromHex(testKeyHex); err != nil {
t.Errorf("got unexpected error %+v", err)
}
if slices.Compare(testKey, key[:]) != 0 {
t.Errorf("got %v, wanted %v", testKey, key[:])
}
if strings.Compare(testKeyHex, key.Hex()) != 0 {
t.Errorf("got %v, wanted %v", testKeyHex, key.Hex())
}
}
func TestAES256(t *testing.T) {
keyHex := "28278b7c0a25f01d3cab639633b9487f9ea1e9a2176dc9595a3f01323aa44284"
var key AES256Key
if err := key.FromHex(keyHex); err != nil {
t.Errorf("got unexpected error %+v", err)
}
data := []byte("Hello world!")
encryptedData, err := key.EncryptAES256(data)
if err != nil {
t.Errorf("got unexpected error when encrypting data %+v", err)
}
decryptedData, err := key.DecryptAES256(encryptedData)
if err != nil {
t.Errorf("got unexpected error when decrypting data %+v", err)
}
if slices.Compare(data, decryptedData) != 0 {
t.Errorf("got %v, wanted %v", decryptedData, data)
}
}
func TestAES256InvalidData(t *testing.T) {
keyHex := "28278b7c0a25f01d3cab639633b9487f9ea1e9a2176dc9595a3f01323aa44284"
var key AES256Key
if err := key.FromHex(keyHex); err != nil {
t.Errorf("got unexpected error when converting data from base64: %+v", err)
}
iv := make([]byte, AES256IVSize)
if _, err := key.DecryptAES256(append(iv, []byte("foo")...)); err == nil {
t.Error("decrypting operation should have failed")
}
}

31
pkg/scrypto/pkcs5.go Normal file
View file

@ -0,0 +1,31 @@
package scrypto
import (
"bytes"
"fmt"
)
func PadPKCS5(data []byte, blockSize int) []byte {
paddingSize := blockSize - len(data)%blockSize
padding := bytes.Repeat([]byte{byte(paddingSize)}, paddingSize)
return append(data, padding...)
}
func UnpadPKCS5(data []byte, blockSize int) ([]byte, error) {
dataSize := len(data)
if dataSize%blockSize != 0 {
return nil, fmt.Errorf("truncated data")
}
if len(data) == 0 {
return data, nil
} else {
paddingSize := int(data[dataSize-1])
if paddingSize > dataSize || paddingSize > blockSize {
return nil, fmt.Errorf("invalid padding size %d", paddingSize)
}
return data[:dataSize-paddingSize], nil
}
}

52
pkg/scrypto/pkcs5_test.go Normal file
View file

@ -0,0 +1,52 @@
package scrypto
import (
"slices"
"testing"
)
func TestPadPKCS5(t *testing.T) {
tests := []struct {
data []byte
want []byte
}{
{data: []byte(""), want: []byte("\x04\x04\x04\x04")},
{data: []byte("a"), want: []byte("a\x03\x03\x03")},
{data: []byte("ab"), want: []byte("ab\x02\x02")},
{data: []byte("abc"), want: []byte("abc\x01")},
{data: []byte("abcd"), want: []byte("abcd\x04\x04\x04\x04")},
{data: []byte("abcde"), want: []byte("abcde\x03\x03\x03")},
{data: []byte("abcdefgh"), want: []byte("abcdefgh\x04\x04\x04\x04")},
}
for _, tt := range tests {
got := PadPKCS5(tt.data, 4)
if slices.Compare(got, tt.want) != 0 {
t.Errorf("got %v, want %v", got, tt.want)
}
}
}
func TestUnpadPKCS5(t *testing.T) {
tests := []struct {
data []byte
want []byte
}{
{data: []byte("\x04\x04\x04\x04"), want: []byte("")},
{data: []byte("a\x03\x03\x03"), want: []byte("a")},
{data: []byte("ab\x02\x02"), want: []byte("ab")},
{data: []byte("abc\x01"), want: []byte("abc")},
{data: []byte("abcd\x04\x04\x04\x04"), want: []byte("abcd")},
{data: []byte("abcde\x03\x03\x03"), want: []byte("abcde")},
{data: []byte("abcdefgh\x04\x04\x04\x04"), want: []byte("abcdefgh")},
}
for _, tt := range tests {
got, err := UnpadPKCS5(tt.data, 4)
if err != nil {
t.Errorf("got err %+v while wanted %v", err, tt.want)
} else {
if slices.Compare(got, tt.want) != 0 {
t.Errorf("got %v, want %v", got, tt.want)
}
}
}
}

16
pkg/scrypto/random.go Normal file
View file

@ -0,0 +1,16 @@
package scrypto
import (
"crypto/rand"
"fmt"
)
func RandomBytes(n int) []byte {
data := make([]byte, n)
if _, err := rand.Read(data); err != nil {
panic(fmt.Sprintf("cannot generate random data: %+v", err))
}
return data
}