// Package crypto provides Argon2id KDF, AES-256-GCM encryption, and key helpers. package crypto import ( "crypto/aes" "crypto/cipher" "crypto/rand" "crypto/subtle" "errors" "fmt" "golang.org/x/crypto/argon2" ) const ( // KeySize is the size of AES-256 keys in bytes. KeySize = 32 // NonceSize is the size of AES-GCM nonces in bytes. NonceSize = 12 // SaltSize is the size of Argon2id salts in bytes. SaltSize = 32 // BarrierVersionV1 is the v1 format: [version][nonce][ciphertext+tag]. BarrierVersionV1 byte = 0x01 // BarrierVersionV2 is the v2 format: [version][key_id_len][key_id][nonce][ciphertext+tag]. BarrierVersionV2 byte = 0x02 // BarrierVersion is kept for backward compatibility (alias for V1). BarrierVersion = BarrierVersionV1 // MaxKeyIDLen is the maximum length of a key ID in the v2 format. MaxKeyIDLen = 255 // Default Argon2id parameters. DefaultArgon2Time = 3 DefaultArgon2Memory = 128 * 1024 // 128 MiB in KiB DefaultArgon2Threads = 4 ) var ( ErrInvalidCiphertext = errors.New("crypto: invalid ciphertext") ErrDecryptionFailed = errors.New("crypto: decryption failed") ErrKeyIDTooLong = errors.New("crypto: key ID exceeds maximum length") ) // Argon2Params holds Argon2id KDF parameters. type Argon2Params struct { Time uint32 Memory uint32 // in KiB Threads uint8 } // DefaultArgon2Params returns the default Argon2id parameters. func DefaultArgon2Params() Argon2Params { return Argon2Params{ Time: DefaultArgon2Time, Memory: DefaultArgon2Memory, Threads: DefaultArgon2Threads, } } // DeriveKey derives a 256-bit key from password and salt using Argon2id. func DeriveKey(password []byte, salt []byte, params Argon2Params) []byte { return argon2.IDKey(password, salt, params.Time, params.Memory, params.Threads, KeySize) } // GenerateKey generates a random 256-bit key. func GenerateKey() ([]byte, error) { key := make([]byte, KeySize) if _, err := rand.Read(key); err != nil { return nil, fmt.Errorf("crypto: generate key: %w", err) } return key, nil } // GenerateSalt generates a random salt for Argon2id. func GenerateSalt() ([]byte, error) { salt := make([]byte, SaltSize) if _, err := rand.Read(salt); err != nil { return nil, fmt.Errorf("crypto: generate salt: %w", err) } return salt, nil } // Encrypt encrypts plaintext with AES-256-GCM using the given key. // The additionalData parameter is authenticated but not encrypted (AAD); // pass nil when no binding context is needed. // Returns: [version byte][12-byte nonce][ciphertext+tag] func Encrypt(key, plaintext, additionalData []byte) ([]byte, error) { block, err := aes.NewCipher(key) if err != nil { return nil, fmt.Errorf("crypto: new cipher: %w", err) } gcm, err := cipher.NewGCM(block) if err != nil { return nil, fmt.Errorf("crypto: new gcm: %w", err) } nonce := make([]byte, NonceSize) if _, err := rand.Read(nonce); err != nil { return nil, fmt.Errorf("crypto: generate nonce: %w", err) } ciphertext := gcm.Seal(nil, nonce, plaintext, additionalData) // Format: [version][nonce][ciphertext+tag] result := make([]byte, 1+NonceSize+len(ciphertext)) result[0] = BarrierVersion copy(result[1:1+NonceSize], nonce) copy(result[1+NonceSize:], ciphertext) return result, nil } // Decrypt decrypts ciphertext produced by Encrypt. // The additionalData must match the value provided during encryption. func Decrypt(key, data, additionalData []byte) ([]byte, error) { if len(data) < 1+NonceSize+aes.BlockSize { return nil, ErrInvalidCiphertext } if data[0] != BarrierVersion { return nil, fmt.Errorf("crypto: unsupported version: %d", data[0]) } nonce := data[1 : 1+NonceSize] ciphertext := data[1+NonceSize:] block, err := aes.NewCipher(key) if err != nil { return nil, fmt.Errorf("crypto: new cipher: %w", err) } gcm, err := cipher.NewGCM(block) if err != nil { return nil, fmt.Errorf("crypto: new gcm: %w", err) } plaintext, err := gcm.Open(nil, nonce, ciphertext, additionalData) if err != nil { return nil, ErrDecryptionFailed } return plaintext, nil } // EncryptV2 encrypts plaintext with AES-256-GCM, embedding a key ID in the ciphertext. // Format: [0x02][key_id_len:1][key_id:N][nonce:12][ciphertext+tag] func EncryptV2(key []byte, keyID string, plaintext, additionalData []byte) ([]byte, error) { if len(keyID) > MaxKeyIDLen { return nil, ErrKeyIDTooLong } block, err := aes.NewCipher(key) if err != nil { return nil, fmt.Errorf("crypto: new cipher: %w", err) } gcm, err := cipher.NewGCM(block) if err != nil { return nil, fmt.Errorf("crypto: new gcm: %w", err) } nonce := make([]byte, NonceSize) if _, err := rand.Read(nonce); err != nil { return nil, fmt.Errorf("crypto: generate nonce: %w", err) } ciphertext := gcm.Seal(nil, nonce, plaintext, additionalData) kidLen := len(keyID) // Format: [version][key_id_len][key_id][nonce][ciphertext+tag] result := make([]byte, 1+1+kidLen+NonceSize+len(ciphertext)) result[0] = BarrierVersionV2 result[1] = byte(kidLen) copy(result[2:2+kidLen], keyID) copy(result[2+kidLen:2+kidLen+NonceSize], nonce) copy(result[2+kidLen+NonceSize:], ciphertext) return result, nil } // DecryptV2 decrypts ciphertext that may be in v1 or v2 format. // For v2 format, it extracts the key ID and returns it alongside the plaintext. // For v1 format, it returns an empty key ID. func DecryptV2(key, data, additionalData []byte) (plaintext []byte, keyID string, err error) { if len(data) < 1 { return nil, "", ErrInvalidCiphertext } switch data[0] { case BarrierVersionV1: pt, err := Decrypt(key, data, additionalData) return pt, "", err case BarrierVersionV2: if len(data) < 2 { return nil, "", ErrInvalidCiphertext } kidLen := int(data[1]) headerLen := 2 + kidLen if len(data) < headerLen+NonceSize+aes.BlockSize { return nil, "", ErrInvalidCiphertext } keyID = string(data[2 : 2+kidLen]) nonce := data[headerLen : headerLen+NonceSize] ciphertext := data[headerLen+NonceSize:] block, err := aes.NewCipher(key) if err != nil { return nil, "", fmt.Errorf("crypto: new cipher: %w", err) } gcm, err := cipher.NewGCM(block) if err != nil { return nil, "", fmt.Errorf("crypto: new gcm: %w", err) } pt, err := gcm.Open(nil, nonce, ciphertext, additionalData) if err != nil { return nil, "", ErrDecryptionFailed } return pt, keyID, nil default: return nil, "", fmt.Errorf("crypto: unsupported version: %d", data[0]) } } // ExtractKeyID returns the key ID from a v2 ciphertext without decrypting. // Returns empty string for v1 ciphertext. func ExtractKeyID(data []byte) (string, error) { if len(data) < 1 { return "", ErrInvalidCiphertext } switch data[0] { case BarrierVersionV1: return "", nil case BarrierVersionV2: if len(data) < 2 { return "", ErrInvalidCiphertext } kidLen := int(data[1]) if len(data) < 2+kidLen { return "", ErrInvalidCiphertext } return string(data[2 : 2+kidLen]), nil default: return "", fmt.Errorf("crypto: unsupported version: %d", data[0]) } } // Zeroize overwrites a byte slice with zeros. func Zeroize(b []byte) { for i := range b { b[i] = 0 } } // ConstantTimeEqual compares two byte slices in constant time. func ConstantTimeEqual(a, b []byte) bool { return subtle.ConstantTimeCompare(a, b) == 1 }