Files
metacrypt/internal/webserver/cert_detail_test.go
Kyle Isom bbe382dc10 Migrate module path from kyle/ to mc/ org
All import paths updated to git.wntrmute.dev/mc/. Bumps mcdsl to v1.2.0.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 02:05:59 -07:00

514 lines
17 KiB
Go

package webserver
import (
"context"
"fmt"
"io/fs"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/go-chi/chi/v5"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
webui "git.wntrmute.dev/mc/metacrypt/web"
)
// mockVault is a minimal vaultBackend implementation for tests.
// All methods return zero values unless overridden via the function fields.
type mockVault struct {
statusFn func(ctx context.Context) (string, error)
validateTokenFn func(ctx context.Context, token string) (*TokenInfo, error)
listMountsFn func(ctx context.Context, token string) ([]MountInfo, error)
getCertFn func(ctx context.Context, token, mount, serial string) (*CertDetail, error)
issueCertFn func(ctx context.Context, token string, req IssueCertRequest) (*IssuedCert, error)
}
func (m *mockVault) Status(ctx context.Context) (string, error) {
if m.statusFn != nil {
return m.statusFn(ctx)
}
return "unsealed", nil
}
func (m *mockVault) Init(ctx context.Context, password string) error { return nil }
func (m *mockVault) Unseal(ctx context.Context, password string) error { return nil }
func (m *mockVault) Login(ctx context.Context, username, password, totpCode string) (string, error) {
return "", nil
}
func (m *mockVault) ValidateToken(ctx context.Context, token string) (*TokenInfo, error) {
if m.validateTokenFn != nil {
return m.validateTokenFn(ctx, token)
}
return &TokenInfo{Username: "testuser", IsAdmin: false}, nil
}
func (m *mockVault) ListMounts(ctx context.Context, token string) ([]MountInfo, error) {
if m.listMountsFn != nil {
return m.listMountsFn(ctx, token)
}
return []MountInfo{{Name: "pki", Type: "ca"}}, nil
}
func (m *mockVault) Mount(ctx context.Context, token, name, engineType string, config map[string]interface{}) error {
return nil
}
func (m *mockVault) GetRootCert(ctx context.Context, mount string) ([]byte, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) GetIssuerCert(ctx context.Context, mount, issuer string) ([]byte, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) ImportRoot(ctx context.Context, token, mount, certPEM, keyPEM string) error {
return nil
}
func (m *mockVault) CreateIssuer(ctx context.Context, token string, req CreateIssuerRequest) error {
return nil
}
func (m *mockVault) ListIssuers(ctx context.Context, token, mount string) ([]string, error) {
return nil, nil
}
func (m *mockVault) IssueCert(ctx context.Context, token string, req IssueCertRequest) (*IssuedCert, error) {
if m.issueCertFn != nil {
return m.issueCertFn(ctx, token, req)
}
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) SignCSR(ctx context.Context, token string, req SignCSRRequest) (*SignedCert, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) GetCert(ctx context.Context, token, mount, serial string) (*CertDetail, error) {
if m.getCertFn != nil {
return m.getCertFn(ctx, token, mount, serial)
}
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) ListCerts(ctx context.Context, token, mount string) ([]CertSummary, error) {
return nil, nil
}
func (m *mockVault) RevokeCert(ctx context.Context, token, mount, serial string) error {
return nil
}
func (m *mockVault) DeleteCert(ctx context.Context, token, mount, serial string) error {
return nil
}
func (m *mockVault) ListPolicies(ctx context.Context, token string) ([]PolicyRule, error) {
return nil, nil
}
func (m *mockVault) GetPolicy(ctx context.Context, token, id string) (*PolicyRule, error) {
return nil, nil
}
func (m *mockVault) CreatePolicy(ctx context.Context, token string, rule PolicyRule) (*PolicyRule, error) {
return nil, nil
}
func (m *mockVault) DeletePolicy(ctx context.Context, token, id string) error { return nil }
// SSH CA stubs
func (m *mockVault) GetSSHCAPublicKey(ctx context.Context, mount string) (*SSHCAPublicKey, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) SSHCASignHost(ctx context.Context, token, mount string, req SSHCASignRequest) (*SSHCASignResult, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) SSHCASignUser(ctx context.Context, token, mount string, req SSHCASignRequest) (*SSHCASignResult, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) ListSSHCAProfiles(ctx context.Context, token, mount string) ([]SSHCAProfileSummary, error) {
return nil, nil
}
func (m *mockVault) GetSSHCAProfile(ctx context.Context, token, mount, name string) (*SSHCAProfile, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) CreateSSHCAProfile(ctx context.Context, token, mount string, req SSHCAProfileRequest) error {
return nil
}
func (m *mockVault) UpdateSSHCAProfile(ctx context.Context, token, mount, name string, req SSHCAProfileRequest) error {
return nil
}
func (m *mockVault) DeleteSSHCAProfile(ctx context.Context, token, mount, name string) error {
return nil
}
func (m *mockVault) ListSSHCACerts(ctx context.Context, token, mount string) ([]SSHCACertSummary, error) {
return nil, nil
}
func (m *mockVault) GetSSHCACert(ctx context.Context, token, mount, serial string) (*SSHCACertDetail, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) RevokeSSHCACert(ctx context.Context, token, mount, serial string) error {
return nil
}
func (m *mockVault) DeleteSSHCACert(ctx context.Context, token, mount, serial string) error {
return nil
}
func (m *mockVault) GetSSHCAKRL(ctx context.Context, mount string) ([]byte, error) {
return nil, fmt.Errorf("not implemented")
}
// Transit stubs
func (m *mockVault) ListTransitKeys(ctx context.Context, token, mount string) ([]TransitKeySummary, error) {
return nil, nil
}
func (m *mockVault) GetTransitKey(ctx context.Context, token, mount, name string) (*TransitKeyDetail, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) CreateTransitKey(ctx context.Context, token, mount, name, keyType string) error {
return nil
}
func (m *mockVault) DeleteTransitKey(ctx context.Context, token, mount, name string) error {
return nil
}
func (m *mockVault) RotateTransitKey(ctx context.Context, token, mount, name string) error {
return nil
}
func (m *mockVault) UpdateTransitKeyConfig(ctx context.Context, token, mount, name string, minDecryptVersion int, allowDeletion bool) error {
return nil
}
func (m *mockVault) TrimTransitKey(ctx context.Context, token, mount, name string) (int, error) {
return 0, nil
}
func (m *mockVault) TransitEncrypt(ctx context.Context, token, mount, key, plaintext, transitCtx string) (string, error) {
return "", fmt.Errorf("not implemented")
}
func (m *mockVault) TransitDecrypt(ctx context.Context, token, mount, key, ciphertext, transitCtx string) (string, error) {
return "", fmt.Errorf("not implemented")
}
func (m *mockVault) TransitRewrap(ctx context.Context, token, mount, key, ciphertext, transitCtx string) (string, error) {
return "", fmt.Errorf("not implemented")
}
func (m *mockVault) TransitSign(ctx context.Context, token, mount, key, input string) (string, error) {
return "", fmt.Errorf("not implemented")
}
func (m *mockVault) TransitVerify(ctx context.Context, token, mount, key, input, signature string) (bool, error) {
return false, fmt.Errorf("not implemented")
}
func (m *mockVault) TransitHMAC(ctx context.Context, token, mount, key, input string) (string, error) {
return "", fmt.Errorf("not implemented")
}
func (m *mockVault) GetTransitPublicKey(ctx context.Context, token, mount, name string) (string, error) {
return "", fmt.Errorf("not implemented")
}
// User stubs
func (m *mockVault) UserRegister(ctx context.Context, token, mount string) (*UserKeyInfo, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) UserProvision(ctx context.Context, token, mount, username string) (*UserKeyInfo, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) GetUserPublicKey(ctx context.Context, token, mount, username string) (*UserKeyInfo, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) ListUsers(ctx context.Context, token, mount string) ([]string, error) {
return nil, nil
}
func (m *mockVault) UserEncrypt(ctx context.Context, token, mount, plaintext, metadata string, recipients []string) (string, error) {
return "", fmt.Errorf("not implemented")
}
func (m *mockVault) UserDecrypt(ctx context.Context, token, mount, envelope string) (*UserDecryptResult, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) UserReEncrypt(ctx context.Context, token, mount, envelope string) (string, error) {
return "", fmt.Errorf("not implemented")
}
func (m *mockVault) UserRotateKey(ctx context.Context, token, mount string) (*UserKeyInfo, error) {
return nil, fmt.Errorf("not implemented")
}
func (m *mockVault) UserDeleteUser(ctx context.Context, token, mount, username string) error {
return nil
}
func (m *mockVault) Close() error { return nil }
// ---- handleTGZDownload tests ----
func TestHandleTGZDownload(t *testing.T) {
t.Run("valid token serves archive and removes it from cache", func(t *testing.T) {
ws := newTestWebServer(t, &mockVault{})
const dlToken = "abc123"
ws.tgzCache.Store(dlToken, &tgzEntry{
filename: "test.tgz",
data: []byte("fake-tgz-data"),
})
r := newChiRequest(http.MethodGet, "/pki/download/"+dlToken, map[string]string{"token": dlToken})
r = addAuthCookie(r, &TokenInfo{Username: "testuser"})
w := httptest.NewRecorder()
ws.handleTGZDownload(w, r)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
}
if ct := w.Header().Get("Content-Type"); ct != "application/gzip" {
t.Errorf("Content-Type = %q, want application/gzip", ct)
}
if body := w.Body.String(); body != "fake-tgz-data" {
t.Errorf("body = %q, want fake-tgz-data", body)
}
// Entry must have been removed — a second request should 404.
w2 := httptest.NewRecorder()
r2 := newChiRequest(http.MethodGet, "/pki/download/"+dlToken, map[string]string{"token": dlToken})
r2 = addAuthCookie(r2, &TokenInfo{Username: "testuser"})
ws.handleTGZDownload(w2, r2)
if w2.Code != http.StatusNotFound {
t.Errorf("second request status = %d, want %d", w2.Code, http.StatusNotFound)
}
})
t.Run("unknown token returns 404", func(t *testing.T) {
ws := newTestWebServer(t, &mockVault{})
r := newChiRequest(http.MethodGet, "/pki/download/nosuchtoken", map[string]string{"token": "nosuchtoken"})
r = addAuthCookie(r, &TokenInfo{Username: "testuser"})
w := httptest.NewRecorder()
ws.handleTGZDownload(w, r)
if w.Code != http.StatusNotFound {
t.Errorf("status = %d, want %d", w.Code, http.StatusNotFound)
}
if !strings.Contains(w.Body.String(), "download not found") {
t.Errorf("body %q does not contain 'download not found'", w.Body.String())
}
})
}
// newTestWebServer builds a WebServer wired to the given mock, suitable for unit tests.
func newTestWebServer(t *testing.T, vault vaultBackend) *WebServer {
t.Helper()
staticFS, err := fs.Sub(webui.FS, "static")
if err != nil {
t.Fatalf("static FS: %v", err)
}
return &WebServer{
vault: vault,
logger: slog.Default(),
staticFS: staticFS,
csrf: newTestCSRF(t),
}
}
// newChiRequest builds an *http.Request with chi URL params set.
func newChiRequest(method, path string, params map[string]string) *http.Request {
r := httptest.NewRequest(method, path, nil)
rctx := chi.NewRouteContext()
for k, v := range params {
rctx.URLParams.Add(k, v)
}
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
return r
}
// addAuthCookie attaches a fake token cookie and injects TokenInfo into the request context.
func addAuthCookie(r *http.Request, info *TokenInfo) *http.Request {
r.AddCookie(&http.Cookie{Name: "metacrypt_token", Value: "test-token"})
return r.WithContext(withTokenInfo(r.Context(), info))
}
// ---- handleCertDetail tests ----
func TestHandleCertDetail(t *testing.T) {
sampleCert := &CertDetail{
Serial: "01:02:03",
Issuer: "myissuer",
CommonName: "example.com",
SANs: []string{"example.com", "www.example.com"},
Profile: "server",
IssuedBy: "testuser",
IssuedAt: "2025-01-01T00:00:00Z",
ExpiresAt: "2026-01-01T00:00:00Z",
CertPEM: "-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----\n",
}
tests := []struct {
name string
serial string
listMountsFn func(ctx context.Context, token string) ([]MountInfo, error)
getCertFn func(ctx context.Context, token, mount, serial string) (*CertDetail, error)
wantStatus int
wantBodyContains string
}{
{
name: "success renders cert detail page",
serial: "01:02:03",
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
return sampleCert, nil
},
wantStatus: http.StatusOK,
wantBodyContains: "example.com",
},
{
name: "cert not found returns 404",
serial: "ff:ff:ff",
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
return nil, status.Error(codes.NotFound, "cert not found")
},
wantStatus: http.StatusNotFound,
wantBodyContains: "certificate not found",
},
{
name: "backend error returns 500",
serial: "01:02:03",
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
return nil, fmt.Errorf("internal error")
},
wantStatus: http.StatusInternalServerError,
wantBodyContains: "internal error",
},
{
name: "no CA mount returns 404",
serial: "01:02:03",
listMountsFn: func(_ context.Context, _ string) ([]MountInfo, error) {
return []MountInfo{}, nil
},
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
return sampleCert, nil
},
wantStatus: http.StatusNotFound,
wantBodyContains: "no CA engine mounted",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mock := &mockVault{
listMountsFn: tc.listMountsFn,
getCertFn: tc.getCertFn,
}
ws := newTestWebServer(t, mock)
r := newChiRequest(http.MethodGet, "/pki/cert/"+tc.serial, map[string]string{"serial": tc.serial})
r = addAuthCookie(r, &TokenInfo{Username: "testuser"})
w := httptest.NewRecorder()
ws.handleCertDetail(w, r)
if w.Code != tc.wantStatus {
t.Errorf("status = %d, want %d", w.Code, tc.wantStatus)
}
if tc.wantBodyContains != "" && !strings.Contains(w.Body.String(), tc.wantBodyContains) {
t.Errorf("body %q does not contain %q", w.Body.String(), tc.wantBodyContains)
}
})
}
}
// ---- handleCertDownload tests ----
func TestHandleCertDownload(t *testing.T) {
samplePEM := "-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----\n"
sampleCert := &CertDetail{
Serial: "01:02:03",
CertPEM: samplePEM,
}
tests := []struct {
name string
serial string
listMountsFn func(ctx context.Context, token string) ([]MountInfo, error)
getCertFn func(ctx context.Context, token, mount, serial string) (*CertDetail, error)
wantStatus int
wantBodyContains string
wantContentType string
wantDisposition string
}{
{
name: "success streams PEM file",
serial: "01:02:03",
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
return sampleCert, nil
},
wantStatus: http.StatusOK,
wantBodyContains: samplePEM,
wantContentType: "application/x-pem-file",
wantDisposition: `attachment; filename="01:02:03.pem"`,
},
{
name: "cert not found returns 404",
serial: "ff:ff:ff",
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
return nil, status.Error(codes.NotFound, "cert not found")
},
wantStatus: http.StatusNotFound,
wantBodyContains: "certificate not found",
},
{
name: "backend error returns 500",
serial: "01:02:03",
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
return nil, fmt.Errorf("storage failure")
},
wantStatus: http.StatusInternalServerError,
wantBodyContains: "storage failure",
},
{
name: "no CA mount returns 404",
serial: "01:02:03",
listMountsFn: func(_ context.Context, _ string) ([]MountInfo, error) {
return []MountInfo{}, nil
},
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
return sampleCert, nil
},
wantStatus: http.StatusNotFound,
wantBodyContains: "no CA engine mounted",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mock := &mockVault{
listMountsFn: tc.listMountsFn,
getCertFn: tc.getCertFn,
}
ws := newTestWebServer(t, mock)
r := newChiRequest(http.MethodGet, "/pki/cert/"+tc.serial+"/download", map[string]string{"serial": tc.serial})
r = addAuthCookie(r, &TokenInfo{Username: "testuser"})
w := httptest.NewRecorder()
ws.handleCertDownload(w, r)
if w.Code != tc.wantStatus {
t.Errorf("status = %d, want %d", w.Code, tc.wantStatus)
}
if tc.wantBodyContains != "" && !strings.Contains(w.Body.String(), tc.wantBodyContains) {
t.Errorf("body %q does not contain %q", w.Body.String(), tc.wantBodyContains)
}
if tc.wantContentType != "" {
if got := w.Header().Get("Content-Type"); got != tc.wantContentType {
t.Errorf("Content-Type = %q, want %q", got, tc.wantContentType)
}
}
if tc.wantDisposition != "" {
if got := w.Header().Get("Content-Disposition"); got != tc.wantDisposition {
t.Errorf("Content-Disposition = %q, want %q", got, tc.wantDisposition)
}
}
})
}
}