package vault import ( "crypto/ed25519" "crypto/rand" "errors" "sync" "testing" ) func generateTestKeys(t *testing.T) ([]byte, ed25519.PrivateKey, ed25519.PublicKey) { t.Helper() pub, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatalf("generate key: %v", err) } mk := make([]byte, 32) if _, err := rand.Read(mk); err != nil { t.Fatalf("generate master key: %v", err) } return mk, priv, pub } func TestNewSealed(t *testing.T) { v := NewSealed() if !v.IsSealed() { t.Fatal("NewSealed() should be sealed") } if _, err := v.MasterKey(); !errors.Is(err, ErrSealed) { t.Fatalf("MasterKey() error = %v, want ErrSealed", err) } if _, err := v.PrivKey(); !errors.Is(err, ErrSealed) { t.Fatalf("PrivKey() error = %v, want ErrSealed", err) } if _, err := v.PubKey(); !errors.Is(err, ErrSealed) { t.Fatalf("PubKey() error = %v, want ErrSealed", err) } } func TestNewUnsealed(t *testing.T) { mk, priv, pub := generateTestKeys(t) v := NewUnsealed(mk, priv, pub) if v.IsSealed() { t.Fatal("NewUnsealed() should not be sealed") } gotMK, err := v.MasterKey() if err != nil { t.Fatalf("MasterKey() error = %v", err) } if len(gotMK) != 32 { t.Fatalf("MasterKey() len = %d, want 32", len(gotMK)) } } func TestUnsealFromSealed(t *testing.T) { mk, priv, pub := generateTestKeys(t) v := NewSealed() if err := v.Unseal(mk, priv, pub); err != nil { t.Fatalf("Unseal() error = %v", err) } if v.IsSealed() { t.Fatal("should be unsealed after Unseal()") } gotPriv, err := v.PrivKey() if err != nil { t.Fatalf("PrivKey() error = %v", err) } if !priv.Equal(gotPriv) { t.Fatal("PrivKey() mismatch") } } func TestUnsealAlreadyUnsealed(t *testing.T) { mk, priv, pub := generateTestKeys(t) v := NewUnsealed(mk, priv, pub) if err := v.Unseal(mk, priv, pub); err == nil { t.Fatal("Unseal() on unsealed vault should return error") } } func TestSealZeroesKeys(t *testing.T) { mk, priv, pub := generateTestKeys(t) // Keep references to the backing arrays so we can verify zeroing. mkRef := mk privRef := priv v := NewUnsealed(mk, priv, pub) v.Seal() if !v.IsSealed() { t.Fatal("should be sealed after Seal()") } // Verify the original backing arrays were zeroed. for i, b := range mkRef { if b != 0 { t.Fatalf("masterKey[%d] = %d, want 0", i, b) } } for i, b := range privRef { if b != 0 { t.Fatalf("privKey[%d] = %d, want 0", i, b) } } } func TestSealUnsealCycle(t *testing.T) { mk, priv, pub := generateTestKeys(t) v := NewUnsealed(mk, priv, pub) v.Seal() mk2, priv2, pub2 := generateTestKeys(t) if err := v.Unseal(mk2, priv2, pub2); err != nil { t.Fatalf("Unseal() after Seal() error = %v", err) } gotPub, err := v.PubKey() if err != nil { t.Fatalf("PubKey() error = %v", err) } if !pub2.Equal(gotPub) { t.Fatal("PubKey() mismatch after re-unseal") } } func TestConcurrentAccess(t *testing.T) { mk, priv, pub := generateTestKeys(t) v := NewUnsealed(mk, priv, pub) var wg sync.WaitGroup // Concurrent readers. for range 50 { wg.Add(1) go func() { defer wg.Done() _ = v.IsSealed() _, _ = v.MasterKey() _, _ = v.PrivKey() _, _ = v.PubKey() }() } // Concurrent seal/unseal cycles. for range 10 { wg.Add(1) go func() { defer wg.Done() v.Seal() mk2, priv2, pub2 := generateTestKeys(t) _ = v.Unseal(mk2, priv2, pub2) }() } wg.Wait() }