package agent import ( "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/x509" "crypto/x509/pkix" "encoding/json" "encoding/pem" "log/slog" "math/big" "net/http" "net/http/httptest" "os" "path/filepath" "testing" "time" "git.wntrmute.dev/mc/mcp/internal/config" "git.wntrmute.dev/mc/mcp/internal/registry" ) func TestNilCertProvisionerIsNoop(t *testing.T) { var p *CertProvisioner if err := p.EnsureCert(context.Background(), "svc", []string{"svc.example.com"}); err != nil { t.Fatalf("EnsureCert on nil: %v", err) } } func TestNewCertProvisionerDisabledWhenUnconfigured(t *testing.T) { p, err := NewCertProvisioner(config.MetacryptConfig{}, "/tmp", slog.Default()) if err != nil { t.Fatalf("unexpected error: %v", err) } if p != nil { t.Fatal("expected nil provisioner for empty config") } } func TestEnsureCertSkipsValidCert(t *testing.T) { certDir := t.TempDir() certPath := filepath.Join(certDir, "svc.pem") keyPath := filepath.Join(certDir, "svc.key") // Generate a cert that expires in 90 days. writeSelfSignedCert(t, certPath, keyPath, "svc.example.com", 90*24*time.Hour) // Create a provisioner that would fail if it tried to issue. p := &CertProvisioner{ serverURL: "https://will-fail-if-called:9999", certDir: certDir, logger: slog.Default(), } if err := p.EnsureCert(context.Background(), "svc", []string{"svc.example.com"}); err != nil { t.Fatalf("EnsureCert: %v", err) } } func TestEnsureCertReissuesExpiring(t *testing.T) { certDir := t.TempDir() certPath := filepath.Join(certDir, "svc.pem") keyPath := filepath.Join(certDir, "svc.key") // Generate a cert that expires in 10 days (within 30-day renewal window). writeSelfSignedCert(t, certPath, keyPath, "svc.example.com", 10*24*time.Hour) // Mock Metacrypt API. newCert, newKey := generateCertPEM(t, "svc.example.com", 90*24*time.Hour) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { resp := map[string]string{ "chain_pem": newCert, "key_pem": newKey, "serial": "abc123", "expires_at": time.Now().Add(90 * 24 * time.Hour).Format(time.RFC3339), } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(resp) })) defer srv.Close() p := &CertProvisioner{ serverURL: srv.URL, token: "test-token", mount: "pki", issuer: "infra", certDir: certDir, httpClient: srv.Client(), logger: slog.Default(), } if err := p.EnsureCert(context.Background(), "svc", []string{"svc.example.com"}); err != nil { t.Fatalf("EnsureCert: %v", err) } // Verify new cert was written. got, err := os.ReadFile(certPath) if err != nil { t.Fatalf("read cert: %v", err) } if string(got) != newCert { t.Fatal("cert file was not updated with new cert") } } func TestIssueCertWritesFiles(t *testing.T) { certDir := t.TempDir() // Mock Metacrypt API. certPEM, keyPEM := generateCertPEM(t, "svc.example.com", 90*24*time.Hour) var gotAuth string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotAuth = r.Header.Get("Authorization") var req map[string]interface{} if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "bad request", http.StatusBadRequest) return } // Verify request structure. if req["mount"] != "pki" || req["operation"] != "issue" { t.Errorf("unexpected request: %v", req) } resp := map[string]string{ "chain_pem": certPEM, "key_pem": keyPEM, "serial": "deadbeef", "expires_at": time.Now().Add(90 * 24 * time.Hour).Format(time.RFC3339), } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(resp) })) defer srv.Close() p := &CertProvisioner{ serverURL: srv.URL, token: "my-service-token", mount: "pki", issuer: "infra", certDir: certDir, httpClient: srv.Client(), logger: slog.Default(), } if err := p.EnsureCert(context.Background(), "svc", []string{"svc.example.com"}); err != nil { t.Fatalf("EnsureCert: %v", err) } // Verify auth header. if gotAuth != "Bearer my-service-token" { t.Fatalf("auth header: got %q, want %q", gotAuth, "Bearer my-service-token") } // Verify cert file. certData, err := os.ReadFile(filepath.Join(certDir, "svc.pem")) if err != nil { t.Fatalf("read cert: %v", err) } if string(certData) != certPEM { t.Fatal("cert content mismatch") } // Verify key file. keyData, err := os.ReadFile(filepath.Join(certDir, "svc.key")) if err != nil { t.Fatalf("read key: %v", err) } if string(keyData) != keyPEM { t.Fatal("key content mismatch") } // Verify key file permissions. info, err := os.Stat(filepath.Join(certDir, "svc.key")) if err != nil { t.Fatalf("stat key: %v", err) } if perm := info.Mode().Perm(); perm != 0600 { t.Fatalf("key permissions: got %o, want 0600", perm) } } func TestIssueCertAPIError(t *testing.T) { certDir := t.TempDir() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, `{"error":"sealed"}`, http.StatusServiceUnavailable) })) defer srv.Close() p := &CertProvisioner{ serverURL: srv.URL, token: "test-token", mount: "pki", issuer: "infra", certDir: certDir, httpClient: srv.Client(), logger: slog.Default(), } err := p.EnsureCert(context.Background(), "svc", []string{"svc.example.com"}) if err == nil { t.Fatal("expected error for sealed metacrypt") } } func TestCertTimeRemaining(t *testing.T) { t.Run("missing file", func(t *testing.T) { if _, ok := certTimeRemaining("/nonexistent/cert.pem"); ok { t.Fatal("expected false for missing file") } }) t.Run("valid cert", func(t *testing.T) { certDir := t.TempDir() path := filepath.Join(certDir, "test.pem") writeSelfSignedCert(t, path, filepath.Join(certDir, "test.key"), "test.example.com", 90*24*time.Hour) remaining, ok := certTimeRemaining(path) if !ok { t.Fatal("expected true for valid cert") } // Should be close to 90 days. if remaining < 89*24*time.Hour || remaining > 91*24*time.Hour { t.Fatalf("remaining: got %v, want ~90 days", remaining) } }) t.Run("expired cert", func(t *testing.T) { certDir := t.TempDir() path := filepath.Join(certDir, "expired.pem") // Write a cert that's already expired (valid from -2h to -1h). writeExpiredCert(t, path, filepath.Join(certDir, "expired.key"), "expired.example.com") remaining, ok := certTimeRemaining(path) if !ok { t.Fatal("expected true for expired cert") } if remaining > 0 { t.Fatalf("remaining: got %v, want <= 0", remaining) } }) } func TestHasL7Routes(t *testing.T) { tests := []struct { name string routes []registry.Route want bool }{ {"nil", nil, false}, {"empty", []registry.Route{}, false}, {"l4 only", []registry.Route{{Mode: "l4"}}, false}, {"l7 only", []registry.Route{{Mode: "l7"}}, true}, {"mixed", []registry.Route{{Mode: "l4"}, {Mode: "l7"}}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := hasL7Routes(tt.routes); got != tt.want { t.Fatalf("hasL7Routes = %v, want %v", got, tt.want) } }) } } func TestL7Hostnames(t *testing.T) { routes := []registry.Route{ {Mode: "l7", Hostname: ""}, {Mode: "l4", Hostname: "ignored.example.com"}, {Mode: "l7", Hostname: "custom.example.com"}, {Mode: "l7", Hostname: ""}, // duplicate default } got := l7Hostnames("myservice", routes) want := []string{"myservice.svc.mcp.metacircular.net", "custom.example.com"} if len(got) != len(want) { t.Fatalf("got %v, want %v", got, want) } for i := range want { if got[i] != want[i] { t.Fatalf("got[%d] = %q, want %q", i, got[i], want[i]) } } } func TestAtomicWrite(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "test.txt") if err := atomicWrite(path, []byte("hello"), 0644); err != nil { t.Fatalf("atomicWrite: %v", err) } data, err := os.ReadFile(path) if err != nil { t.Fatalf("read: %v", err) } if string(data) != "hello" { t.Fatalf("got %q, want %q", string(data), "hello") } // Verify no .tmp file left behind. if _, err := os.Stat(path + ".tmp"); !os.IsNotExist(err) { t.Fatal("temp file should not exist after atomic write") } } // --- test helpers --- // writeSelfSignedCert generates a self-signed cert/key and writes them to disk. func writeSelfSignedCert(t *testing.T, certPath, keyPath, hostname string, validity time.Duration) { t.Helper() certPEM, keyPEM := generateCertPEM(t, hostname, validity) if err := os.WriteFile(certPath, []byte(certPEM), 0644); err != nil { t.Fatalf("write cert: %v", err) } if err := os.WriteFile(keyPath, []byte(keyPEM), 0600); err != nil { t.Fatalf("write key: %v", err) } } // writeExpiredCert generates a cert that is already expired. func writeExpiredCert(t *testing.T, certPath, keyPath, hostname string) { t.Helper() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatalf("generate key: %v", err) } tmpl := &x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: hostname}, DNSNames: []string{hostname}, NotBefore: time.Now().Add(-2 * time.Hour), NotAfter: time.Now().Add(-1 * time.Hour), } der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) if err != nil { t.Fatalf("create cert: %v", err) } certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) keyDER, err := x509.MarshalECPrivateKey(key) if err != nil { t.Fatalf("marshal key: %v", err) } keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) if err := os.WriteFile(certPath, certPEM, 0644); err != nil { t.Fatalf("write cert: %v", err) } if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil { t.Fatalf("write key: %v", err) } } // generateCertPEM generates a self-signed cert and returns PEM strings. func generateCertPEM(t *testing.T, hostname string, validity time.Duration) (certPEM, keyPEM string) { t.Helper() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatalf("generate key: %v", err) } tmpl := &x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: hostname}, DNSNames: []string{hostname}, NotBefore: time.Now().Add(-1 * time.Hour), NotAfter: time.Now().Add(validity), } der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) if err != nil { t.Fatalf("create cert: %v", err) } certBlock := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) keyDER, err := x509.MarshalECPrivateKey(key) if err != nil { t.Fatalf("marshal key: %v", err) } keyBlock := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) return string(certBlock), string(keyBlock) }