package sshca import ( "context" "crypto/ed25519" "crypto/rand" "encoding/binary" "errors" "strings" "sync" "testing" "golang.org/x/crypto/ssh" "git.wntrmute.dev/kyle/metacrypt/internal/barrier" "git.wntrmute.dev/kyle/metacrypt/internal/engine" ) // memBarrier is an in-memory barrier for testing. type memBarrier struct { data map[string][]byte mu sync.RWMutex } func newMemBarrier() *memBarrier { return &memBarrier{data: make(map[string][]byte)} } func (m *memBarrier) Unseal(_ []byte) error { return nil } func (m *memBarrier) Seal() error { return nil } func (m *memBarrier) IsSealed() bool { return false } func (m *memBarrier) Get(_ context.Context, path string) ([]byte, error) { m.mu.RLock() defer m.mu.RUnlock() v, ok := m.data[path] if !ok { return nil, barrier.ErrNotFound } cp := make([]byte, len(v)) copy(cp, v) return cp, nil } func (m *memBarrier) Put(_ context.Context, path string, value []byte) error { m.mu.Lock() defer m.mu.Unlock() cp := make([]byte, len(value)) copy(cp, value) m.data[path] = cp return nil } func (m *memBarrier) Delete(_ context.Context, path string) error { m.mu.Lock() defer m.mu.Unlock() delete(m.data, path) return nil } func (m *memBarrier) List(_ context.Context, prefix string) ([]string, error) { m.mu.RLock() defer m.mu.RUnlock() var paths []string for k := range m.data { if strings.HasPrefix(k, prefix) { paths = append(paths, strings.TrimPrefix(k, prefix)) } } return paths, nil } func adminCaller() *engine.CallerInfo { return &engine.CallerInfo{Username: "admin", Roles: []string{"admin"}, IsAdmin: true} } func userCaller() *engine.CallerInfo { return &engine.CallerInfo{Username: "user", Roles: []string{"user"}, IsAdmin: false} } func guestCaller() *engine.CallerInfo { return &engine.CallerInfo{Username: "guest", Roles: []string{"guest"}, IsAdmin: false} } func setupEngine(t *testing.T) (*SSHCAEngine, *memBarrier) { t.Helper() b := newMemBarrier() eng := NewSSHCAEngine().(*SSHCAEngine) //nolint:errcheck ctx := context.Background() config := map[string]interface{}{ "key_algorithm": "ed25519", "max_ttl": "87600h", "default_ttl": "24h", } if err := eng.Initialize(ctx, b, "engine/sshca/test/", config); err != nil { t.Fatalf("Initialize: %v", err) } return eng, b } func generateTestPubKey(t *testing.T) string { t.Helper() pub, _, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatalf("generate test key: %v", err) } sshPub, err := ssh.NewPublicKey(pub) if err != nil { t.Fatalf("create ssh public key: %v", err) } return string(ssh.MarshalAuthorizedKey(sshPub)) } func TestInitializeGeneratesCAKey(t *testing.T) { eng, _ := setupEngine(t) if eng.caKey == nil { t.Fatal("CA key is nil") } if eng.caSigner == nil { t.Fatal("CA signer is nil") } if eng.config == nil { t.Fatal("config is nil") } if eng.config.KeyAlgorithm != "ed25519" { t.Errorf("key algorithm: got %q, want %q", eng.config.KeyAlgorithm, "ed25519") } } func TestUnsealSealLifecycle(t *testing.T) { eng, b := setupEngine(t) mountPath := "engine/sshca/test/" // Seal and verify state is cleared. if err := eng.Seal(); err != nil { t.Fatalf("Seal: %v", err) } if eng.caKey != nil { t.Error("caKey should be nil after seal") } if eng.caSigner != nil { t.Error("caSigner should be nil after seal") } if eng.config != nil { t.Error("config should be nil after seal") } // Unseal and verify state is restored. ctx := context.Background() if err := eng.Unseal(ctx, b, mountPath); err != nil { t.Fatalf("Unseal: %v", err) } if eng.caKey == nil { t.Error("caKey should be non-nil after unseal") } if eng.caSigner == nil { t.Error("caSigner should be non-nil after unseal") } if eng.config == nil { t.Error("config should be non-nil after unseal") } } func TestSignHost(t *testing.T) { eng, _ := setupEngine(t) ctx := context.Background() pubKey := generateTestPubKey(t) resp, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "sign-host", CallerInfo: adminCaller(), Data: map[string]interface{}{ "public_key": pubKey, "hostname": "web.example.com", }, }) if err != nil { t.Fatalf("sign-host: %v", err) } if resp.Data["cert_type"] != "host" { t.Errorf("cert_type: got %v, want %q", resp.Data["cert_type"], "host") } if resp.Data["serial"] == nil || resp.Data["serial"] == "" { t.Error("serial should not be empty") } if resp.Data["cert_data"] == nil { t.Error("cert_data should not be nil") } // Verify the certificate is parseable. certData := resp.Data["cert_data"].(string) //nolint:errcheck sshPubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(certData)) if err != nil { t.Fatalf("parse cert: %v", err) } cert, ok := sshPubKey.(*ssh.Certificate) if !ok { t.Fatal("parsed key is not a certificate") } if cert.CertType != ssh.HostCert { t.Errorf("cert type: got %d, want %d", cert.CertType, ssh.HostCert) } if len(cert.ValidPrincipals) != 1 || cert.ValidPrincipals[0] != "web.example.com" { t.Errorf("principals: got %v", cert.ValidPrincipals) } } func TestSignHostTTLEnforcement(t *testing.T) { eng, _ := setupEngine(t) ctx := context.Background() pubKey := generateTestPubKey(t) // Should fail: TTL exceeds max. _, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "sign-host", CallerInfo: adminCaller(), Data: map[string]interface{}{ "public_key": pubKey, "hostname": "web.example.com", "ttl": "999999h", }, }) if err == nil { t.Fatal("expected error for TTL exceeding max") } if !strings.Contains(err.Error(), "exceeds maximum") { t.Errorf("expected 'exceeds maximum' error, got: %v", err) } } func TestSignUser(t *testing.T) { eng, _ := setupEngine(t) ctx := context.Background() pubKey := generateTestPubKey(t) // Default: signs for own username. resp, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "sign-user", CallerInfo: userCaller(), Data: map[string]interface{}{ "public_key": pubKey, }, }) if err != nil { t.Fatalf("sign-user: %v", err) } if resp.Data["cert_type"] != "user" { t.Errorf("cert_type: got %v, want %q", resp.Data["cert_type"], "user") } // Verify principals. principals := resp.Data["principals"].([]interface{}) //nolint:errcheck if len(principals) != 1 || principals[0] != "user" { t.Errorf("principals: got %v, want [user]", principals) } // Verify extensions include permit-pty. certData := resp.Data["cert_data"].(string) //nolint:errcheck sshPubKey, _, _, _, _ := ssh.ParseAuthorizedKey([]byte(certData)) cert := sshPubKey.(*ssh.Certificate) //nolint:errcheck if _, ok := cert.Permissions.Extensions["permit-pty"]; !ok { t.Error("expected permit-pty extension") } } func TestSignUserOwnPrincipalDefault(t *testing.T) { eng, _ := setupEngine(t) ctx := context.Background() pubKey := generateTestPubKey(t) // Non-admin cannot sign for another principal without policy. _, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "sign-user", CallerInfo: userCaller(), Data: map[string]interface{}{ "public_key": pubKey, "principals": []interface{}{"someone-else"}, }, }) if err == nil { t.Fatal("expected error for non-admin signing for another principal") } if !strings.Contains(err.Error(), "forbidden") { t.Errorf("expected forbidden error, got: %v", err) } // Admin can sign for any principal. _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "sign-user", CallerInfo: adminCaller(), Data: map[string]interface{}{ "public_key": pubKey, "principals": []interface{}{"someone-else"}, }, }) if err != nil { t.Fatalf("admin should sign for any principal: %v", err) } } func TestSignUserProfileMerging(t *testing.T) { eng, _ := setupEngine(t) ctx := context.Background() // Create a profile. _, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "create-profile", CallerInfo: adminCaller(), Data: map[string]interface{}{ "name": "restricted", "extensions": map[string]interface{}{ "permit-pty": "", "permit-port-forwarding": "", }, "critical_options": map[string]interface{}{ "force-command": "/bin/date", }, "allowed_principals": []interface{}{"user", "admin"}, }, }) if err != nil { t.Fatalf("create-profile: %v", err) } pubKey := generateTestPubKey(t) // Sign with profile. resp, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "sign-user", CallerInfo: userCaller(), Data: map[string]interface{}{ "public_key": pubKey, "profile": "restricted", }, }) if err != nil { t.Fatalf("sign-user with profile: %v", err) } // Verify extensions are merged (profile wins on conflict). certData := resp.Data["cert_data"].(string) //nolint:errcheck sshPubKey, _, _, _, _ := ssh.ParseAuthorizedKey([]byte(certData)) cert := sshPubKey.(*ssh.Certificate) //nolint:errcheck if _, ok := cert.Permissions.Extensions["permit-port-forwarding"]; !ok { t.Error("expected permit-port-forwarding extension from profile") } if cert.Permissions.CriticalOptions["force-command"] != "/bin/date" { t.Error("expected force-command critical option from profile") } } func TestSignUserProfileEnforcesPrincipals(t *testing.T) { eng, _ := setupEngine(t) ctx := context.Background() _, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "create-profile", CallerInfo: adminCaller(), Data: map[string]interface{}{ "name": "limited", "allowed_principals": []interface{}{"allowed-user"}, }, }) if err != nil { t.Fatalf("create-profile: %v", err) } pubKey := generateTestPubKey(t) // Should fail: principal not in profile's allowed list. _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "sign-user", CallerInfo: adminCaller(), Data: map[string]interface{}{ "public_key": pubKey, "profile": "limited", "principals": []interface{}{"not-allowed"}, }, }) if err == nil { t.Fatal("expected error for principal not in allowed list") } if !strings.Contains(err.Error(), "not allowed by profile") { t.Errorf("expected 'not allowed' error, got: %v", err) } } func TestProfileCRUD(t *testing.T) { eng, _ := setupEngine(t) ctx := context.Background() // Create. _, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "create-profile", CallerInfo: adminCaller(), Data: map[string]interface{}{ "name": "myprofile", "extensions": map[string]interface{}{ "permit-pty": "", }, }, }) if err != nil { t.Fatalf("create-profile: %v", err) } // Duplicate create should fail. _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "create-profile", CallerInfo: adminCaller(), Data: map[string]interface{}{ "name": "myprofile", }, }) if !errors.Is(err, ErrProfileExists) { t.Errorf("expected ErrProfileExists, got: %v", err) } // Get. resp, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "get-profile", CallerInfo: userCaller(), Data: map[string]interface{}{ "name": "myprofile", }, }) if err != nil { t.Fatalf("get-profile: %v", err) } if resp.Data["name"] != "myprofile" { t.Errorf("profile name: got %v", resp.Data["name"]) } // List. resp, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "list-profiles", CallerInfo: userCaller(), }) if err != nil { t.Fatalf("list-profiles: %v", err) } profiles := resp.Data["profiles"].([]interface{}) //nolint:errcheck if len(profiles) != 1 { t.Errorf("expected 1 profile, got %d", len(profiles)) } // Update. _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "update-profile", CallerInfo: adminCaller(), Data: map[string]interface{}{ "name": "myprofile", "max_ttl": "48h", }, }) if err != nil { t.Fatalf("update-profile: %v", err) } // Verify update. resp, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "get-profile", CallerInfo: userCaller(), Data: map[string]interface{}{ "name": "myprofile", }, }) if err != nil { t.Fatalf("get-profile after update: %v", err) } if resp.Data["max_ttl"] != "48h" { t.Errorf("max_ttl: got %v, want %q", resp.Data["max_ttl"], "48h") } // Delete. _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "delete-profile", CallerInfo: adminCaller(), Data: map[string]interface{}{ "name": "myprofile", }, }) if err != nil { t.Fatalf("delete-profile: %v", err) } // Verify deleted. _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "get-profile", CallerInfo: userCaller(), Data: map[string]interface{}{ "name": "myprofile", }, }) if !errors.Is(err, ErrProfileNotFound) { t.Errorf("expected ErrProfileNotFound, got: %v", err) } } func TestCertListGetRevokeDelete(t *testing.T) { eng, _ := setupEngine(t) ctx := context.Background() pubKey := generateTestPubKey(t) // Sign two certs. var serials []string for _, hostname := range []string{"a.example.com", "b.example.com"} { resp, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "sign-host", CallerInfo: adminCaller(), Data: map[string]interface{}{ "public_key": pubKey, "hostname": hostname, }, }) if err != nil { t.Fatalf("sign-host %s: %v", hostname, err) } serials = append(serials, resp.Data["serial"].(string)) //nolint:errcheck } // List certs. resp, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "list-certs", CallerInfo: userCaller(), }) if err != nil { t.Fatalf("list-certs: %v", err) } certs := resp.Data["certs"].([]interface{}) //nolint:errcheck if len(certs) != 2 { t.Errorf("expected 2 certs, got %d", len(certs)) } // Get cert. resp, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "get-cert", CallerInfo: userCaller(), Data: map[string]interface{}{ "serial": serials[0], }, }) if err != nil { t.Fatalf("get-cert: %v", err) } if resp.Data["serial"] != serials[0] { t.Errorf("serial: got %v, want %v", resp.Data["serial"], serials[0]) } // Revoke cert. resp, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "revoke-cert", CallerInfo: adminCaller(), Data: map[string]interface{}{ "serial": serials[0], }, }) if err != nil { t.Fatalf("revoke-cert: %v", err) } if resp.Data["revoked_at"] == nil { t.Error("revoked_at should not be nil") } // Verify revoked. resp, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "get-cert", CallerInfo: userCaller(), Data: map[string]interface{}{ "serial": serials[0], }, }) if err != nil { t.Fatalf("get-cert after revoke: %v", err) } if resp.Data["revoked"] != true { t.Error("cert should be revoked") } // Delete cert. _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "delete-cert", CallerInfo: adminCaller(), Data: map[string]interface{}{ "serial": serials[1], }, }) if err != nil { t.Fatalf("delete-cert: %v", err) } // Verify deleted. _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "get-cert", CallerInfo: userCaller(), Data: map[string]interface{}{ "serial": serials[1], }, }) if !errors.Is(err, ErrCertNotFound) { t.Errorf("expected ErrCertNotFound, got: %v", err) } } func TestKRLContainsRevokedSerials(t *testing.T) { eng, _ := setupEngine(t) ctx := context.Background() pubKey := generateTestPubKey(t) // Sign a cert. resp, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "sign-host", CallerInfo: adminCaller(), Data: map[string]interface{}{ "public_key": pubKey, "hostname": "revoke-me.example.com", }, }) if err != nil { t.Fatalf("sign-host: %v", err) } serial := resp.Data["serial"].(string) //nolint:errcheck // Revoke it. _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "revoke-cert", CallerInfo: adminCaller(), Data: map[string]interface{}{ "serial": serial, }, }) if err != nil { t.Fatalf("revoke-cert: %v", err) } // Get KRL. krlResp, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "get-krl", }) if err != nil { t.Fatalf("get-krl: %v", err) } krlData := []byte(krlResp.Data["krl"].(string)) //nolint:errcheck if len(krlData) < 12 { t.Fatal("KRL data too short") } // Verify magic. magic := string(krlData[:12]) if magic != "OPENSSH_KRL\x00" { t.Errorf("KRL magic: got %q", magic) } // KRL should contain a certificate section since there are revoked serials. // The section starts after the header (12 + 4 + 8 + 8 + 8 + 4 + 4 = 48 bytes). if len(krlData) <= 48 { t.Error("KRL should contain certificate section with revoked serials") } // Verify the section type is 0x01 (KRL_SECTION_CERTIFICATES). if krlData[48] != 0x01 { t.Errorf("expected section type 0x01, got 0x%02x", krlData[48]) } // Verify the KRL contains the revoked serial somewhere in the data. // Parse the serial from the response. var serialUint uint64 for i := 0; i < len(serial); i++ { serialUint = serialUint*10 + uint64(serial[i]-'0') } var serialBytes [8]byte binary.BigEndian.PutUint64(serialBytes[:], serialUint) found := false for i := 48; i <= len(krlData)-8; i++ { if krlData[i] == serialBytes[0] && krlData[i+1] == serialBytes[1] && krlData[i+2] == serialBytes[2] && krlData[i+3] == serialBytes[3] && krlData[i+4] == serialBytes[4] && krlData[i+5] == serialBytes[5] && krlData[i+6] == serialBytes[6] && krlData[i+7] == serialBytes[7] { found = true break } } if !found { t.Error("KRL should contain the revoked serial") } } func TestAuthEnforcement(t *testing.T) { eng, _ := setupEngine(t) ctx := context.Background() pubKey := generateTestPubKey(t) // Guest rejected for sign-host. _, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "sign-host", CallerInfo: guestCaller(), Data: map[string]interface{}{ "public_key": pubKey, "hostname": "test.example.com", }, }) if !errors.Is(err, ErrForbidden) { t.Errorf("expected ErrForbidden for guest sign-host, got: %v", err) } // Guest rejected for sign-user. _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "sign-user", CallerInfo: guestCaller(), Data: map[string]interface{}{ "public_key": pubKey, }, }) if !errors.Is(err, ErrForbidden) { t.Errorf("expected ErrForbidden for guest sign-user, got: %v", err) } // Nil caller rejected. _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "sign-host", Data: map[string]interface{}{ "public_key": pubKey, "hostname": "test.example.com", }, }) if !errors.Is(err, ErrUnauthorized) { t.Errorf("expected ErrUnauthorized for nil caller, got: %v", err) } // Admin-only operations reject non-admin. for _, op := range []string{"create-profile", "update-profile", "delete-profile", "revoke-cert", "delete-cert"} { _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: op, CallerInfo: userCaller(), Data: map[string]interface{}{ "name": "test", "serial": "123", }, }) if !errors.Is(err, ErrForbidden) { t.Errorf("expected ErrForbidden for user %s, got: %v", op, err) } } // User can read profiles and certs. _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "list-profiles", CallerInfo: userCaller(), }) if err != nil { t.Errorf("user should list-profiles: %v", err) } _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "list-certs", CallerInfo: userCaller(), }) if err != nil { t.Errorf("user should list-certs: %v", err) } // Guest cannot list profiles or certs. _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "list-profiles", CallerInfo: guestCaller(), }) if !errors.Is(err, ErrForbidden) { t.Errorf("expected ErrForbidden for guest list-profiles, got: %v", err) } _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "list-certs", CallerInfo: guestCaller(), }) if !errors.Is(err, ErrForbidden) { t.Errorf("expected ErrForbidden for guest list-certs, got: %v", err) } } func TestGetCAPubkey(t *testing.T) { eng, _ := setupEngine(t) ctx := context.Background() resp, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "get-ca-pubkey", }) if err != nil { t.Fatalf("get-ca-pubkey: %v", err) } pubKeyStr := resp.Data["public_key"].(string) //nolint:errcheck if pubKeyStr == "" { t.Error("public_key should not be empty") } // Should be parseable as SSH public key. _, _, _, _, err = ssh.ParseAuthorizedKey([]byte(pubKeyStr)) if err != nil { t.Errorf("parse public key: %v", err) } } func TestUnsealRestoresState(t *testing.T) { eng, b := setupEngine(t) ctx := context.Background() mountPath := "engine/sshca/test/" pubKey := generateTestPubKey(t) // Sign a cert and create a profile. _, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "sign-host", CallerInfo: adminCaller(), Data: map[string]interface{}{ "public_key": pubKey, "hostname": "persist.example.com", }, }) if err != nil { t.Fatalf("sign-host: %v", err) } _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "create-profile", CallerInfo: adminCaller(), Data: map[string]interface{}{ "name": "persist-profile", }, }) if err != nil { t.Fatalf("create-profile: %v", err) } // Seal. _ = eng.Seal() // Unseal. if err := eng.Unseal(ctx, b, mountPath); err != nil { t.Fatalf("Unseal: %v", err) } // Verify we can still list certs. resp, err := eng.HandleRequest(ctx, &engine.Request{ Operation: "list-certs", CallerInfo: userCaller(), }) if err != nil { t.Fatalf("list-certs after unseal: %v", err) } certs := resp.Data["certs"].([]interface{}) //nolint:errcheck if len(certs) != 1 { t.Errorf("expected 1 cert after unseal, got %d", len(certs)) } // Verify we can still list profiles. resp, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "list-profiles", CallerInfo: userCaller(), }) if err != nil { t.Fatalf("list-profiles after unseal: %v", err) } profiles := resp.Data["profiles"].([]interface{}) //nolint:errcheck if len(profiles) != 1 { t.Errorf("expected 1 profile after unseal, got %d", len(profiles)) } // Verify we can still sign. _, err = eng.HandleRequest(ctx, &engine.Request{ Operation: "sign-host", CallerInfo: adminCaller(), Data: map[string]interface{}{ "public_key": pubKey, "hostname": "after-unseal.example.com", }, }) if err != nil { t.Fatalf("sign-host after unseal: %v", err) } } func TestEngineType(t *testing.T) { eng := NewSSHCAEngine() if eng.Type() != engine.EngineTypeSSHCA { t.Errorf("Type: got %v, want %v", eng.Type(), engine.EngineTypeSSHCA) } }