package crypto import ( "bytes" "errors" "testing" ) func TestGenerateKey(t *testing.T) { key, err := GenerateKey() if err != nil { t.Fatalf("GenerateKey: %v", err) } if len(key) != KeySize { t.Fatalf("key length: got %d, want %d", len(key), KeySize) } // Should be random (not all zeros). if bytes.Equal(key, make([]byte, KeySize)) { t.Fatal("key is all zeros") } } func TestGenerateSalt(t *testing.T) { salt, err := GenerateSalt() if err != nil { t.Fatalf("GenerateSalt: %v", err) } if len(salt) != SaltSize { t.Fatalf("salt length: got %d, want %d", len(salt), SaltSize) } } func TestEncryptDecrypt(t *testing.T) { key, _ := GenerateKey() plaintext := []byte("hello, metacrypt!") ciphertext, err := Encrypt(key, plaintext, nil) if err != nil { t.Fatalf("Encrypt: %v", err) } // Version byte should be present. if ciphertext[0] != BarrierVersion { t.Fatalf("version byte: got %d, want %d", ciphertext[0], BarrierVersion) } decrypted, err := Decrypt(key, ciphertext, nil) if err != nil { t.Fatalf("Decrypt: %v", err) } if !bytes.Equal(plaintext, decrypted) { t.Fatalf("roundtrip failed: got %q, want %q", decrypted, plaintext) } } func TestEncryptDecryptWithAAD(t *testing.T) { key, _ := GenerateKey() plaintext := []byte("hello, metacrypt!") aad := []byte("engine/ca/pki/root/cert.pem") ciphertext, err := Encrypt(key, plaintext, aad) if err != nil { t.Fatalf("Encrypt with AAD: %v", err) } // Decrypt with correct AAD succeeds. decrypted, err := Decrypt(key, ciphertext, aad) if err != nil { t.Fatalf("Decrypt with AAD: %v", err) } if !bytes.Equal(plaintext, decrypted) { t.Fatalf("roundtrip failed: got %q, want %q", decrypted, plaintext) } // Decrypt with wrong AAD fails. _, err = Decrypt(key, ciphertext, []byte("wrong/path")) if !errors.Is(err, ErrDecryptionFailed) { t.Fatalf("expected ErrDecryptionFailed with wrong AAD, got: %v", err) } // Decrypt with nil AAD fails. _, err = Decrypt(key, ciphertext, nil) if !errors.Is(err, ErrDecryptionFailed) { t.Fatalf("expected ErrDecryptionFailed with nil AAD, got: %v", err) } } func TestDecryptWrongKey(t *testing.T) { key1, _ := GenerateKey() key2, _ := GenerateKey() plaintext := []byte("secret data") ciphertext, _ := Encrypt(key1, plaintext, nil) _, err := Decrypt(key2, ciphertext, nil) if !errors.Is(err, ErrDecryptionFailed) { t.Fatalf("expected ErrDecryptionFailed, got: %v", err) } } func TestDecryptInvalidCiphertext(t *testing.T) { key, _ := GenerateKey() _, err := Decrypt(key, []byte("short"), nil) if !errors.Is(err, ErrInvalidCiphertext) { t.Fatalf("expected ErrInvalidCiphertext, got: %v", err) } } func TestDeriveKey(t *testing.T) { password := []byte("test-password") salt, _ := GenerateSalt() params := Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1} key := DeriveKey(password, salt, params) if len(key) != KeySize { t.Fatalf("derived key length: got %d, want %d", len(key), KeySize) } // Same inputs should produce same output. key2 := DeriveKey(password, salt, params) if !bytes.Equal(key, key2) { t.Fatal("determinism: same inputs produced different keys") } // Different password should produce different output. key3 := DeriveKey([]byte("different"), salt, params) if bytes.Equal(key, key3) { t.Fatal("different passwords produced same key") } } func TestZeroize(t *testing.T) { data := []byte{1, 2, 3, 4, 5} Zeroize(data) for i, b := range data { if b != 0 { t.Fatalf("byte %d not zeroed: %d", i, b) } } } func TestConstantTimeEqual(t *testing.T) { a := []byte("hello") b := []byte("hello") c := []byte("world") if !ConstantTimeEqual(a, b) { t.Fatal("equal slices reported as not equal") } if ConstantTimeEqual(a, c) { t.Fatal("different slices reported as equal") } } func TestEncryptProducesDifferentCiphertext(t *testing.T) { key, _ := GenerateKey() plaintext := []byte("same data") ct1, _ := Encrypt(key, plaintext, nil) ct2, _ := Encrypt(key, plaintext, nil) if bytes.Equal(ct1, ct2) { t.Fatal("two encryptions of same plaintext produced identical ciphertext (nonce reuse)") } } func TestV2EncryptDecryptRoundtrip(t *testing.T) { key, _ := GenerateKey() plaintext := []byte("v2 test data") keyID := "engine/ca/prod" aad := []byte("engine/ca/prod/config.json") ciphertext, err := EncryptV2(key, keyID, plaintext, aad) if err != nil { t.Fatalf("EncryptV2: %v", err) } if ciphertext[0] != BarrierVersionV2 { t.Fatalf("version byte: got %d, want %d", ciphertext[0], BarrierVersionV2) } pt, gotKeyID, err := DecryptV2(key, ciphertext, aad) if err != nil { t.Fatalf("DecryptV2: %v", err) } if gotKeyID != keyID { t.Fatalf("key ID: got %q, want %q", gotKeyID, keyID) } if !bytes.Equal(plaintext, pt) { t.Fatalf("roundtrip failed: got %q, want %q", pt, plaintext) } } func TestV2DecryptV1Compat(t *testing.T) { key, _ := GenerateKey() plaintext := []byte("v1 legacy data") // Encrypt with v1. v1ct, err := Encrypt(key, plaintext, nil) if err != nil { t.Fatalf("Encrypt v1: %v", err) } // DecryptV2 should handle v1 ciphertext. pt, keyID, err := DecryptV2(key, v1ct, nil) if err != nil { t.Fatalf("DecryptV2 with v1 ciphertext: %v", err) } if keyID != "" { t.Fatalf("expected empty key ID for v1, got %q", keyID) } if !bytes.Equal(plaintext, pt) { t.Fatalf("roundtrip failed: got %q, want %q", pt, plaintext) } } func TestV2WrongAAD(t *testing.T) { key, _ := GenerateKey() plaintext := []byte("data") aad := []byte("correct/path") ct, _ := EncryptV2(key, "system", plaintext, aad) _, _, err := DecryptV2(key, ct, []byte("wrong/path")) if !errors.Is(err, ErrDecryptionFailed) { t.Fatalf("expected ErrDecryptionFailed with wrong AAD, got: %v", err) } } func TestV2WrongKey(t *testing.T) { key1, _ := GenerateKey() key2, _ := GenerateKey() plaintext := []byte("data") ct, _ := EncryptV2(key1, "system", plaintext, nil) _, _, err := DecryptV2(key2, ct, nil) if !errors.Is(err, ErrDecryptionFailed) { t.Fatalf("expected ErrDecryptionFailed, got: %v", err) } } func TestExtractKeyID(t *testing.T) { key, _ := GenerateKey() // v1: empty key ID. v1ct, _ := Encrypt(key, []byte("data"), nil) kid, err := ExtractKeyID(v1ct) if err != nil { t.Fatalf("ExtractKeyID v1: %v", err) } if kid != "" { t.Fatalf("expected empty key ID for v1, got %q", kid) } // v2: embedded key ID. v2ct, _ := EncryptV2(key, "engine/transit/main", []byte("data"), nil) kid, err = ExtractKeyID(v2ct) if err != nil { t.Fatalf("ExtractKeyID v2: %v", err) } if kid != "engine/transit/main" { t.Fatalf("key ID: got %q, want %q", kid, "engine/transit/main") } } func TestV2KeyIDTooLong(t *testing.T) { key, _ := GenerateKey() longID := string(make([]byte, MaxKeyIDLen+1)) _, err := EncryptV2(key, longID, []byte("data"), nil) if !errors.Is(err, ErrKeyIDTooLong) { t.Fatalf("expected ErrKeyIDTooLong, got: %v", err) } } func TestV2EmptyKeyID(t *testing.T) { key, _ := GenerateKey() plaintext := []byte("data with empty key id") ct, err := EncryptV2(key, "", plaintext, nil) if err != nil { t.Fatalf("EncryptV2 empty key ID: %v", err) } pt, keyID, err := DecryptV2(key, ct, nil) if err != nil { t.Fatalf("DecryptV2 empty key ID: %v", err) } if keyID != "" { t.Fatalf("expected empty key ID, got %q", keyID) } if !bytes.Equal(plaintext, pt) { t.Fatalf("roundtrip failed") } }