Implement a two-level key hierarchy: the MEK now wraps per-engine DEKs stored in a new barrier_keys table, rather than encrypting all barrier entries directly. A v2 ciphertext format (0x02) embeds the key ID so the barrier can resolve which DEK to use on decryption. v1 ciphertext remains supported for backward compatibility. Key changes: - crypto: EncryptV2/DecryptV2/ExtractKeyID for v2 ciphertext with key IDs - barrier: key registry (CreateKey, RotateKey, ListKeys, MigrateToV2, ReWrapKeys) - seal: RotateMEK re-wraps DEKs without re-encrypting data - engine: Mount auto-creates per-engine DEK - REST + gRPC: barrier/keys, barrier/rotate-mek, barrier/rotate-key, barrier/migrate - proto: BarrierService (v1 + v2) with ListKeys, RotateMEK, RotateKey, Migrate - db: migration v2 adds barrier_keys table Also includes: security audit report, CSRF protection, engine design specs (sshca, transit, user), path-bound AAD migration tool, policy engine enhancements, and ARCHITECTURE.md updates. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -20,8 +20,16 @@ const (
|
||||
// SaltSize is the size of Argon2id salts in bytes.
|
||||
SaltSize = 32
|
||||
|
||||
// BarrierVersion is the version byte prefix for encrypted barrier entries.
|
||||
BarrierVersion byte = 0x01
|
||||
// BarrierVersionV1 is the v1 format: [version][nonce][ciphertext+tag].
|
||||
BarrierVersionV1 byte = 0x01
|
||||
// BarrierVersionV2 is the v2 format: [version][key_id_len][key_id][nonce][ciphertext+tag].
|
||||
BarrierVersionV2 byte = 0x02
|
||||
|
||||
// BarrierVersion is kept for backward compatibility (alias for V1).
|
||||
BarrierVersion = BarrierVersionV1
|
||||
|
||||
// MaxKeyIDLen is the maximum length of a key ID in the v2 format.
|
||||
MaxKeyIDLen = 255
|
||||
|
||||
// Default Argon2id parameters.
|
||||
DefaultArgon2Time = 3
|
||||
@@ -32,6 +40,7 @@ const (
|
||||
var (
|
||||
ErrInvalidCiphertext = errors.New("crypto: invalid ciphertext")
|
||||
ErrDecryptionFailed = errors.New("crypto: decryption failed")
|
||||
ErrKeyIDTooLong = errors.New("crypto: key ID exceeds maximum length")
|
||||
)
|
||||
|
||||
// Argon2Params holds Argon2id KDF parameters.
|
||||
@@ -74,8 +83,10 @@ func GenerateSalt() ([]byte, error) {
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext with AES-256-GCM using the given key.
|
||||
// The additionalData parameter is authenticated but not encrypted (AAD);
|
||||
// pass nil when no binding context is needed.
|
||||
// Returns: [version byte][12-byte nonce][ciphertext+tag]
|
||||
func Encrypt(key, plaintext []byte) ([]byte, error) {
|
||||
func Encrypt(key, plaintext, additionalData []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crypto: new cipher: %w", err)
|
||||
@@ -88,7 +99,7 @@ func Encrypt(key, plaintext []byte) ([]byte, error) {
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("crypto: generate nonce: %w", err)
|
||||
}
|
||||
ciphertext := gcm.Seal(nil, nonce, plaintext, nil)
|
||||
ciphertext := gcm.Seal(nil, nonce, plaintext, additionalData)
|
||||
|
||||
// Format: [version][nonce][ciphertext+tag]
|
||||
result := make([]byte, 1+NonceSize+len(ciphertext))
|
||||
@@ -99,7 +110,8 @@ func Encrypt(key, plaintext []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
// Decrypt decrypts ciphertext produced by Encrypt.
|
||||
func Decrypt(key, data []byte) ([]byte, error) {
|
||||
// The additionalData must match the value provided during encryption.
|
||||
func Decrypt(key, data, additionalData []byte) ([]byte, error) {
|
||||
if len(data) < 1+NonceSize+aes.BlockSize {
|
||||
return nil, ErrInvalidCiphertext
|
||||
}
|
||||
@@ -117,13 +129,114 @@ func Decrypt(key, data []byte) ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crypto: new gcm: %w", err)
|
||||
}
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, additionalData)
|
||||
if err != nil {
|
||||
return nil, ErrDecryptionFailed
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// EncryptV2 encrypts plaintext with AES-256-GCM, embedding a key ID in the ciphertext.
|
||||
// Format: [0x02][key_id_len:1][key_id:N][nonce:12][ciphertext+tag]
|
||||
func EncryptV2(key []byte, keyID string, plaintext, additionalData []byte) ([]byte, error) {
|
||||
if len(keyID) > MaxKeyIDLen {
|
||||
return nil, ErrKeyIDTooLong
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crypto: new cipher: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crypto: new gcm: %w", err)
|
||||
}
|
||||
nonce := make([]byte, NonceSize)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("crypto: generate nonce: %w", err)
|
||||
}
|
||||
ciphertext := gcm.Seal(nil, nonce, plaintext, additionalData)
|
||||
|
||||
kidLen := len(keyID)
|
||||
// Format: [version][key_id_len][key_id][nonce][ciphertext+tag]
|
||||
result := make([]byte, 1+1+kidLen+NonceSize+len(ciphertext))
|
||||
result[0] = BarrierVersionV2
|
||||
result[1] = byte(kidLen)
|
||||
copy(result[2:2+kidLen], keyID)
|
||||
copy(result[2+kidLen:2+kidLen+NonceSize], nonce)
|
||||
copy(result[2+kidLen+NonceSize:], ciphertext)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecryptV2 decrypts ciphertext that may be in v1 or v2 format.
|
||||
// For v2 format, it extracts the key ID and returns it alongside the plaintext.
|
||||
// For v1 format, it returns an empty key ID.
|
||||
func DecryptV2(key, data, additionalData []byte) (plaintext []byte, keyID string, err error) {
|
||||
if len(data) < 1 {
|
||||
return nil, "", ErrInvalidCiphertext
|
||||
}
|
||||
|
||||
switch data[0] {
|
||||
case BarrierVersionV1:
|
||||
pt, err := Decrypt(key, data, additionalData)
|
||||
return pt, "", err
|
||||
|
||||
case BarrierVersionV2:
|
||||
if len(data) < 2 {
|
||||
return nil, "", ErrInvalidCiphertext
|
||||
}
|
||||
kidLen := int(data[1])
|
||||
headerLen := 2 + kidLen
|
||||
if len(data) < headerLen+NonceSize+aes.BlockSize {
|
||||
return nil, "", ErrInvalidCiphertext
|
||||
}
|
||||
|
||||
keyID = string(data[2 : 2+kidLen])
|
||||
nonce := data[headerLen : headerLen+NonceSize]
|
||||
ciphertext := data[headerLen+NonceSize:]
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("crypto: new cipher: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("crypto: new gcm: %w", err)
|
||||
}
|
||||
pt, err := gcm.Open(nil, nonce, ciphertext, additionalData)
|
||||
if err != nil {
|
||||
return nil, "", ErrDecryptionFailed
|
||||
}
|
||||
return pt, keyID, nil
|
||||
|
||||
default:
|
||||
return nil, "", fmt.Errorf("crypto: unsupported version: %d", data[0])
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractKeyID returns the key ID from a v2 ciphertext without decrypting.
|
||||
// Returns empty string for v1 ciphertext.
|
||||
func ExtractKeyID(data []byte) (string, error) {
|
||||
if len(data) < 1 {
|
||||
return "", ErrInvalidCiphertext
|
||||
}
|
||||
switch data[0] {
|
||||
case BarrierVersionV1:
|
||||
return "", nil
|
||||
case BarrierVersionV2:
|
||||
if len(data) < 2 {
|
||||
return "", ErrInvalidCiphertext
|
||||
}
|
||||
kidLen := int(data[1])
|
||||
if len(data) < 2+kidLen {
|
||||
return "", ErrInvalidCiphertext
|
||||
}
|
||||
return string(data[2 : 2+kidLen]), nil
|
||||
default:
|
||||
return "", fmt.Errorf("crypto: unsupported version: %d", data[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Zeroize overwrites a byte slice with zeros.
|
||||
func Zeroize(b []byte) {
|
||||
for i := range b {
|
||||
|
||||
@@ -34,7 +34,7 @@ func TestEncryptDecrypt(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("hello, metacrypt!")
|
||||
|
||||
ciphertext, err := Encrypt(key, plaintext)
|
||||
ciphertext, err := Encrypt(key, plaintext, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
@@ -44,7 +44,7 @@ func TestEncryptDecrypt(t *testing.T) {
|
||||
t.Fatalf("version byte: got %d, want %d", ciphertext[0], BarrierVersion)
|
||||
}
|
||||
|
||||
decrypted, err := Decrypt(key, ciphertext)
|
||||
decrypted, err := Decrypt(key, ciphertext, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt: %v", err)
|
||||
}
|
||||
@@ -54,13 +54,45 @@ func TestEncryptDecrypt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptWithAAD(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("hello, metacrypt!")
|
||||
aad := []byte("engine/ca/pki/root/cert.pem")
|
||||
|
||||
ciphertext, err := Encrypt(key, plaintext, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt with AAD: %v", err)
|
||||
}
|
||||
|
||||
// Decrypt with correct AAD succeeds.
|
||||
decrypted, err := Decrypt(key, ciphertext, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt with AAD: %v", err)
|
||||
}
|
||||
if !bytes.Equal(plaintext, decrypted) {
|
||||
t.Fatalf("roundtrip failed: got %q, want %q", decrypted, plaintext)
|
||||
}
|
||||
|
||||
// Decrypt with wrong AAD fails.
|
||||
_, err = Decrypt(key, ciphertext, []byte("wrong/path"))
|
||||
if !errors.Is(err, ErrDecryptionFailed) {
|
||||
t.Fatalf("expected ErrDecryptionFailed with wrong AAD, got: %v", err)
|
||||
}
|
||||
|
||||
// Decrypt with nil AAD fails.
|
||||
_, err = Decrypt(key, ciphertext, nil)
|
||||
if !errors.Is(err, ErrDecryptionFailed) {
|
||||
t.Fatalf("expected ErrDecryptionFailed with nil AAD, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWrongKey(t *testing.T) {
|
||||
key1, _ := GenerateKey()
|
||||
key2, _ := GenerateKey()
|
||||
plaintext := []byte("secret data")
|
||||
|
||||
ciphertext, _ := Encrypt(key1, plaintext)
|
||||
_, err := Decrypt(key2, ciphertext)
|
||||
ciphertext, _ := Encrypt(key1, plaintext, nil)
|
||||
_, err := Decrypt(key2, ciphertext, nil)
|
||||
if !errors.Is(err, ErrDecryptionFailed) {
|
||||
t.Fatalf("expected ErrDecryptionFailed, got: %v", err)
|
||||
}
|
||||
@@ -68,7 +100,7 @@ func TestDecryptWrongKey(t *testing.T) {
|
||||
|
||||
func TestDecryptInvalidCiphertext(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
_, err := Decrypt(key, []byte("short"))
|
||||
_, err := Decrypt(key, []byte("short"), nil)
|
||||
if !errors.Is(err, ErrInvalidCiphertext) {
|
||||
t.Fatalf("expected ErrInvalidCiphertext, got: %v", err)
|
||||
}
|
||||
@@ -124,10 +156,141 @@ func TestEncryptProducesDifferentCiphertext(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("same data")
|
||||
|
||||
ct1, _ := Encrypt(key, plaintext)
|
||||
ct2, _ := Encrypt(key, plaintext)
|
||||
ct1, _ := Encrypt(key, plaintext, nil)
|
||||
ct2, _ := Encrypt(key, plaintext, nil)
|
||||
|
||||
if bytes.Equal(ct1, ct2) {
|
||||
t.Fatal("two encryptions of same plaintext produced identical ciphertext (nonce reuse)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2EncryptDecryptRoundtrip(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("v2 test data")
|
||||
keyID := "engine/ca/prod"
|
||||
aad := []byte("engine/ca/prod/config.json")
|
||||
|
||||
ciphertext, err := EncryptV2(key, keyID, plaintext, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptV2: %v", err)
|
||||
}
|
||||
|
||||
if ciphertext[0] != BarrierVersionV2 {
|
||||
t.Fatalf("version byte: got %d, want %d", ciphertext[0], BarrierVersionV2)
|
||||
}
|
||||
|
||||
pt, gotKeyID, err := DecryptV2(key, ciphertext, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptV2: %v", err)
|
||||
}
|
||||
if gotKeyID != keyID {
|
||||
t.Fatalf("key ID: got %q, want %q", gotKeyID, keyID)
|
||||
}
|
||||
if !bytes.Equal(plaintext, pt) {
|
||||
t.Fatalf("roundtrip failed: got %q, want %q", pt, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2DecryptV1Compat(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("v1 legacy data")
|
||||
|
||||
// Encrypt with v1.
|
||||
v1ct, err := Encrypt(key, plaintext, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt v1: %v", err)
|
||||
}
|
||||
|
||||
// DecryptV2 should handle v1 ciphertext.
|
||||
pt, keyID, err := DecryptV2(key, v1ct, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptV2 with v1 ciphertext: %v", err)
|
||||
}
|
||||
if keyID != "" {
|
||||
t.Fatalf("expected empty key ID for v1, got %q", keyID)
|
||||
}
|
||||
if !bytes.Equal(plaintext, pt) {
|
||||
t.Fatalf("roundtrip failed: got %q, want %q", pt, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2WrongAAD(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("data")
|
||||
aad := []byte("correct/path")
|
||||
|
||||
ct, _ := EncryptV2(key, "system", plaintext, aad)
|
||||
|
||||
_, _, err := DecryptV2(key, ct, []byte("wrong/path"))
|
||||
if !errors.Is(err, ErrDecryptionFailed) {
|
||||
t.Fatalf("expected ErrDecryptionFailed with wrong AAD, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2WrongKey(t *testing.T) {
|
||||
key1, _ := GenerateKey()
|
||||
key2, _ := GenerateKey()
|
||||
plaintext := []byte("data")
|
||||
|
||||
ct, _ := EncryptV2(key1, "system", plaintext, nil)
|
||||
|
||||
_, _, err := DecryptV2(key2, ct, nil)
|
||||
if !errors.Is(err, ErrDecryptionFailed) {
|
||||
t.Fatalf("expected ErrDecryptionFailed, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractKeyID(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
|
||||
// v1: empty key ID.
|
||||
v1ct, _ := Encrypt(key, []byte("data"), nil)
|
||||
kid, err := ExtractKeyID(v1ct)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractKeyID v1: %v", err)
|
||||
}
|
||||
if kid != "" {
|
||||
t.Fatalf("expected empty key ID for v1, got %q", kid)
|
||||
}
|
||||
|
||||
// v2: embedded key ID.
|
||||
v2ct, _ := EncryptV2(key, "engine/transit/main", []byte("data"), nil)
|
||||
kid, err = ExtractKeyID(v2ct)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractKeyID v2: %v", err)
|
||||
}
|
||||
if kid != "engine/transit/main" {
|
||||
t.Fatalf("key ID: got %q, want %q", kid, "engine/transit/main")
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2KeyIDTooLong(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
longID := string(make([]byte, MaxKeyIDLen+1))
|
||||
|
||||
_, err := EncryptV2(key, longID, []byte("data"), nil)
|
||||
if !errors.Is(err, ErrKeyIDTooLong) {
|
||||
t.Fatalf("expected ErrKeyIDTooLong, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2EmptyKeyID(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("data with empty key id")
|
||||
|
||||
ct, err := EncryptV2(key, "", plaintext, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptV2 empty key ID: %v", err)
|
||||
}
|
||||
|
||||
pt, keyID, err := DecryptV2(key, ct, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptV2 empty key ID: %v", err)
|
||||
}
|
||||
if keyID != "" {
|
||||
t.Fatalf("expected empty key ID, got %q", keyID)
|
||||
}
|
||||
if !bytes.Equal(plaintext, pt) {
|
||||
t.Fatalf("roundtrip failed")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user