package webserver import ( "context" "fmt" "io/fs" "log/slog" "net/http" "net/http/httptest" "strings" "testing" "github.com/go-chi/chi/v5" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" webui "git.wntrmute.dev/kyle/metacrypt/web" ) // mockVault is a minimal vaultBackend implementation for tests. // All methods return zero values unless overridden via the function fields. type mockVault struct { statusFn func(ctx context.Context) (string, error) validateTokenFn func(ctx context.Context, token string) (*TokenInfo, error) listMountsFn func(ctx context.Context, token string) ([]MountInfo, error) getCertFn func(ctx context.Context, token, mount, serial string) (*CertDetail, error) issueCertFn func(ctx context.Context, token string, req IssueCertRequest) (*IssuedCert, error) } func (m *mockVault) Status(ctx context.Context) (string, error) { if m.statusFn != nil { return m.statusFn(ctx) } return "unsealed", nil } func (m *mockVault) Init(ctx context.Context, password string) error { return nil } func (m *mockVault) Unseal(ctx context.Context, password string) error { return nil } func (m *mockVault) Login(ctx context.Context, username, password, totpCode string) (string, error) { return "", nil } func (m *mockVault) ValidateToken(ctx context.Context, token string) (*TokenInfo, error) { if m.validateTokenFn != nil { return m.validateTokenFn(ctx, token) } return &TokenInfo{Username: "testuser", IsAdmin: false}, nil } func (m *mockVault) ListMounts(ctx context.Context, token string) ([]MountInfo, error) { if m.listMountsFn != nil { return m.listMountsFn(ctx, token) } return []MountInfo{{Name: "pki", Type: "ca"}}, nil } func (m *mockVault) Mount(ctx context.Context, token, name, engineType string, config map[string]interface{}) error { return nil } func (m *mockVault) GetRootCert(ctx context.Context, mount string) ([]byte, error) { return nil, fmt.Errorf("not implemented") } func (m *mockVault) GetIssuerCert(ctx context.Context, mount, issuer string) ([]byte, error) { return nil, fmt.Errorf("not implemented") } func (m *mockVault) ImportRoot(ctx context.Context, token, mount, certPEM, keyPEM string) error { return nil } func (m *mockVault) CreateIssuer(ctx context.Context, token string, req CreateIssuerRequest) error { return nil } func (m *mockVault) ListIssuers(ctx context.Context, token, mount string) ([]string, error) { return nil, nil } func (m *mockVault) IssueCert(ctx context.Context, token string, req IssueCertRequest) (*IssuedCert, error) { if m.issueCertFn != nil { return m.issueCertFn(ctx, token, req) } return nil, fmt.Errorf("not implemented") } func (m *mockVault) SignCSR(ctx context.Context, token string, req SignCSRRequest) (*SignedCert, error) { return nil, fmt.Errorf("not implemented") } func (m *mockVault) GetCert(ctx context.Context, token, mount, serial string) (*CertDetail, error) { if m.getCertFn != nil { return m.getCertFn(ctx, token, mount, serial) } return nil, fmt.Errorf("not implemented") } func (m *mockVault) ListCerts(ctx context.Context, token, mount string) ([]CertSummary, error) { return nil, nil } func (m *mockVault) RevokeCert(ctx context.Context, token, mount, serial string) error { return nil } func (m *mockVault) DeleteCert(ctx context.Context, token, mount, serial string) error { return nil } func (m *mockVault) ListPolicies(ctx context.Context, token string) ([]PolicyRule, error) { return nil, nil } func (m *mockVault) GetPolicy(ctx context.Context, token, id string) (*PolicyRule, error) { return nil, nil } func (m *mockVault) CreatePolicy(ctx context.Context, token string, rule PolicyRule) (*PolicyRule, error) { return nil, nil } func (m *mockVault) DeletePolicy(ctx context.Context, token, id string) error { return nil } // SSH CA stubs func (m *mockVault) GetSSHCAPublicKey(ctx context.Context, mount string) (*SSHCAPublicKey, error) { return nil, fmt.Errorf("not implemented") } func (m *mockVault) SSHCASignHost(ctx context.Context, token, mount string, req SSHCASignRequest) (*SSHCASignResult, error) { return nil, fmt.Errorf("not implemented") } func (m *mockVault) SSHCASignUser(ctx context.Context, token, mount string, req SSHCASignRequest) (*SSHCASignResult, error) { return nil, fmt.Errorf("not implemented") } func (m *mockVault) ListSSHCAProfiles(ctx context.Context, token, mount string) ([]SSHCAProfileSummary, error) { return nil, nil } func (m *mockVault) GetSSHCAProfile(ctx context.Context, token, mount, name string) (*SSHCAProfile, error) { return nil, fmt.Errorf("not implemented") } func (m *mockVault) CreateSSHCAProfile(ctx context.Context, token, mount string, req SSHCAProfileRequest) error { return nil } func (m *mockVault) UpdateSSHCAProfile(ctx context.Context, token, mount, name string, req SSHCAProfileRequest) error { return nil } func (m *mockVault) DeleteSSHCAProfile(ctx context.Context, token, mount, name string) error { return nil } func (m *mockVault) ListSSHCACerts(ctx context.Context, token, mount string) ([]SSHCACertSummary, error) { return nil, nil } func (m *mockVault) GetSSHCACert(ctx context.Context, token, mount, serial string) (*SSHCACertDetail, error) { return nil, fmt.Errorf("not implemented") } func (m *mockVault) RevokeSSHCACert(ctx context.Context, token, mount, serial string) error { return nil } func (m *mockVault) DeleteSSHCACert(ctx context.Context, token, mount, serial string) error { return nil } func (m *mockVault) GetSSHCAKRL(ctx context.Context, mount string) ([]byte, error) { return nil, fmt.Errorf("not implemented") } // Transit stubs func (m *mockVault) ListTransitKeys(ctx context.Context, token, mount string) ([]TransitKeySummary, error) { return nil, nil } func (m *mockVault) GetTransitKey(ctx context.Context, token, mount, name string) (*TransitKeyDetail, error) { return nil, fmt.Errorf("not implemented") } func (m *mockVault) CreateTransitKey(ctx context.Context, token, mount, name, keyType string) error { return nil } func (m *mockVault) DeleteTransitKey(ctx context.Context, token, mount, name string) error { return nil } func (m *mockVault) RotateTransitKey(ctx context.Context, token, mount, name string) error { return nil } func (m *mockVault) UpdateTransitKeyConfig(ctx context.Context, token, mount, name string, minDecryptVersion int, allowDeletion bool) error { return nil } func (m *mockVault) TrimTransitKey(ctx context.Context, token, mount, name string) (int, error) { return 0, nil } func (m *mockVault) TransitEncrypt(ctx context.Context, token, mount, key, plaintext, transitCtx string) (string, error) { return "", fmt.Errorf("not implemented") } func (m *mockVault) TransitDecrypt(ctx context.Context, token, mount, key, ciphertext, transitCtx string) (string, error) { return "", fmt.Errorf("not implemented") } func (m *mockVault) TransitRewrap(ctx context.Context, token, mount, key, ciphertext, transitCtx string) (string, error) { return "", fmt.Errorf("not implemented") } func (m *mockVault) TransitSign(ctx context.Context, token, mount, key, input string) (string, error) { return "", fmt.Errorf("not implemented") } func (m *mockVault) TransitVerify(ctx context.Context, token, mount, key, input, signature string) (bool, error) { return false, fmt.Errorf("not implemented") } func (m *mockVault) TransitHMAC(ctx context.Context, token, mount, key, input string) (string, error) { return "", fmt.Errorf("not implemented") } func (m *mockVault) GetTransitPublicKey(ctx context.Context, token, mount, name string) (string, error) { return "", fmt.Errorf("not implemented") } // User stubs func (m *mockVault) UserRegister(ctx context.Context, token, mount string) (*UserKeyInfo, error) { return nil, fmt.Errorf("not implemented") } func (m *mockVault) UserProvision(ctx context.Context, token, mount, username string) (*UserKeyInfo, error) { return nil, fmt.Errorf("not implemented") } func (m *mockVault) GetUserPublicKey(ctx context.Context, token, mount, username string) (*UserKeyInfo, error) { return nil, fmt.Errorf("not implemented") } func (m *mockVault) ListUsers(ctx context.Context, token, mount string) ([]string, error) { return nil, nil } func (m *mockVault) UserEncrypt(ctx context.Context, token, mount, plaintext, metadata string, recipients []string) (string, error) { return "", fmt.Errorf("not implemented") } func (m *mockVault) UserDecrypt(ctx context.Context, token, mount, envelope string) (*UserDecryptResult, error) { return nil, fmt.Errorf("not implemented") } func (m *mockVault) UserReEncrypt(ctx context.Context, token, mount, envelope string) (string, error) { return "", fmt.Errorf("not implemented") } func (m *mockVault) UserRotateKey(ctx context.Context, token, mount string) (*UserKeyInfo, error) { return nil, fmt.Errorf("not implemented") } func (m *mockVault) UserDeleteUser(ctx context.Context, token, mount, username string) error { return nil } func (m *mockVault) Close() error { return nil } // ---- handleTGZDownload tests ---- func TestHandleTGZDownload(t *testing.T) { t.Run("valid token serves archive and removes it from cache", func(t *testing.T) { ws := newTestWebServer(t, &mockVault{}) const dlToken = "abc123" ws.tgzCache.Store(dlToken, &tgzEntry{ filename: "test.tgz", data: []byte("fake-tgz-data"), }) r := newChiRequest(http.MethodGet, "/pki/download/"+dlToken, map[string]string{"token": dlToken}) r = addAuthCookie(r, &TokenInfo{Username: "testuser"}) w := httptest.NewRecorder() ws.handleTGZDownload(w, r) if w.Code != http.StatusOK { t.Errorf("status = %d, want %d", w.Code, http.StatusOK) } if ct := w.Header().Get("Content-Type"); ct != "application/gzip" { t.Errorf("Content-Type = %q, want application/gzip", ct) } if body := w.Body.String(); body != "fake-tgz-data" { t.Errorf("body = %q, want fake-tgz-data", body) } // Entry must have been removed — a second request should 404. w2 := httptest.NewRecorder() r2 := newChiRequest(http.MethodGet, "/pki/download/"+dlToken, map[string]string{"token": dlToken}) r2 = addAuthCookie(r2, &TokenInfo{Username: "testuser"}) ws.handleTGZDownload(w2, r2) if w2.Code != http.StatusNotFound { t.Errorf("second request status = %d, want %d", w2.Code, http.StatusNotFound) } }) t.Run("unknown token returns 404", func(t *testing.T) { ws := newTestWebServer(t, &mockVault{}) r := newChiRequest(http.MethodGet, "/pki/download/nosuchtoken", map[string]string{"token": "nosuchtoken"}) r = addAuthCookie(r, &TokenInfo{Username: "testuser"}) w := httptest.NewRecorder() ws.handleTGZDownload(w, r) if w.Code != http.StatusNotFound { t.Errorf("status = %d, want %d", w.Code, http.StatusNotFound) } if !strings.Contains(w.Body.String(), "download not found") { t.Errorf("body %q does not contain 'download not found'", w.Body.String()) } }) } // newTestWebServer builds a WebServer wired to the given mock, suitable for unit tests. func newTestWebServer(t *testing.T, vault vaultBackend) *WebServer { t.Helper() staticFS, err := fs.Sub(webui.FS, "static") if err != nil { t.Fatalf("static FS: %v", err) } return &WebServer{ vault: vault, logger: slog.Default(), staticFS: staticFS, csrf: newCSRFProtect(), } } // newChiRequest builds an *http.Request with chi URL params set. func newChiRequest(method, path string, params map[string]string) *http.Request { r := httptest.NewRequest(method, path, nil) rctx := chi.NewRouteContext() for k, v := range params { rctx.URLParams.Add(k, v) } r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) return r } // addAuthCookie attaches a fake token cookie and injects TokenInfo into the request context. func addAuthCookie(r *http.Request, info *TokenInfo) *http.Request { r.AddCookie(&http.Cookie{Name: "metacrypt_token", Value: "test-token"}) return r.WithContext(withTokenInfo(r.Context(), info)) } // ---- handleCertDetail tests ---- func TestHandleCertDetail(t *testing.T) { sampleCert := &CertDetail{ Serial: "01:02:03", Issuer: "myissuer", CommonName: "example.com", SANs: []string{"example.com", "www.example.com"}, Profile: "server", IssuedBy: "testuser", IssuedAt: "2025-01-01T00:00:00Z", ExpiresAt: "2026-01-01T00:00:00Z", CertPEM: "-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----\n", } tests := []struct { name string serial string listMountsFn func(ctx context.Context, token string) ([]MountInfo, error) getCertFn func(ctx context.Context, token, mount, serial string) (*CertDetail, error) wantStatus int wantBodyContains string }{ { name: "success renders cert detail page", serial: "01:02:03", getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) { return sampleCert, nil }, wantStatus: http.StatusOK, wantBodyContains: "example.com", }, { name: "cert not found returns 404", serial: "ff:ff:ff", getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) { return nil, status.Error(codes.NotFound, "cert not found") }, wantStatus: http.StatusNotFound, wantBodyContains: "certificate not found", }, { name: "backend error returns 500", serial: "01:02:03", getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) { return nil, fmt.Errorf("internal error") }, wantStatus: http.StatusInternalServerError, wantBodyContains: "internal error", }, { name: "no CA mount returns 404", serial: "01:02:03", listMountsFn: func(_ context.Context, _ string) ([]MountInfo, error) { return []MountInfo{}, nil }, getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) { return sampleCert, nil }, wantStatus: http.StatusNotFound, wantBodyContains: "no CA engine mounted", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { mock := &mockVault{ listMountsFn: tc.listMountsFn, getCertFn: tc.getCertFn, } ws := newTestWebServer(t, mock) r := newChiRequest(http.MethodGet, "/pki/cert/"+tc.serial, map[string]string{"serial": tc.serial}) r = addAuthCookie(r, &TokenInfo{Username: "testuser"}) w := httptest.NewRecorder() ws.handleCertDetail(w, r) if w.Code != tc.wantStatus { t.Errorf("status = %d, want %d", w.Code, tc.wantStatus) } if tc.wantBodyContains != "" && !strings.Contains(w.Body.String(), tc.wantBodyContains) { t.Errorf("body %q does not contain %q", w.Body.String(), tc.wantBodyContains) } }) } } // ---- handleCertDownload tests ---- func TestHandleCertDownload(t *testing.T) { samplePEM := "-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----\n" sampleCert := &CertDetail{ Serial: "01:02:03", CertPEM: samplePEM, } tests := []struct { name string serial string listMountsFn func(ctx context.Context, token string) ([]MountInfo, error) getCertFn func(ctx context.Context, token, mount, serial string) (*CertDetail, error) wantStatus int wantBodyContains string wantContentType string wantDisposition string }{ { name: "success streams PEM file", serial: "01:02:03", getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) { return sampleCert, nil }, wantStatus: http.StatusOK, wantBodyContains: samplePEM, wantContentType: "application/x-pem-file", wantDisposition: `attachment; filename="01:02:03.pem"`, }, { name: "cert not found returns 404", serial: "ff:ff:ff", getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) { return nil, status.Error(codes.NotFound, "cert not found") }, wantStatus: http.StatusNotFound, wantBodyContains: "certificate not found", }, { name: "backend error returns 500", serial: "01:02:03", getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) { return nil, fmt.Errorf("storage failure") }, wantStatus: http.StatusInternalServerError, wantBodyContains: "storage failure", }, { name: "no CA mount returns 404", serial: "01:02:03", listMountsFn: func(_ context.Context, _ string) ([]MountInfo, error) { return []MountInfo{}, nil }, getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) { return sampleCert, nil }, wantStatus: http.StatusNotFound, wantBodyContains: "no CA engine mounted", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { mock := &mockVault{ listMountsFn: tc.listMountsFn, getCertFn: tc.getCertFn, } ws := newTestWebServer(t, mock) r := newChiRequest(http.MethodGet, "/pki/cert/"+tc.serial+"/download", map[string]string{"serial": tc.serial}) r = addAuthCookie(r, &TokenInfo{Username: "testuser"}) w := httptest.NewRecorder() ws.handleCertDownload(w, r) if w.Code != tc.wantStatus { t.Errorf("status = %d, want %d", w.Code, tc.wantStatus) } if tc.wantBodyContains != "" && !strings.Contains(w.Body.String(), tc.wantBodyContains) { t.Errorf("body %q does not contain %q", w.Body.String(), tc.wantBodyContains) } if tc.wantContentType != "" { if got := w.Header().Get("Content-Type"); got != tc.wantContentType { t.Errorf("Content-Type = %q, want %q", got, tc.wantContentType) } } if tc.wantDisposition != "" { if got := w.Header().Get("Content-Disposition"); got != tc.wantDisposition { t.Errorf("Content-Disposition = %q, want %q", got, tc.wantDisposition) } } }) } }