summaryrefslogtreecommitdiff
path: root/pkg/scrypto/aes256.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/scrypto/aes256.go')
-rw-r--r--pkg/scrypto/aes256.go138
1 files changed, 138 insertions, 0 deletions
diff --git a/pkg/scrypto/aes256.go b/pkg/scrypto/aes256.go
new file mode 100644
index 0000000..27a39d6
--- /dev/null
+++ b/pkg/scrypto/aes256.go
@@ -0,0 +1,138 @@
+package scrypto
+
+import (
+ "bytes"
+ "crypto/aes"
+ "crypto/cipher"
+ "crypto/rand"
+ "encoding/base64"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+)
+
+type AES256Key [32]byte
+
+const (
+ AES256IVSize int = aes.BlockSize
+)
+
+var Zero = AES256Key{}
+
+func (key *AES256Key) Bytes() []byte {
+ return key[:]
+}
+
+func (key *AES256Key) IsZero() bool {
+ return bytes.Equal(key.Bytes(), Zero.Bytes())
+}
+
+func (key *AES256Key) Hex() string {
+ return hex.EncodeToString(key[:])
+}
+
+func (key *AES256Key) FromHex(s string) error {
+ data, err := hex.DecodeString(s)
+ if err != nil {
+ return err
+ }
+
+ if len(data) != 32 {
+ return fmt.Errorf("invalid key size")
+ }
+
+ copy((*key)[:], data[:32])
+
+ return nil
+}
+
+func (key *AES256Key) FromBase64(s string) error {
+ data, err := base64.StdEncoding.DecodeString(s)
+ if err != nil {
+ return err
+ }
+
+ if len(data) != 32 {
+ return fmt.Errorf("invalid key size")
+ }
+
+ copy((*key)[:], data[:32])
+
+ return nil
+}
+
+func (key *AES256Key) MarshalJSON() ([]byte, error) {
+ s := base64.StdEncoding.EncodeToString(key[:])
+ return json.Marshal(s)
+}
+
+func (pkey *AES256Key) UnmarshalJSON(data []byte) error {
+ var s string
+ if err := json.Unmarshal(data, &s); err != nil {
+ return err
+ }
+
+ key, err := base64.StdEncoding.DecodeString(s)
+ if err != nil {
+ return fmt.Errorf("cannot decode base64 value: %w", err)
+ }
+
+ if len(key) != 32 {
+ return fmt.Errorf("invalid key size (must be 32 bytes long)")
+ }
+
+ copy(pkey[:], key)
+
+ return nil
+}
+
+func (key *AES256Key) EncryptAES256(inputData []byte) ([]byte, error) {
+ blockCipher, err := aes.NewCipher(key[:])
+ if err != nil {
+ return nil, fmt.Errorf("cannot create cipher: %w", err)
+ }
+
+ paddedData := PadPKCS5(inputData, aes.BlockSize)
+
+ outputData := make([]byte, AES256IVSize+len(paddedData))
+
+ iv := outputData[:AES256IVSize]
+ encryptedData := outputData[AES256IVSize:]
+
+ if _, err := rand.Read(iv); err != nil {
+ return nil, fmt.Errorf("cannot generate iv: %w", err)
+ }
+
+ encrypter := cipher.NewCBCEncrypter(blockCipher, iv)
+ encrypter.CryptBlocks(encryptedData, paddedData)
+
+ return outputData, nil
+}
+
+func (key *AES256Key) DecryptAES256(inputData []byte) ([]byte, error) {
+ blockCipher, err := aes.NewCipher(key[:])
+ if err != nil {
+ return nil, fmt.Errorf("cannot create cipher: %w", err)
+ }
+
+ if len(inputData) < AES256IVSize {
+ return nil, fmt.Errorf("truncated data")
+ }
+
+ iv := inputData[:AES256IVSize]
+ paddedData := inputData[AES256IVSize:]
+
+ if len(paddedData)%aes.BlockSize != 0 {
+ return nil, fmt.Errorf("invalid padded data length")
+ }
+
+ decrypter := cipher.NewCBCDecrypter(blockCipher, iv)
+ decrypter.CryptBlocks(paddedData, paddedData)
+
+ outputData, err := UnpadPKCS5(paddedData, aes.BlockSize)
+ if err != nil {
+ return nil, fmt.Errorf("invalid padded data: %w", err)
+ }
+
+ return outputData, nil
+}