Implement a two-level key hierarchy: the MEK now wraps per-engine DEKs stored in a new barrier_keys table, rather than encrypting all barrier entries directly. A v2 ciphertext format (0x02) embeds the key ID so the barrier can resolve which DEK to use on decryption. v1 ciphertext remains supported for backward compatibility. Key changes: - crypto: EncryptV2/DecryptV2/ExtractKeyID for v2 ciphertext with key IDs - barrier: key registry (CreateKey, RotateKey, ListKeys, MigrateToV2, ReWrapKeys) - seal: RotateMEK re-wraps DEKs without re-encrypting data - engine: Mount auto-creates per-engine DEK - REST + gRPC: barrier/keys, barrier/rotate-mek, barrier/rotate-key, barrier/migrate - proto: BarrierService (v1 + v2) with ListKeys, RotateMEK, RotateKey, Migrate - db: migration v2 adds barrier_keys table Also includes: security audit report, CSRF protection, engine design specs (sshca, transit, user), path-bound AAD migration tool, policy engine enhancements, and ARCHITECTURE.md updates. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
297 lines
7.2 KiB
Go
297 lines
7.2 KiB
Go
package crypto
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"testing"
|
|
)
|
|
|
|
func TestGenerateKey(t *testing.T) {
|
|
key, err := GenerateKey()
|
|
if err != nil {
|
|
t.Fatalf("GenerateKey: %v", err)
|
|
}
|
|
if len(key) != KeySize {
|
|
t.Fatalf("key length: got %d, want %d", len(key), KeySize)
|
|
}
|
|
// Should be random (not all zeros).
|
|
if bytes.Equal(key, make([]byte, KeySize)) {
|
|
t.Fatal("key is all zeros")
|
|
}
|
|
}
|
|
|
|
func TestGenerateSalt(t *testing.T) {
|
|
salt, err := GenerateSalt()
|
|
if err != nil {
|
|
t.Fatalf("GenerateSalt: %v", err)
|
|
}
|
|
if len(salt) != SaltSize {
|
|
t.Fatalf("salt length: got %d, want %d", len(salt), SaltSize)
|
|
}
|
|
}
|
|
|
|
func TestEncryptDecrypt(t *testing.T) {
|
|
key, _ := GenerateKey()
|
|
plaintext := []byte("hello, metacrypt!")
|
|
|
|
ciphertext, err := Encrypt(key, plaintext, nil)
|
|
if err != nil {
|
|
t.Fatalf("Encrypt: %v", err)
|
|
}
|
|
|
|
// Version byte should be present.
|
|
if ciphertext[0] != BarrierVersion {
|
|
t.Fatalf("version byte: got %d, want %d", ciphertext[0], BarrierVersion)
|
|
}
|
|
|
|
decrypted, err := Decrypt(key, ciphertext, nil)
|
|
if err != nil {
|
|
t.Fatalf("Decrypt: %v", err)
|
|
}
|
|
|
|
if !bytes.Equal(plaintext, decrypted) {
|
|
t.Fatalf("roundtrip failed: got %q, want %q", decrypted, plaintext)
|
|
}
|
|
}
|
|
|
|
func TestEncryptDecryptWithAAD(t *testing.T) {
|
|
key, _ := GenerateKey()
|
|
plaintext := []byte("hello, metacrypt!")
|
|
aad := []byte("engine/ca/pki/root/cert.pem")
|
|
|
|
ciphertext, err := Encrypt(key, plaintext, aad)
|
|
if err != nil {
|
|
t.Fatalf("Encrypt with AAD: %v", err)
|
|
}
|
|
|
|
// Decrypt with correct AAD succeeds.
|
|
decrypted, err := Decrypt(key, ciphertext, aad)
|
|
if err != nil {
|
|
t.Fatalf("Decrypt with AAD: %v", err)
|
|
}
|
|
if !bytes.Equal(plaintext, decrypted) {
|
|
t.Fatalf("roundtrip failed: got %q, want %q", decrypted, plaintext)
|
|
}
|
|
|
|
// Decrypt with wrong AAD fails.
|
|
_, err = Decrypt(key, ciphertext, []byte("wrong/path"))
|
|
if !errors.Is(err, ErrDecryptionFailed) {
|
|
t.Fatalf("expected ErrDecryptionFailed with wrong AAD, got: %v", err)
|
|
}
|
|
|
|
// Decrypt with nil AAD fails.
|
|
_, err = Decrypt(key, ciphertext, nil)
|
|
if !errors.Is(err, ErrDecryptionFailed) {
|
|
t.Fatalf("expected ErrDecryptionFailed with nil AAD, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestDecryptWrongKey(t *testing.T) {
|
|
key1, _ := GenerateKey()
|
|
key2, _ := GenerateKey()
|
|
plaintext := []byte("secret data")
|
|
|
|
ciphertext, _ := Encrypt(key1, plaintext, nil)
|
|
_, err := Decrypt(key2, ciphertext, nil)
|
|
if !errors.Is(err, ErrDecryptionFailed) {
|
|
t.Fatalf("expected ErrDecryptionFailed, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestDecryptInvalidCiphertext(t *testing.T) {
|
|
key, _ := GenerateKey()
|
|
_, err := Decrypt(key, []byte("short"), nil)
|
|
if !errors.Is(err, ErrInvalidCiphertext) {
|
|
t.Fatalf("expected ErrInvalidCiphertext, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestDeriveKey(t *testing.T) {
|
|
password := []byte("test-password")
|
|
salt, _ := GenerateSalt()
|
|
params := Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
|
|
|
key := DeriveKey(password, salt, params)
|
|
if len(key) != KeySize {
|
|
t.Fatalf("derived key length: got %d, want %d", len(key), KeySize)
|
|
}
|
|
|
|
// Same inputs should produce same output.
|
|
key2 := DeriveKey(password, salt, params)
|
|
if !bytes.Equal(key, key2) {
|
|
t.Fatal("determinism: same inputs produced different keys")
|
|
}
|
|
|
|
// Different password should produce different output.
|
|
key3 := DeriveKey([]byte("different"), salt, params)
|
|
if bytes.Equal(key, key3) {
|
|
t.Fatal("different passwords produced same key")
|
|
}
|
|
}
|
|
|
|
func TestZeroize(t *testing.T) {
|
|
data := []byte{1, 2, 3, 4, 5}
|
|
Zeroize(data)
|
|
for i, b := range data {
|
|
if b != 0 {
|
|
t.Fatalf("byte %d not zeroed: %d", i, b)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestConstantTimeEqual(t *testing.T) {
|
|
a := []byte("hello")
|
|
b := []byte("hello")
|
|
c := []byte("world")
|
|
|
|
if !ConstantTimeEqual(a, b) {
|
|
t.Fatal("equal slices reported as not equal")
|
|
}
|
|
if ConstantTimeEqual(a, c) {
|
|
t.Fatal("different slices reported as equal")
|
|
}
|
|
}
|
|
|
|
func TestEncryptProducesDifferentCiphertext(t *testing.T) {
|
|
key, _ := GenerateKey()
|
|
plaintext := []byte("same data")
|
|
|
|
ct1, _ := Encrypt(key, plaintext, nil)
|
|
ct2, _ := Encrypt(key, plaintext, nil)
|
|
|
|
if bytes.Equal(ct1, ct2) {
|
|
t.Fatal("two encryptions of same plaintext produced identical ciphertext (nonce reuse)")
|
|
}
|
|
}
|
|
|
|
func TestV2EncryptDecryptRoundtrip(t *testing.T) {
|
|
key, _ := GenerateKey()
|
|
plaintext := []byte("v2 test data")
|
|
keyID := "engine/ca/prod"
|
|
aad := []byte("engine/ca/prod/config.json")
|
|
|
|
ciphertext, err := EncryptV2(key, keyID, plaintext, aad)
|
|
if err != nil {
|
|
t.Fatalf("EncryptV2: %v", err)
|
|
}
|
|
|
|
if ciphertext[0] != BarrierVersionV2 {
|
|
t.Fatalf("version byte: got %d, want %d", ciphertext[0], BarrierVersionV2)
|
|
}
|
|
|
|
pt, gotKeyID, err := DecryptV2(key, ciphertext, aad)
|
|
if err != nil {
|
|
t.Fatalf("DecryptV2: %v", err)
|
|
}
|
|
if gotKeyID != keyID {
|
|
t.Fatalf("key ID: got %q, want %q", gotKeyID, keyID)
|
|
}
|
|
if !bytes.Equal(plaintext, pt) {
|
|
t.Fatalf("roundtrip failed: got %q, want %q", pt, plaintext)
|
|
}
|
|
}
|
|
|
|
func TestV2DecryptV1Compat(t *testing.T) {
|
|
key, _ := GenerateKey()
|
|
plaintext := []byte("v1 legacy data")
|
|
|
|
// Encrypt with v1.
|
|
v1ct, err := Encrypt(key, plaintext, nil)
|
|
if err != nil {
|
|
t.Fatalf("Encrypt v1: %v", err)
|
|
}
|
|
|
|
// DecryptV2 should handle v1 ciphertext.
|
|
pt, keyID, err := DecryptV2(key, v1ct, nil)
|
|
if err != nil {
|
|
t.Fatalf("DecryptV2 with v1 ciphertext: %v", err)
|
|
}
|
|
if keyID != "" {
|
|
t.Fatalf("expected empty key ID for v1, got %q", keyID)
|
|
}
|
|
if !bytes.Equal(plaintext, pt) {
|
|
t.Fatalf("roundtrip failed: got %q, want %q", pt, plaintext)
|
|
}
|
|
}
|
|
|
|
func TestV2WrongAAD(t *testing.T) {
|
|
key, _ := GenerateKey()
|
|
plaintext := []byte("data")
|
|
aad := []byte("correct/path")
|
|
|
|
ct, _ := EncryptV2(key, "system", plaintext, aad)
|
|
|
|
_, _, err := DecryptV2(key, ct, []byte("wrong/path"))
|
|
if !errors.Is(err, ErrDecryptionFailed) {
|
|
t.Fatalf("expected ErrDecryptionFailed with wrong AAD, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestV2WrongKey(t *testing.T) {
|
|
key1, _ := GenerateKey()
|
|
key2, _ := GenerateKey()
|
|
plaintext := []byte("data")
|
|
|
|
ct, _ := EncryptV2(key1, "system", plaintext, nil)
|
|
|
|
_, _, err := DecryptV2(key2, ct, nil)
|
|
if !errors.Is(err, ErrDecryptionFailed) {
|
|
t.Fatalf("expected ErrDecryptionFailed, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestExtractKeyID(t *testing.T) {
|
|
key, _ := GenerateKey()
|
|
|
|
// v1: empty key ID.
|
|
v1ct, _ := Encrypt(key, []byte("data"), nil)
|
|
kid, err := ExtractKeyID(v1ct)
|
|
if err != nil {
|
|
t.Fatalf("ExtractKeyID v1: %v", err)
|
|
}
|
|
if kid != "" {
|
|
t.Fatalf("expected empty key ID for v1, got %q", kid)
|
|
}
|
|
|
|
// v2: embedded key ID.
|
|
v2ct, _ := EncryptV2(key, "engine/transit/main", []byte("data"), nil)
|
|
kid, err = ExtractKeyID(v2ct)
|
|
if err != nil {
|
|
t.Fatalf("ExtractKeyID v2: %v", err)
|
|
}
|
|
if kid != "engine/transit/main" {
|
|
t.Fatalf("key ID: got %q, want %q", kid, "engine/transit/main")
|
|
}
|
|
}
|
|
|
|
func TestV2KeyIDTooLong(t *testing.T) {
|
|
key, _ := GenerateKey()
|
|
longID := string(make([]byte, MaxKeyIDLen+1))
|
|
|
|
_, err := EncryptV2(key, longID, []byte("data"), nil)
|
|
if !errors.Is(err, ErrKeyIDTooLong) {
|
|
t.Fatalf("expected ErrKeyIDTooLong, got: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestV2EmptyKeyID(t *testing.T) {
|
|
key, _ := GenerateKey()
|
|
plaintext := []byte("data with empty key id")
|
|
|
|
ct, err := EncryptV2(key, "", plaintext, nil)
|
|
if err != nil {
|
|
t.Fatalf("EncryptV2 empty key ID: %v", err)
|
|
}
|
|
|
|
pt, keyID, err := DecryptV2(key, ct, nil)
|
|
if err != nil {
|
|
t.Fatalf("DecryptV2 empty key ID: %v", err)
|
|
}
|
|
if keyID != "" {
|
|
t.Fatalf("expected empty key ID, got %q", keyID)
|
|
}
|
|
if !bytes.Equal(plaintext, pt) {
|
|
t.Fatalf("roundtrip failed")
|
|
}
|
|
}
|