tfstated/pkg/scrypto/aes256.go

138 lines
2.7 KiB
Go

package scrypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
)
type AES256Key [32]byte
const (
AES256IVSize int = aes.BlockSize
)
var Zero = AES256Key{}
func (key *AES256Key) Bytes() []byte {
return key[:]
}
func (key *AES256Key) IsZero() bool {
return bytes.Equal(key.Bytes(), Zero.Bytes())
}
func (key *AES256Key) Hex() string {
return hex.EncodeToString(key[:])
}
func (key *AES256Key) FromHex(s string) error {
data, err := hex.DecodeString(s)
if err != nil {
return err
}
if len(data) != 32 {
return fmt.Errorf("invalid key size")
}
copy((*key)[:], data[:32])
return nil
}
func (key *AES256Key) FromBase64(s string) error {
data, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return err
}
if len(data) != 32 {
return fmt.Errorf("invalid key size")
}
copy((*key)[:], data[:32])
return nil
}
func (key *AES256Key) MarshalJSON() ([]byte, error) {
s := base64.StdEncoding.EncodeToString(key[:])
return json.Marshal(s)
}
func (pkey *AES256Key) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
key, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return fmt.Errorf("cannot decode base64 value: %w", err)
}
if len(key) != 32 {
return fmt.Errorf("invalid key size (must be 32 bytes long)")
}
copy(pkey[:], key)
return nil
}
func (key *AES256Key) EncryptAES256(inputData []byte) ([]byte, error) {
blockCipher, err := aes.NewCipher(key[:])
if err != nil {
return nil, fmt.Errorf("cannot create cipher: %w", err)
}
paddedData := PadPKCS5(inputData, aes.BlockSize)
outputData := make([]byte, AES256IVSize+len(paddedData))
iv := outputData[:AES256IVSize]
encryptedData := outputData[AES256IVSize:]
if _, err := rand.Read(iv); err != nil {
return nil, fmt.Errorf("cannot generate iv: %w", err)
}
encrypter := cipher.NewCBCEncrypter(blockCipher, iv)
encrypter.CryptBlocks(encryptedData, paddedData)
return outputData, nil
}
func (key *AES256Key) DecryptAES256(inputData []byte) ([]byte, error) {
blockCipher, err := aes.NewCipher(key[:])
if err != nil {
return nil, fmt.Errorf("cannot create cipher: %w", err)
}
if len(inputData) < AES256IVSize {
return nil, fmt.Errorf("truncated data")
}
iv := inputData[:AES256IVSize]
paddedData := inputData[AES256IVSize:]
if len(paddedData)%aes.BlockSize != 0 {
return nil, fmt.Errorf("invalid padded data length")
}
decrypter := cipher.NewCBCDecrypter(blockCipher, iv)
decrypter.CryptBlocks(paddedData, paddedData)
outputData, err := UnpadPKCS5(paddedData, aes.BlockSize)
if err != nil {
return nil, fmt.Errorf("invalid padded data: %w", err)
}
return outputData, nil
}