summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJulien Dessaux2024-09-30 00:58:49 +0200
committerJulien Dessaux2024-09-30 01:00:59 +0200
commit4ff490806c826cf2da4c2291ed924f0a49383fce (patch)
tree6870f4883cd03a824095b969500f08fb59f04038
parentchore(tfstated): rename tfstate to tfstated (diff)
downloadtfstated-4ff490806c826cf2da4c2291ed924f0a49383fce.tar.gz
tfstated-4ff490806c826cf2da4c2291ed924f0a49383fce.tar.bz2
tfstated-4ff490806c826cf2da4c2291ed924f0a49383fce.zip
feat(tfstated): implement GET and POST methods, states are encrypted in a sqlite3 database
-rw-r--r--.gitignore3
-rw-r--r--cmd/tfstated/get.go35
-rw-r--r--cmd/tfstated/main.go21
-rw-r--r--cmd/tfstated/post.go32
-rw-r--r--cmd/tfstated/routes.go13
-rw-r--r--go.mod2
-rw-r--r--go.sum2
-rw-r--r--pkg/database/db.go103
-rw-r--r--pkg/database/migrations.go63
-rw-r--r--pkg/database/sql/000_init.sql10
-rw-r--r--pkg/database/states.go30
-rw-r--r--pkg/scrypto/LICENSE13
-rw-r--r--pkg/scrypto/README.md5
-rw-r--r--pkg/scrypto/aes256.go138
-rw-r--r--pkg/scrypto/aes256_test.go60
-rw-r--r--pkg/scrypto/pkcs5.go31
-rw-r--r--pkg/scrypto/pkcs5_test.go52
-rw-r--r--pkg/scrypto/random.go16
18 files changed, 627 insertions, 2 deletions
diff --git a/.gitignore b/.gitignore
index 1788154..8bdbc77 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,4 @@
tfstate
+tfstate.db
+tfstate.db-shm
+tfstate.db-wal
diff --git a/cmd/tfstated/get.go b/cmd/tfstated/get.go
new file mode 100644
index 0000000..6de78a5
--- /dev/null
+++ b/cmd/tfstated/get.go
@@ -0,0 +1,35 @@
+package main
+
+import (
+ "database/sql"
+ "errors"
+ "fmt"
+ "net/http"
+
+ "git.adyxax.org/adyxax/tfstated/pkg/database"
+)
+
+func handleGet(db *database.DB) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Cache-Control", "no-store, no-cache")
+ w.Header().Set("Content-Type", "application/json")
+
+ if r.URL.Path == "/" {
+ w.WriteHeader(http.StatusBadRequest)
+ _, _ = w.Write([]byte("{\"msg\": \"No state path provided, cannot GET /\"}"))
+ return
+ }
+
+ if data, err := db.GetState(r.URL.Path); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ w.WriteHeader(http.StatusNotFound)
+ } else {
+ w.WriteHeader(http.StatusInternalServerError)
+ }
+ _, _ = w.Write([]byte(fmt.Sprintf("{\"msg\": \"%+v\"}", err)))
+ } else {
+ w.WriteHeader(http.StatusOK)
+ _, _ = w.Write(data)
+ }
+ })
+}
diff --git a/cmd/tfstated/main.go b/cmd/tfstated/main.go
index e0b33a5..61248bc 100644
--- a/cmd/tfstated/main.go
+++ b/cmd/tfstated/main.go
@@ -12,6 +12,8 @@ import (
"os/signal"
"sync"
"time"
+
+ "git.adyxax.org/adyxax/tfstated/pkg/database"
)
type Config struct {
@@ -22,6 +24,7 @@ type Config struct {
func run(
ctx context.Context,
config *Config,
+ db *database.DB,
args []string,
getenv func(string) string,
stdin io.Reader,
@@ -30,10 +33,20 @@ func run(
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()
+ dataEncryptionKey := getenv("DATA_ENCRYPTION_KEY")
+ if dataEncryptionKey == "" {
+ return fmt.Errorf("the DATA_ENCRYPTION_KEY environment variable is not set")
+ }
+ if err := db.SetDataEncryptionKey(dataEncryptionKey); err != nil {
+ return err
+ }
+
mux := http.NewServeMux()
addRoutes(
mux,
+ db,
)
+
httpServer := &http.Server{
Addr: net.JoinHostPort(config.Host, config.Port),
Handler: mux,
@@ -79,9 +92,17 @@ func main() {
Port: "8080",
}
+ db, err := database.NewDB(ctx, "./tfstate.db?_txlock=immediate")
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "database init error: %+v\n", err)
+ os.Exit(1)
+ }
+ defer db.Close()
+
if err := run(
ctx,
&config,
+ db,
os.Args,
os.Getenv,
os.Stdin,
diff --git a/cmd/tfstated/post.go b/cmd/tfstated/post.go
new file mode 100644
index 0000000..e47f5ad
--- /dev/null
+++ b/cmd/tfstated/post.go
@@ -0,0 +1,32 @@
+package main
+
+import (
+ "fmt"
+ "io"
+ "net/http"
+
+ "git.adyxax.org/adyxax/tfstated/pkg/database"
+)
+
+func handlePost(db *database.DB) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ if r.URL.Path == "/" {
+ w.WriteHeader(http.StatusBadRequest)
+ _, _ = w.Write([]byte("{\"msg\": \"No state path provided, cannot POST /\"}"))
+ return
+ }
+ data, err := io.ReadAll(r.Body)
+ if err != nil {
+ w.WriteHeader(http.StatusBadRequest)
+ _, _ = w.Write([]byte(fmt.Sprintf("{\"msg\": \"failed to read request body: %+v\"}", err)))
+ return
+ }
+ if err := db.SetState(r.URL.Path, data); err != nil {
+ w.WriteHeader(http.StatusInternalServerError)
+ _, _ = w.Write([]byte(fmt.Sprintf("{\"msg\": \"%+v\"}", err)))
+ } else {
+ w.WriteHeader(http.StatusOK)
+ }
+ })
+}
diff --git a/cmd/tfstated/routes.go b/cmd/tfstated/routes.go
index f583fc7..32ab111 100644
--- a/cmd/tfstated/routes.go
+++ b/cmd/tfstated/routes.go
@@ -1,7 +1,16 @@
package main
-import "net/http"
+import (
+ "net/http"
-func addRoutes(mux *http.ServeMux) {
+ "git.adyxax.org/adyxax/tfstated/pkg/database"
+)
+
+func addRoutes(
+ mux *http.ServeMux,
+ db *database.DB,
+) {
mux.Handle("GET /healthz", handleHealthz())
+ mux.Handle("GET /", handleGet(db))
+ mux.Handle("POST /", handlePost(db))
}
diff --git a/go.mod b/go.mod
index 9a846e2..14aa55d 100644
--- a/go.mod
+++ b/go.mod
@@ -1,3 +1,5 @@
module git.adyxax.org/adyxax/tfstated
go 1.23.1
+
+require github.com/mattn/go-sqlite3 v1.14.23
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..32531fa
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,2 @@
+github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0=
+github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
diff --git a/pkg/database/db.go b/pkg/database/db.go
new file mode 100644
index 0000000..2744570
--- /dev/null
+++ b/pkg/database/db.go
@@ -0,0 +1,103 @@
+package database
+
+import (
+ "context"
+ "database/sql"
+ "runtime"
+
+ "git.adyxax.org/adyxax/tfstated/pkg/scrypto"
+)
+
+func initDB(ctx context.Context, url string) (*sql.DB, error) {
+ db, err := sql.Open("sqlite3", url)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ if err != nil {
+ _ = db.Close()
+ }
+ }()
+ if _, err = db.ExecContext(ctx, "PRAGMA busy_timeout = 5000"); err != nil {
+ return nil, err
+ }
+
+ return db, nil
+}
+
+type DB struct {
+ ctx context.Context
+ dataEncryptionKey scrypto.AES256Key
+ readDB *sql.DB
+ writeDB *sql.DB
+}
+
+func NewDB(ctx context.Context, url string) (*DB, error) {
+ readDB, err := initDB(ctx, url)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ if err != nil {
+ _ = readDB.Close()
+ }
+ }()
+ readDB.SetMaxOpenConns(max(4, runtime.NumCPU()))
+
+ writeDB, err := initDB(ctx, url)
+ if err != nil {
+ return nil, err
+ }
+ defer func() {
+ if err != nil {
+ _ = writeDB.Close()
+ }
+ }()
+ writeDB.SetMaxOpenConns(1)
+
+ db := DB{
+ ctx: ctx,
+ readDB: readDB,
+ writeDB: writeDB,
+ }
+ if _, err = db.Exec("PRAGMA foreign_keys = ON"); err != nil {
+ return nil, err
+ }
+ if _, err = db.Exec("PRAGMA cache_size = 10000000"); err != nil {
+ return nil, err
+ }
+ if _, err = db.Exec("PRAGMA journal_mode = WAL"); err != nil {
+ return nil, err
+ }
+ if _, err = db.Exec("PRAGMA synchronous = NORMAL"); err != nil {
+ return nil, err
+ }
+ if err = db.migrate(); err != nil {
+ return nil, err
+ }
+
+ return &db, nil
+}
+
+func (db *DB) Begin() (*sql.Tx, error) {
+ return db.writeDB.Begin()
+}
+
+func (db *DB) Close() error {
+ if err := db.readDB.Close(); err != nil {
+ _ = db.writeDB.Close()
+ }
+ return db.writeDB.Close()
+}
+
+func (db *DB) Exec(query string, args ...any) (sql.Result, error) {
+ return db.writeDB.ExecContext(db.ctx, query, args...)
+}
+
+func (db *DB) QueryRow(query string, args ...any) *sql.Row {
+ return db.readDB.QueryRowContext(db.ctx, query, args...)
+}
+
+func (db *DB) SetDataEncryptionKey(s string) error {
+ return db.dataEncryptionKey.FromBase64(s)
+}
diff --git a/pkg/database/migrations.go b/pkg/database/migrations.go
new file mode 100644
index 0000000..b460884
--- /dev/null
+++ b/pkg/database/migrations.go
@@ -0,0 +1,63 @@
+package database
+
+import (
+ "embed"
+ "io/fs"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+//go:embed sql/*.sql
+var schemaFiles embed.FS
+
+func (db *DB) migrate() error {
+ statements := make([]string, 0)
+ err := fs.WalkDir(schemaFiles, ".", func(path string, d fs.DirEntry, err error) error {
+ if d.IsDir() || err != nil {
+ return err
+ }
+ var stmts []byte
+ if stmts, err = schemaFiles.ReadFile(path); err != nil {
+ return err
+ } else {
+ statements = append(statements, string(stmts))
+ }
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+
+ tx, err := db.Begin()
+ if err != nil {
+ return err
+ }
+ defer func() {
+ if err != nil {
+ _ = tx.Rollback()
+ }
+ }()
+
+ var version int
+ if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil {
+ if err.Error() == "no such table: schema_version" {
+ version = 0
+ } else {
+ return err
+ }
+ }
+
+ for version < len(statements) {
+ if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil {
+ return err
+ }
+ version++
+ }
+ if _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version); err != nil {
+ return err
+ }
+ if err = tx.Commit(); err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/pkg/database/sql/000_init.sql b/pkg/database/sql/000_init.sql
new file mode 100644
index 0000000..0248896
--- /dev/null
+++ b/pkg/database/sql/000_init.sql
@@ -0,0 +1,10 @@
+CREATE TABLE schema_version (
+ version INTEGER NOT NULL
+) STRICT;
+
+CREATE TABLE states (
+ id INTEGER PRIMARY KEY,
+ name TEXT NOT NULL,
+ data BLOB NOT NULL
+) STRICT;
+CREATE UNIQUE INDEX states_name on states(name);
diff --git a/pkg/database/states.go b/pkg/database/states.go
new file mode 100644
index 0000000..ef2263b
--- /dev/null
+++ b/pkg/database/states.go
@@ -0,0 +1,30 @@
+package database
+
+import (
+ "database/sql"
+)
+
+func (db *DB) GetState(name string) ([]byte, error) {
+ var encryptedData []byte
+ err := db.QueryRow(`SELECT data FROM states WHERE name = ?;`, name).Scan(&encryptedData)
+ if err != nil {
+ return nil, err
+ }
+ return db.dataEncryptionKey.DecryptAES256(encryptedData)
+}
+
+func (db *DB) SetState(name string, data []byte) error {
+ encryptedData, err := db.dataEncryptionKey.EncryptAES256(data)
+ if err != nil {
+ return err
+ }
+ _, err = db.Exec(
+ `INSERT INTO states(name, data) VALUES (:name, :data) ON CONFLICT DO UPDATE SET data = :data WHERE name = :name;`,
+ sql.Named("data", encryptedData),
+ sql.Named("name", name),
+ )
+ if err != nil {
+ return err
+ }
+ return nil
+}
diff --git a/pkg/scrypto/LICENSE b/pkg/scrypto/LICENSE
new file mode 100644
index 0000000..aeafaf4
--- /dev/null
+++ b/pkg/scrypto/LICENSE
@@ -0,0 +1,13 @@
+Copyright (c) 2023-2024 Nicolas Martyanoff <nicolas@n16f.net>
+
+Permission to use, copy, modify, and/or distribute this software for any
+purpose with or without fee is hereby granted, provided that the above
+copyright notice and this permission notice appear in all copies.
+
+THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
+REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
+AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
+INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
+LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
+OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
+PERFORMANCE OF THIS SOFTWARE.
diff --git a/pkg/scrypto/README.md b/pkg/scrypto/README.md
new file mode 100644
index 0000000..e83bc8e
--- /dev/null
+++ b/pkg/scrypto/README.md
@@ -0,0 +1,5 @@
+Original code imported from https://github.com/galdor/go-service/tree/master/pkg/scrypto
+
+Alterations:
+- converted the EncryptAES256 and DecryptAES256 functions to methods
+- rewrote the tests to get rid of the dependency on testify
diff --git a/pkg/scrypto/aes256.go b/pkg/scrypto/aes256.go
new file mode 100644
index 0000000..27a39d6
--- /dev/null
+++ b/pkg/scrypto/aes256.go
@@ -0,0 +1,138 @@
+package scrypto
+
+import (
+ "bytes"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "encoding/base64"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+)
+
+type AES256Key [32]byte
+
+const (
+ AES256IVSize int = aes.BlockSize
+)
+
+var Zero = AES256Key{}
+
+func (key *AES256Key) Bytes() []byte {
+ return key[:]
+}
+
+func (key *AES256Key) IsZero() bool {
+ return bytes.Equal(key.Bytes(), Zero.Bytes())
+}
+
+func (key *AES256Key) Hex() string {
+ return hex.EncodeToString(key[:])
+}
+
+func (key *AES256Key) FromHex(s string) error {
+ data, err := hex.DecodeString(s)
+ if err != nil {
+ return err
+ }
+
+ if len(data) != 32 {
+ return fmt.Errorf("invalid key size")
+ }
+
+ copy((*key)[:], data[:32])
+
+ return nil
+}
+
+func (key *AES256Key) FromBase64(s string) error {
+ data, err := base64.StdEncoding.DecodeString(s)
+ if err != nil {
+ return err
+ }
+
+ if len(data) != 32 {
+ return fmt.Errorf("invalid key size")
+ }
+
+ copy((*key)[:], data[:32])
+
+ return nil
+}
+
+func (key *AES256Key) MarshalJSON() ([]byte, error) {
+ s := base64.StdEncoding.EncodeToString(key[:])
+ return json.Marshal(s)
+}
+
+func (pkey *AES256Key) UnmarshalJSON(data []byte) error {
+ var s string
+ if err := json.Unmarshal(data, &s); err != nil {
+ return err
+ }
+
+ key, err := base64.StdEncoding.DecodeString(s)
+ if err != nil {
+ return fmt.Errorf("cannot decode base64 value: %w", err)
+ }
+
+ if len(key) != 32 {
+ return fmt.Errorf("invalid key size (must be 32 bytes long)")
+ }
+
+ copy(pkey[:], key)
+
+ return nil
+}
+
+func (key *AES256Key) EncryptAES256(inputData []byte) ([]byte, error) {
+ blockCipher, err := aes.NewCipher(key[:])
+ if err != nil {
+ return nil, fmt.Errorf("cannot create cipher: %w", err)
+ }
+
+ paddedData := PadPKCS5(inputData, aes.BlockSize)
+
+ outputData := make([]byte, AES256IVSize+len(paddedData))
+
+ iv := outputData[:AES256IVSize]
+ encryptedData := outputData[AES256IVSize:]
+
+ if _, err := rand.Read(iv); err != nil {
+ return nil, fmt.Errorf("cannot generate iv: %w", err)
+ }
+
+ encrypter := cipher.NewCBCEncrypter(blockCipher, iv)
+ encrypter.CryptBlocks(encryptedData, paddedData)
+
+ return outputData, nil
+}
+
+func (key *AES256Key) DecryptAES256(inputData []byte) ([]byte, error) {
+ blockCipher, err := aes.NewCipher(key[:])
+ if err != nil {
+ return nil, fmt.Errorf("cannot create cipher: %w", err)
+ }
+
+ if len(inputData) < AES256IVSize {
+ return nil, fmt.Errorf("truncated data")
+ }
+
+ iv := inputData[:AES256IVSize]
+ paddedData := inputData[AES256IVSize:]
+
+ if len(paddedData)%aes.BlockSize != 0 {
+ return nil, fmt.Errorf("invalid padded data length")
+ }
+
+ decrypter := cipher.NewCBCDecrypter(blockCipher, iv)
+ decrypter.CryptBlocks(paddedData, paddedData)
+
+ outputData, err := UnpadPKCS5(paddedData, aes.BlockSize)
+ if err != nil {
+ return nil, fmt.Errorf("invalid padded data: %w", err)
+ }
+
+ return outputData, nil
+}
diff --git a/pkg/scrypto/aes256_test.go b/pkg/scrypto/aes256_test.go
new file mode 100644
index 0000000..9914625
--- /dev/null
+++ b/pkg/scrypto/aes256_test.go
@@ -0,0 +1,60 @@
+package scrypto
+
+import (
+ "encoding/hex"
+ "slices"
+ "strings"
+ "testing"
+)
+
+func TestAES256KeyHex(t *testing.T) {
+ testKeyHex := "28278b7c0a25f01d3cab639633b9487f9ea1e9a2176dc9595a3f01323aa44284"
+ testKey, _ := hex.DecodeString(testKeyHex)
+
+ var key AES256Key
+ if err := key.FromHex(testKeyHex); err != nil {
+ t.Errorf("got unexpected error %+v", err)
+ }
+ if slices.Compare(testKey, key[:]) != 0 {
+ t.Errorf("got %v, wanted %v", testKey, key[:])
+ }
+ if strings.Compare(testKeyHex, key.Hex()) != 0 {
+ t.Errorf("got %v, wanted %v", testKeyHex, key.Hex())
+ }
+}
+
+func TestAES256(t *testing.T) {
+ keyHex := "28278b7c0a25f01d3cab639633b9487f9ea1e9a2176dc9595a3f01323aa44284"
+ var key AES256Key
+ if err := key.FromHex(keyHex); err != nil {
+ t.Errorf("got unexpected error %+v", err)
+ }
+
+ data := []byte("Hello world!")
+ encryptedData, err := key.EncryptAES256(data)
+ if err != nil {
+ t.Errorf("got unexpected error when encrypting data %+v", err)
+ }
+
+ decryptedData, err := key.DecryptAES256(encryptedData)
+ if err != nil {
+ t.Errorf("got unexpected error when decrypting data %+v", err)
+ }
+
+ if slices.Compare(data, decryptedData) != 0 {
+ t.Errorf("got %v, wanted %v", decryptedData, data)
+ }
+}
+
+func TestAES256InvalidData(t *testing.T) {
+ keyHex := "28278b7c0a25f01d3cab639633b9487f9ea1e9a2176dc9595a3f01323aa44284"
+ var key AES256Key
+ if err := key.FromHex(keyHex); err != nil {
+ t.Errorf("got unexpected error when converting data from base64: %+v", err)
+ }
+
+ iv := make([]byte, AES256IVSize)
+ if _, err := key.DecryptAES256(append(iv, []byte("foo")...)); err == nil {
+ t.Error("decrypting operation should have failed")
+ }
+}
diff --git a/pkg/scrypto/pkcs5.go b/pkg/scrypto/pkcs5.go
new file mode 100644
index 0000000..b41c564
--- /dev/null
+++ b/pkg/scrypto/pkcs5.go
@@ -0,0 +1,31 @@
+package scrypto
+
+import (
+ "bytes"
+ "fmt"
+)
+
+func PadPKCS5(data []byte, blockSize int) []byte {
+ paddingSize := blockSize - len(data)%blockSize
+ padding := bytes.Repeat([]byte{byte(paddingSize)}, paddingSize)
+
+ return append(data, padding...)
+}
+
+func UnpadPKCS5(data []byte, blockSize int) ([]byte, error) {
+ dataSize := len(data)
+ if dataSize%blockSize != 0 {
+ return nil, fmt.Errorf("truncated data")
+ }
+
+ if len(data) == 0 {
+ return data, nil
+ } else {
+ paddingSize := int(data[dataSize-1])
+ if paddingSize > dataSize || paddingSize > blockSize {
+ return nil, fmt.Errorf("invalid padding size %d", paddingSize)
+ }
+
+ return data[:dataSize-paddingSize], nil
+ }
+}
diff --git a/pkg/scrypto/pkcs5_test.go b/pkg/scrypto/pkcs5_test.go
new file mode 100644
index 0000000..8bd1a1d
--- /dev/null
+++ b/pkg/scrypto/pkcs5_test.go
@@ -0,0 +1,52 @@
+package scrypto
+
+import (
+ "slices"
+ "testing"
+)
+
+func TestPadPKCS5(t *testing.T) {
+ tests := []struct {
+ data []byte
+ want []byte
+ }{
+ {data: []byte(""), want: []byte("\x04\x04\x04\x04")},
+ {data: []byte("a"), want: []byte("a\x03\x03\x03")},
+ {data: []byte("ab"), want: []byte("ab\x02\x02")},
+ {data: []byte("abc"), want: []byte("abc\x01")},
+ {data: []byte("abcd"), want: []byte("abcd\x04\x04\x04\x04")},
+ {data: []byte("abcde"), want: []byte("abcde\x03\x03\x03")},
+ {data: []byte("abcdefgh"), want: []byte("abcdefgh\x04\x04\x04\x04")},
+ }
+ for _, tt := range tests {
+ got := PadPKCS5(tt.data, 4)
+ if slices.Compare(got, tt.want) != 0 {
+ t.Errorf("got %v, want %v", got, tt.want)
+ }
+ }
+}
+
+func TestUnpadPKCS5(t *testing.T) {
+ tests := []struct {
+ data []byte
+ want []byte
+ }{
+ {data: []byte("\x04\x04\x04\x04"), want: []byte("")},
+ {data: []byte("a\x03\x03\x03"), want: []byte("a")},
+ {data: []byte("ab\x02\x02"), want: []byte("ab")},
+ {data: []byte("abc\x01"), want: []byte("abc")},
+ {data: []byte("abcd\x04\x04\x04\x04"), want: []byte("abcd")},
+ {data: []byte("abcde\x03\x03\x03"), want: []byte("abcde")},
+ {data: []byte("abcdefgh\x04\x04\x04\x04"), want: []byte("abcdefgh")},
+ }
+ for _, tt := range tests {
+ got, err := UnpadPKCS5(tt.data, 4)
+ if err != nil {
+ t.Errorf("got err %+v while wanted %v", err, tt.want)
+ } else {
+ if slices.Compare(got, tt.want) != 0 {
+ t.Errorf("got %v, want %v", got, tt.want)
+ }
+ }
+ }
+}
diff --git a/pkg/scrypto/random.go b/pkg/scrypto/random.go
new file mode 100644
index 0000000..47d9557
--- /dev/null
+++ b/pkg/scrypto/random.go
@@ -0,0 +1,16 @@
+package scrypto
+
+import (
+ "crypto/rand"
+ "fmt"
+)
+
+func RandomBytes(n int) []byte {
+ data := make([]byte, n)
+
+ if _, err := rand.Read(data); err != nil {
+ panic(fmt.Sprintf("cannot generate random data: %+v", err))
+ }
+
+ return data
+}