feat(tfstated): implement GET and POST methods, states are encrypted in a sqlite3 database
This commit is contained in:
parent
baf5aac08e
commit
4ff490806c
18 changed files with 627 additions and 2 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -1 +1,4 @@
|
||||||
tfstate
|
tfstate
|
||||||
|
tfstate.db
|
||||||
|
tfstate.db-shm
|
||||||
|
tfstate.db-wal
|
||||||
|
|
35
cmd/tfstated/get.go
Normal file
35
cmd/tfstated/get.go
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.adyxax.org/adyxax/tfstated/pkg/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func handleGet(db *database.DB) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Cache-Control", "no-store, no-cache")
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
if r.URL.Path == "/" {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = w.Write([]byte("{\"msg\": \"No state path provided, cannot GET /\"}"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if data, err := db.GetState(r.URL.Path); err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
_, _ = w.Write([]byte(fmt.Sprintf("{\"msg\": \"%+v\"}", err)))
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write(data)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -12,6 +12,8 @@ import (
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.adyxax.org/adyxax/tfstated/pkg/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
@ -22,6 +24,7 @@ type Config struct {
|
||||||
func run(
|
func run(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
config *Config,
|
config *Config,
|
||||||
|
db *database.DB,
|
||||||
args []string,
|
args []string,
|
||||||
getenv func(string) string,
|
getenv func(string) string,
|
||||||
stdin io.Reader,
|
stdin io.Reader,
|
||||||
|
@ -30,10 +33,20 @@ func run(
|
||||||
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
dataEncryptionKey := getenv("DATA_ENCRYPTION_KEY")
|
||||||
|
if dataEncryptionKey == "" {
|
||||||
|
return fmt.Errorf("the DATA_ENCRYPTION_KEY environment variable is not set")
|
||||||
|
}
|
||||||
|
if err := db.SetDataEncryptionKey(dataEncryptionKey); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
addRoutes(
|
addRoutes(
|
||||||
mux,
|
mux,
|
||||||
|
db,
|
||||||
)
|
)
|
||||||
|
|
||||||
httpServer := &http.Server{
|
httpServer := &http.Server{
|
||||||
Addr: net.JoinHostPort(config.Host, config.Port),
|
Addr: net.JoinHostPort(config.Host, config.Port),
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
|
@ -79,9 +92,17 @@ func main() {
|
||||||
Port: "8080",
|
Port: "8080",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
db, err := database.NewDB(ctx, "./tfstate.db?_txlock=immediate")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "database init error: %+v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
if err := run(
|
if err := run(
|
||||||
ctx,
|
ctx,
|
||||||
&config,
|
&config,
|
||||||
|
db,
|
||||||
os.Args,
|
os.Args,
|
||||||
os.Getenv,
|
os.Getenv,
|
||||||
os.Stdin,
|
os.Stdin,
|
||||||
|
|
32
cmd/tfstated/post.go
Normal file
32
cmd/tfstated/post.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.adyxax.org/adyxax/tfstated/pkg/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func handlePost(db *database.DB) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if r.URL.Path == "/" {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = w.Write([]byte("{\"msg\": \"No state path provided, cannot POST /\"}"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = w.Write([]byte(fmt.Sprintf("{\"msg\": \"failed to read request body: %+v\"}", err)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := db.SetState(r.URL.Path, data); err != nil {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte(fmt.Sprintf("{\"msg\": \"%+v\"}", err)))
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -1,7 +1,16 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import "net/http"
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
func addRoutes(mux *http.ServeMux) {
|
"git.adyxax.org/adyxax/tfstated/pkg/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func addRoutes(
|
||||||
|
mux *http.ServeMux,
|
||||||
|
db *database.DB,
|
||||||
|
) {
|
||||||
mux.Handle("GET /healthz", handleHealthz())
|
mux.Handle("GET /healthz", handleHealthz())
|
||||||
|
mux.Handle("GET /", handleGet(db))
|
||||||
|
mux.Handle("POST /", handlePost(db))
|
||||||
}
|
}
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -1,3 +1,5 @@
|
||||||
module git.adyxax.org/adyxax/tfstated
|
module git.adyxax.org/adyxax/tfstated
|
||||||
|
|
||||||
go 1.23.1
|
go 1.23.1
|
||||||
|
|
||||||
|
require github.com/mattn/go-sqlite3 v1.14.23
|
||||||
|
|
2
go.sum
Normal file
2
go.sum
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
103
pkg/database/db.go
Normal file
103
pkg/database/db.go
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"git.adyxax.org/adyxax/tfstated/pkg/scrypto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func initDB(ctx context.Context, url string) (*sql.DB, error) {
|
||||||
|
db, err := sql.Open("sqlite3", url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = db.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if _, err = db.ExecContext(ctx, "PRAGMA busy_timeout = 5000"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type DB struct {
|
||||||
|
ctx context.Context
|
||||||
|
dataEncryptionKey scrypto.AES256Key
|
||||||
|
readDB *sql.DB
|
||||||
|
writeDB *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDB(ctx context.Context, url string) (*DB, error) {
|
||||||
|
readDB, err := initDB(ctx, url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = readDB.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
readDB.SetMaxOpenConns(max(4, runtime.NumCPU()))
|
||||||
|
|
||||||
|
writeDB, err := initDB(ctx, url)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = writeDB.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
writeDB.SetMaxOpenConns(1)
|
||||||
|
|
||||||
|
db := DB{
|
||||||
|
ctx: ctx,
|
||||||
|
readDB: readDB,
|
||||||
|
writeDB: writeDB,
|
||||||
|
}
|
||||||
|
if _, err = db.Exec("PRAGMA foreign_keys = ON"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err = db.Exec("PRAGMA cache_size = 10000000"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err = db.Exec("PRAGMA journal_mode = WAL"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err = db.Exec("PRAGMA synchronous = NORMAL"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = db.migrate(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &db, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Begin() (*sql.Tx, error) {
|
||||||
|
return db.writeDB.Begin()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Close() error {
|
||||||
|
if err := db.readDB.Close(); err != nil {
|
||||||
|
_ = db.writeDB.Close()
|
||||||
|
}
|
||||||
|
return db.writeDB.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) Exec(query string, args ...any) (sql.Result, error) {
|
||||||
|
return db.writeDB.ExecContext(db.ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) QueryRow(query string, args ...any) *sql.Row {
|
||||||
|
return db.readDB.QueryRowContext(db.ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) SetDataEncryptionKey(s string) error {
|
||||||
|
return db.dataEncryptionKey.FromBase64(s)
|
||||||
|
}
|
63
pkg/database/migrations.go
Normal file
63
pkg/database/migrations.go
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"io/fs"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed sql/*.sql
|
||||||
|
var schemaFiles embed.FS
|
||||||
|
|
||||||
|
func (db *DB) migrate() error {
|
||||||
|
statements := make([]string, 0)
|
||||||
|
err := fs.WalkDir(schemaFiles, ".", func(path string, d fs.DirEntry, err error) error {
|
||||||
|
if d.IsDir() || err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var stmts []byte
|
||||||
|
if stmts, err = schemaFiles.ReadFile(path); err != nil {
|
||||||
|
return err
|
||||||
|
} else {
|
||||||
|
statements = append(statements, string(stmts))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var version int
|
||||||
|
if err = tx.QueryRowContext(db.ctx, `SELECT version FROM schema_version;`).Scan(&version); err != nil {
|
||||||
|
if err.Error() == "no such table: schema_version" {
|
||||||
|
version = 0
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for version < len(statements) {
|
||||||
|
if _, err = tx.ExecContext(db.ctx, statements[version]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
version++
|
||||||
|
}
|
||||||
|
if _, err = tx.ExecContext(db.ctx, `DELETE FROM schema_version; INSERT INTO schema_version (version) VALUES (?);`, version); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = tx.Commit(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
10
pkg/database/sql/000_init.sql
Normal file
10
pkg/database/sql/000_init.sql
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
CREATE TABLE schema_version (
|
||||||
|
version INTEGER NOT NULL
|
||||||
|
) STRICT;
|
||||||
|
|
||||||
|
CREATE TABLE states (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
data BLOB NOT NULL
|
||||||
|
) STRICT;
|
||||||
|
CREATE UNIQUE INDEX states_name on states(name);
|
30
pkg/database/states.go
Normal file
30
pkg/database/states.go
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (db *DB) GetState(name string) ([]byte, error) {
|
||||||
|
var encryptedData []byte
|
||||||
|
err := db.QueryRow(`SELECT data FROM states WHERE name = ?;`, name).Scan(&encryptedData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return db.dataEncryptionKey.DecryptAES256(encryptedData)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (db *DB) SetState(name string, data []byte) error {
|
||||||
|
encryptedData, err := db.dataEncryptionKey.EncryptAES256(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = db.Exec(
|
||||||
|
`INSERT INTO states(name, data) VALUES (:name, :data) ON CONFLICT DO UPDATE SET data = :data WHERE name = :name;`,
|
||||||
|
sql.Named("data", encryptedData),
|
||||||
|
sql.Named("name", name),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
13
pkg/scrypto/LICENSE
Normal file
13
pkg/scrypto/LICENSE
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
Copyright (c) 2023-2024 Nicolas Martyanoff <nicolas@n16f.net>
|
||||||
|
|
||||||
|
Permission to use, copy, modify, and/or distribute this software for any
|
||||||
|
purpose with or without fee is hereby granted, provided that the above
|
||||||
|
copyright notice and this permission notice appear in all copies.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
|
||||||
|
REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
|
||||||
|
AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
|
||||||
|
INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
|
||||||
|
LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR
|
||||||
|
OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
|
||||||
|
PERFORMANCE OF THIS SOFTWARE.
|
5
pkg/scrypto/README.md
Normal file
5
pkg/scrypto/README.md
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
Original code imported from https://github.com/galdor/go-service/tree/master/pkg/scrypto
|
||||||
|
|
||||||
|
Alterations:
|
||||||
|
- converted the EncryptAES256 and DecryptAES256 functions to methods
|
||||||
|
- rewrote the tests to get rid of the dependency on testify
|
138
pkg/scrypto/aes256.go
Normal file
138
pkg/scrypto/aes256.go
Normal file
|
@ -0,0 +1,138 @@
|
||||||
|
package scrypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AES256Key [32]byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
AES256IVSize int = aes.BlockSize
|
||||||
|
)
|
||||||
|
|
||||||
|
var Zero = AES256Key{}
|
||||||
|
|
||||||
|
func (key *AES256Key) Bytes() []byte {
|
||||||
|
return key[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key *AES256Key) IsZero() bool {
|
||||||
|
return bytes.Equal(key.Bytes(), Zero.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key *AES256Key) Hex() string {
|
||||||
|
return hex.EncodeToString(key[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key *AES256Key) FromHex(s string) error {
|
||||||
|
data, err := hex.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data) != 32 {
|
||||||
|
return fmt.Errorf("invalid key size")
|
||||||
|
}
|
||||||
|
|
||||||
|
copy((*key)[:], data[:32])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key *AES256Key) FromBase64(s string) error {
|
||||||
|
data, err := base64.StdEncoding.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data) != 32 {
|
||||||
|
return fmt.Errorf("invalid key size")
|
||||||
|
}
|
||||||
|
|
||||||
|
copy((*key)[:], data[:32])
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key *AES256Key) MarshalJSON() ([]byte, error) {
|
||||||
|
s := base64.StdEncoding.EncodeToString(key[:])
|
||||||
|
return json.Marshal(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pkey *AES256Key) UnmarshalJSON(data []byte) error {
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(data, &s); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
key, err := base64.StdEncoding.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot decode base64 value: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(key) != 32 {
|
||||||
|
return fmt.Errorf("invalid key size (must be 32 bytes long)")
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(pkey[:], key)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key *AES256Key) EncryptAES256(inputData []byte) ([]byte, error) {
|
||||||
|
blockCipher, err := aes.NewCipher(key[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot create cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
paddedData := PadPKCS5(inputData, aes.BlockSize)
|
||||||
|
|
||||||
|
outputData := make([]byte, AES256IVSize+len(paddedData))
|
||||||
|
|
||||||
|
iv := outputData[:AES256IVSize]
|
||||||
|
encryptedData := outputData[AES256IVSize:]
|
||||||
|
|
||||||
|
if _, err := rand.Read(iv); err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot generate iv: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
encrypter := cipher.NewCBCEncrypter(blockCipher, iv)
|
||||||
|
encrypter.CryptBlocks(encryptedData, paddedData)
|
||||||
|
|
||||||
|
return outputData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (key *AES256Key) DecryptAES256(inputData []byte) ([]byte, error) {
|
||||||
|
blockCipher, err := aes.NewCipher(key[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot create cipher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(inputData) < AES256IVSize {
|
||||||
|
return nil, fmt.Errorf("truncated data")
|
||||||
|
}
|
||||||
|
|
||||||
|
iv := inputData[:AES256IVSize]
|
||||||
|
paddedData := inputData[AES256IVSize:]
|
||||||
|
|
||||||
|
if len(paddedData)%aes.BlockSize != 0 {
|
||||||
|
return nil, fmt.Errorf("invalid padded data length")
|
||||||
|
}
|
||||||
|
|
||||||
|
decrypter := cipher.NewCBCDecrypter(blockCipher, iv)
|
||||||
|
decrypter.CryptBlocks(paddedData, paddedData)
|
||||||
|
|
||||||
|
outputData, err := UnpadPKCS5(paddedData, aes.BlockSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid padded data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputData, nil
|
||||||
|
}
|
60
pkg/scrypto/aes256_test.go
Normal file
60
pkg/scrypto/aes256_test.go
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
package scrypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAES256KeyHex(t *testing.T) {
|
||||||
|
testKeyHex := "28278b7c0a25f01d3cab639633b9487f9ea1e9a2176dc9595a3f01323aa44284"
|
||||||
|
testKey, _ := hex.DecodeString(testKeyHex)
|
||||||
|
|
||||||
|
var key AES256Key
|
||||||
|
if err := key.FromHex(testKeyHex); err != nil {
|
||||||
|
t.Errorf("got unexpected error %+v", err)
|
||||||
|
}
|
||||||
|
if slices.Compare(testKey, key[:]) != 0 {
|
||||||
|
t.Errorf("got %v, wanted %v", testKey, key[:])
|
||||||
|
}
|
||||||
|
if strings.Compare(testKeyHex, key.Hex()) != 0 {
|
||||||
|
t.Errorf("got %v, wanted %v", testKeyHex, key.Hex())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAES256(t *testing.T) {
|
||||||
|
keyHex := "28278b7c0a25f01d3cab639633b9487f9ea1e9a2176dc9595a3f01323aa44284"
|
||||||
|
var key AES256Key
|
||||||
|
if err := key.FromHex(keyHex); err != nil {
|
||||||
|
t.Errorf("got unexpected error %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data := []byte("Hello world!")
|
||||||
|
encryptedData, err := key.EncryptAES256(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("got unexpected error when encrypting data %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decryptedData, err := key.DecryptAES256(encryptedData)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("got unexpected error when decrypting data %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Compare(data, decryptedData) != 0 {
|
||||||
|
t.Errorf("got %v, wanted %v", decryptedData, data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAES256InvalidData(t *testing.T) {
|
||||||
|
keyHex := "28278b7c0a25f01d3cab639633b9487f9ea1e9a2176dc9595a3f01323aa44284"
|
||||||
|
var key AES256Key
|
||||||
|
if err := key.FromHex(keyHex); err != nil {
|
||||||
|
t.Errorf("got unexpected error when converting data from base64: %+v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
iv := make([]byte, AES256IVSize)
|
||||||
|
if _, err := key.DecryptAES256(append(iv, []byte("foo")...)); err == nil {
|
||||||
|
t.Error("decrypting operation should have failed")
|
||||||
|
}
|
||||||
|
}
|
31
pkg/scrypto/pkcs5.go
Normal file
31
pkg/scrypto/pkcs5.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
package scrypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func PadPKCS5(data []byte, blockSize int) []byte {
|
||||||
|
paddingSize := blockSize - len(data)%blockSize
|
||||||
|
padding := bytes.Repeat([]byte{byte(paddingSize)}, paddingSize)
|
||||||
|
|
||||||
|
return append(data, padding...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func UnpadPKCS5(data []byte, blockSize int) ([]byte, error) {
|
||||||
|
dataSize := len(data)
|
||||||
|
if dataSize%blockSize != 0 {
|
||||||
|
return nil, fmt.Errorf("truncated data")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
return data, nil
|
||||||
|
} else {
|
||||||
|
paddingSize := int(data[dataSize-1])
|
||||||
|
if paddingSize > dataSize || paddingSize > blockSize {
|
||||||
|
return nil, fmt.Errorf("invalid padding size %d", paddingSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data[:dataSize-paddingSize], nil
|
||||||
|
}
|
||||||
|
}
|
52
pkg/scrypto/pkcs5_test.go
Normal file
52
pkg/scrypto/pkcs5_test.go
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
package scrypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPadPKCS5(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
data []byte
|
||||||
|
want []byte
|
||||||
|
}{
|
||||||
|
{data: []byte(""), want: []byte("\x04\x04\x04\x04")},
|
||||||
|
{data: []byte("a"), want: []byte("a\x03\x03\x03")},
|
||||||
|
{data: []byte("ab"), want: []byte("ab\x02\x02")},
|
||||||
|
{data: []byte("abc"), want: []byte("abc\x01")},
|
||||||
|
{data: []byte("abcd"), want: []byte("abcd\x04\x04\x04\x04")},
|
||||||
|
{data: []byte("abcde"), want: []byte("abcde\x03\x03\x03")},
|
||||||
|
{data: []byte("abcdefgh"), want: []byte("abcdefgh\x04\x04\x04\x04")},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
got := PadPKCS5(tt.data, 4)
|
||||||
|
if slices.Compare(got, tt.want) != 0 {
|
||||||
|
t.Errorf("got %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnpadPKCS5(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
data []byte
|
||||||
|
want []byte
|
||||||
|
}{
|
||||||
|
{data: []byte("\x04\x04\x04\x04"), want: []byte("")},
|
||||||
|
{data: []byte("a\x03\x03\x03"), want: []byte("a")},
|
||||||
|
{data: []byte("ab\x02\x02"), want: []byte("ab")},
|
||||||
|
{data: []byte("abc\x01"), want: []byte("abc")},
|
||||||
|
{data: []byte("abcd\x04\x04\x04\x04"), want: []byte("abcd")},
|
||||||
|
{data: []byte("abcde\x03\x03\x03"), want: []byte("abcde")},
|
||||||
|
{data: []byte("abcdefgh\x04\x04\x04\x04"), want: []byte("abcdefgh")},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
got, err := UnpadPKCS5(tt.data, 4)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("got err %+v while wanted %v", err, tt.want)
|
||||||
|
} else {
|
||||||
|
if slices.Compare(got, tt.want) != 0 {
|
||||||
|
t.Errorf("got %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
16
pkg/scrypto/random.go
Normal file
16
pkg/scrypto/random.go
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
package scrypto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func RandomBytes(n int) []byte {
|
||||||
|
data := make([]byte, n)
|
||||||
|
|
||||||
|
if _, err := rand.Read(data); err != nil {
|
||||||
|
panic(fmt.Sprintf("cannot generate random data: %+v", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue