CSRF: Replace local csrfProtect with mcdsl/csrf.Protect. Delete internal/webserver/csrf.go. Web: Replace renderTemplate with web.RenderTemplate + csrf.TemplateFunc. Replace extractCookie with web.GetSessionToken. Replace manual session cookie SetCookie with web.SetSessionCookie. Snapshot: Replace local sqliteBackup with mcdsl/db.Snapshot. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
514 lines
17 KiB
Go
514 lines
17 KiB
Go
package webserver
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io/fs"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
|
|
webui "git.wntrmute.dev/kyle/metacrypt/web"
|
|
)
|
|
|
|
// mockVault is a minimal vaultBackend implementation for tests.
|
|
// All methods return zero values unless overridden via the function fields.
|
|
type mockVault struct {
|
|
statusFn func(ctx context.Context) (string, error)
|
|
validateTokenFn func(ctx context.Context, token string) (*TokenInfo, error)
|
|
listMountsFn func(ctx context.Context, token string) ([]MountInfo, error)
|
|
getCertFn func(ctx context.Context, token, mount, serial string) (*CertDetail, error)
|
|
issueCertFn func(ctx context.Context, token string, req IssueCertRequest) (*IssuedCert, error)
|
|
}
|
|
|
|
func (m *mockVault) Status(ctx context.Context) (string, error) {
|
|
if m.statusFn != nil {
|
|
return m.statusFn(ctx)
|
|
}
|
|
return "unsealed", nil
|
|
}
|
|
|
|
func (m *mockVault) Init(ctx context.Context, password string) error { return nil }
|
|
|
|
func (m *mockVault) Unseal(ctx context.Context, password string) error { return nil }
|
|
|
|
func (m *mockVault) Login(ctx context.Context, username, password, totpCode string) (string, error) {
|
|
return "", nil
|
|
}
|
|
|
|
func (m *mockVault) ValidateToken(ctx context.Context, token string) (*TokenInfo, error) {
|
|
if m.validateTokenFn != nil {
|
|
return m.validateTokenFn(ctx, token)
|
|
}
|
|
return &TokenInfo{Username: "testuser", IsAdmin: false}, nil
|
|
}
|
|
|
|
func (m *mockVault) ListMounts(ctx context.Context, token string) ([]MountInfo, error) {
|
|
if m.listMountsFn != nil {
|
|
return m.listMountsFn(ctx, token)
|
|
}
|
|
return []MountInfo{{Name: "pki", Type: "ca"}}, nil
|
|
}
|
|
|
|
func (m *mockVault) Mount(ctx context.Context, token, name, engineType string, config map[string]interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockVault) GetRootCert(ctx context.Context, mount string) ([]byte, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
|
|
func (m *mockVault) GetIssuerCert(ctx context.Context, mount, issuer string) ([]byte, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
|
|
func (m *mockVault) ImportRoot(ctx context.Context, token, mount, certPEM, keyPEM string) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockVault) CreateIssuer(ctx context.Context, token string, req CreateIssuerRequest) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockVault) ListIssuers(ctx context.Context, token, mount string) ([]string, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *mockVault) IssueCert(ctx context.Context, token string, req IssueCertRequest) (*IssuedCert, error) {
|
|
if m.issueCertFn != nil {
|
|
return m.issueCertFn(ctx, token, req)
|
|
}
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
|
|
func (m *mockVault) SignCSR(ctx context.Context, token string, req SignCSRRequest) (*SignedCert, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
|
|
func (m *mockVault) GetCert(ctx context.Context, token, mount, serial string) (*CertDetail, error) {
|
|
if m.getCertFn != nil {
|
|
return m.getCertFn(ctx, token, mount, serial)
|
|
}
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
|
|
func (m *mockVault) ListCerts(ctx context.Context, token, mount string) ([]CertSummary, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *mockVault) RevokeCert(ctx context.Context, token, mount, serial string) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockVault) DeleteCert(ctx context.Context, token, mount, serial string) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockVault) ListPolicies(ctx context.Context, token string) ([]PolicyRule, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *mockVault) GetPolicy(ctx context.Context, token, id string) (*PolicyRule, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *mockVault) CreatePolicy(ctx context.Context, token string, rule PolicyRule) (*PolicyRule, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (m *mockVault) DeletePolicy(ctx context.Context, token, id string) error { return nil }
|
|
|
|
// SSH CA stubs
|
|
func (m *mockVault) GetSSHCAPublicKey(ctx context.Context, mount string) (*SSHCAPublicKey, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) SSHCASignHost(ctx context.Context, token, mount string, req SSHCASignRequest) (*SSHCASignResult, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) SSHCASignUser(ctx context.Context, token, mount string, req SSHCASignRequest) (*SSHCASignResult, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) ListSSHCAProfiles(ctx context.Context, token, mount string) ([]SSHCAProfileSummary, error) {
|
|
return nil, nil
|
|
}
|
|
func (m *mockVault) GetSSHCAProfile(ctx context.Context, token, mount, name string) (*SSHCAProfile, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) CreateSSHCAProfile(ctx context.Context, token, mount string, req SSHCAProfileRequest) error {
|
|
return nil
|
|
}
|
|
func (m *mockVault) UpdateSSHCAProfile(ctx context.Context, token, mount, name string, req SSHCAProfileRequest) error {
|
|
return nil
|
|
}
|
|
func (m *mockVault) DeleteSSHCAProfile(ctx context.Context, token, mount, name string) error {
|
|
return nil
|
|
}
|
|
func (m *mockVault) ListSSHCACerts(ctx context.Context, token, mount string) ([]SSHCACertSummary, error) {
|
|
return nil, nil
|
|
}
|
|
func (m *mockVault) GetSSHCACert(ctx context.Context, token, mount, serial string) (*SSHCACertDetail, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) RevokeSSHCACert(ctx context.Context, token, mount, serial string) error {
|
|
return nil
|
|
}
|
|
func (m *mockVault) DeleteSSHCACert(ctx context.Context, token, mount, serial string) error {
|
|
return nil
|
|
}
|
|
func (m *mockVault) GetSSHCAKRL(ctx context.Context, mount string) ([]byte, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
|
|
// Transit stubs
|
|
func (m *mockVault) ListTransitKeys(ctx context.Context, token, mount string) ([]TransitKeySummary, error) {
|
|
return nil, nil
|
|
}
|
|
func (m *mockVault) GetTransitKey(ctx context.Context, token, mount, name string) (*TransitKeyDetail, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) CreateTransitKey(ctx context.Context, token, mount, name, keyType string) error {
|
|
return nil
|
|
}
|
|
func (m *mockVault) DeleteTransitKey(ctx context.Context, token, mount, name string) error {
|
|
return nil
|
|
}
|
|
func (m *mockVault) RotateTransitKey(ctx context.Context, token, mount, name string) error {
|
|
return nil
|
|
}
|
|
func (m *mockVault) UpdateTransitKeyConfig(ctx context.Context, token, mount, name string, minDecryptVersion int, allowDeletion bool) error {
|
|
return nil
|
|
}
|
|
func (m *mockVault) TrimTransitKey(ctx context.Context, token, mount, name string) (int, error) {
|
|
return 0, nil
|
|
}
|
|
func (m *mockVault) TransitEncrypt(ctx context.Context, token, mount, key, plaintext, transitCtx string) (string, error) {
|
|
return "", fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) TransitDecrypt(ctx context.Context, token, mount, key, ciphertext, transitCtx string) (string, error) {
|
|
return "", fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) TransitRewrap(ctx context.Context, token, mount, key, ciphertext, transitCtx string) (string, error) {
|
|
return "", fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) TransitSign(ctx context.Context, token, mount, key, input string) (string, error) {
|
|
return "", fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) TransitVerify(ctx context.Context, token, mount, key, input, signature string) (bool, error) {
|
|
return false, fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) TransitHMAC(ctx context.Context, token, mount, key, input string) (string, error) {
|
|
return "", fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) GetTransitPublicKey(ctx context.Context, token, mount, name string) (string, error) {
|
|
return "", fmt.Errorf("not implemented")
|
|
}
|
|
|
|
// User stubs
|
|
func (m *mockVault) UserRegister(ctx context.Context, token, mount string) (*UserKeyInfo, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) UserProvision(ctx context.Context, token, mount, username string) (*UserKeyInfo, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) GetUserPublicKey(ctx context.Context, token, mount, username string) (*UserKeyInfo, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) ListUsers(ctx context.Context, token, mount string) ([]string, error) {
|
|
return nil, nil
|
|
}
|
|
func (m *mockVault) UserEncrypt(ctx context.Context, token, mount, plaintext, metadata string, recipients []string) (string, error) {
|
|
return "", fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) UserDecrypt(ctx context.Context, token, mount, envelope string) (*UserDecryptResult, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) UserReEncrypt(ctx context.Context, token, mount, envelope string) (string, error) {
|
|
return "", fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) UserRotateKey(ctx context.Context, token, mount string) (*UserKeyInfo, error) {
|
|
return nil, fmt.Errorf("not implemented")
|
|
}
|
|
func (m *mockVault) UserDeleteUser(ctx context.Context, token, mount, username string) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockVault) Close() error { return nil }
|
|
|
|
// ---- handleTGZDownload tests ----
|
|
|
|
func TestHandleTGZDownload(t *testing.T) {
|
|
t.Run("valid token serves archive and removes it from cache", func(t *testing.T) {
|
|
ws := newTestWebServer(t, &mockVault{})
|
|
|
|
const dlToken = "abc123"
|
|
ws.tgzCache.Store(dlToken, &tgzEntry{
|
|
filename: "test.tgz",
|
|
data: []byte("fake-tgz-data"),
|
|
})
|
|
|
|
r := newChiRequest(http.MethodGet, "/pki/download/"+dlToken, map[string]string{"token": dlToken})
|
|
r = addAuthCookie(r, &TokenInfo{Username: "testuser"})
|
|
|
|
w := httptest.NewRecorder()
|
|
ws.handleTGZDownload(w, r)
|
|
|
|
if w.Code != http.StatusOK {
|
|
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
|
}
|
|
if ct := w.Header().Get("Content-Type"); ct != "application/gzip" {
|
|
t.Errorf("Content-Type = %q, want application/gzip", ct)
|
|
}
|
|
if body := w.Body.String(); body != "fake-tgz-data" {
|
|
t.Errorf("body = %q, want fake-tgz-data", body)
|
|
}
|
|
|
|
// Entry must have been removed — a second request should 404.
|
|
w2 := httptest.NewRecorder()
|
|
r2 := newChiRequest(http.MethodGet, "/pki/download/"+dlToken, map[string]string{"token": dlToken})
|
|
r2 = addAuthCookie(r2, &TokenInfo{Username: "testuser"})
|
|
ws.handleTGZDownload(w2, r2)
|
|
if w2.Code != http.StatusNotFound {
|
|
t.Errorf("second request status = %d, want %d", w2.Code, http.StatusNotFound)
|
|
}
|
|
})
|
|
|
|
t.Run("unknown token returns 404", func(t *testing.T) {
|
|
ws := newTestWebServer(t, &mockVault{})
|
|
|
|
r := newChiRequest(http.MethodGet, "/pki/download/nosuchtoken", map[string]string{"token": "nosuchtoken"})
|
|
r = addAuthCookie(r, &TokenInfo{Username: "testuser"})
|
|
|
|
w := httptest.NewRecorder()
|
|
ws.handleTGZDownload(w, r)
|
|
|
|
if w.Code != http.StatusNotFound {
|
|
t.Errorf("status = %d, want %d", w.Code, http.StatusNotFound)
|
|
}
|
|
if !strings.Contains(w.Body.String(), "download not found") {
|
|
t.Errorf("body %q does not contain 'download not found'", w.Body.String())
|
|
}
|
|
})
|
|
}
|
|
|
|
// newTestWebServer builds a WebServer wired to the given mock, suitable for unit tests.
|
|
func newTestWebServer(t *testing.T, vault vaultBackend) *WebServer {
|
|
t.Helper()
|
|
staticFS, err := fs.Sub(webui.FS, "static")
|
|
if err != nil {
|
|
t.Fatalf("static FS: %v", err)
|
|
}
|
|
return &WebServer{
|
|
vault: vault,
|
|
logger: slog.Default(),
|
|
staticFS: staticFS,
|
|
csrf: newTestCSRF(t),
|
|
}
|
|
}
|
|
|
|
// newChiRequest builds an *http.Request with chi URL params set.
|
|
func newChiRequest(method, path string, params map[string]string) *http.Request {
|
|
r := httptest.NewRequest(method, path, nil)
|
|
rctx := chi.NewRouteContext()
|
|
for k, v := range params {
|
|
rctx.URLParams.Add(k, v)
|
|
}
|
|
r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx))
|
|
return r
|
|
}
|
|
|
|
// addAuthCookie attaches a fake token cookie and injects TokenInfo into the request context.
|
|
func addAuthCookie(r *http.Request, info *TokenInfo) *http.Request {
|
|
r.AddCookie(&http.Cookie{Name: "metacrypt_token", Value: "test-token"})
|
|
return r.WithContext(withTokenInfo(r.Context(), info))
|
|
}
|
|
|
|
// ---- handleCertDetail tests ----
|
|
|
|
func TestHandleCertDetail(t *testing.T) {
|
|
sampleCert := &CertDetail{
|
|
Serial: "01:02:03",
|
|
Issuer: "myissuer",
|
|
CommonName: "example.com",
|
|
SANs: []string{"example.com", "www.example.com"},
|
|
Profile: "server",
|
|
IssuedBy: "testuser",
|
|
IssuedAt: "2025-01-01T00:00:00Z",
|
|
ExpiresAt: "2026-01-01T00:00:00Z",
|
|
CertPEM: "-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----\n",
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
serial string
|
|
listMountsFn func(ctx context.Context, token string) ([]MountInfo, error)
|
|
getCertFn func(ctx context.Context, token, mount, serial string) (*CertDetail, error)
|
|
wantStatus int
|
|
wantBodyContains string
|
|
}{
|
|
{
|
|
name: "success renders cert detail page",
|
|
serial: "01:02:03",
|
|
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
|
|
return sampleCert, nil
|
|
},
|
|
wantStatus: http.StatusOK,
|
|
wantBodyContains: "example.com",
|
|
},
|
|
{
|
|
name: "cert not found returns 404",
|
|
serial: "ff:ff:ff",
|
|
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
|
|
return nil, status.Error(codes.NotFound, "cert not found")
|
|
},
|
|
wantStatus: http.StatusNotFound,
|
|
wantBodyContains: "certificate not found",
|
|
},
|
|
{
|
|
name: "backend error returns 500",
|
|
serial: "01:02:03",
|
|
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
|
|
return nil, fmt.Errorf("internal error")
|
|
},
|
|
wantStatus: http.StatusInternalServerError,
|
|
wantBodyContains: "internal error",
|
|
},
|
|
{
|
|
name: "no CA mount returns 404",
|
|
serial: "01:02:03",
|
|
listMountsFn: func(_ context.Context, _ string) ([]MountInfo, error) {
|
|
return []MountInfo{}, nil
|
|
},
|
|
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
|
|
return sampleCert, nil
|
|
},
|
|
wantStatus: http.StatusNotFound,
|
|
wantBodyContains: "no CA engine mounted",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
mock := &mockVault{
|
|
listMountsFn: tc.listMountsFn,
|
|
getCertFn: tc.getCertFn,
|
|
}
|
|
ws := newTestWebServer(t, mock)
|
|
|
|
r := newChiRequest(http.MethodGet, "/pki/cert/"+tc.serial, map[string]string{"serial": tc.serial})
|
|
r = addAuthCookie(r, &TokenInfo{Username: "testuser"})
|
|
|
|
w := httptest.NewRecorder()
|
|
ws.handleCertDetail(w, r)
|
|
|
|
if w.Code != tc.wantStatus {
|
|
t.Errorf("status = %d, want %d", w.Code, tc.wantStatus)
|
|
}
|
|
if tc.wantBodyContains != "" && !strings.Contains(w.Body.String(), tc.wantBodyContains) {
|
|
t.Errorf("body %q does not contain %q", w.Body.String(), tc.wantBodyContains)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// ---- handleCertDownload tests ----
|
|
|
|
func TestHandleCertDownload(t *testing.T) {
|
|
samplePEM := "-----BEGIN CERTIFICATE-----\nfake\n-----END CERTIFICATE-----\n"
|
|
sampleCert := &CertDetail{
|
|
Serial: "01:02:03",
|
|
CertPEM: samplePEM,
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
serial string
|
|
listMountsFn func(ctx context.Context, token string) ([]MountInfo, error)
|
|
getCertFn func(ctx context.Context, token, mount, serial string) (*CertDetail, error)
|
|
wantStatus int
|
|
wantBodyContains string
|
|
wantContentType string
|
|
wantDisposition string
|
|
}{
|
|
{
|
|
name: "success streams PEM file",
|
|
serial: "01:02:03",
|
|
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
|
|
return sampleCert, nil
|
|
},
|
|
wantStatus: http.StatusOK,
|
|
wantBodyContains: samplePEM,
|
|
wantContentType: "application/x-pem-file",
|
|
wantDisposition: `attachment; filename="01:02:03.pem"`,
|
|
},
|
|
{
|
|
name: "cert not found returns 404",
|
|
serial: "ff:ff:ff",
|
|
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
|
|
return nil, status.Error(codes.NotFound, "cert not found")
|
|
},
|
|
wantStatus: http.StatusNotFound,
|
|
wantBodyContains: "certificate not found",
|
|
},
|
|
{
|
|
name: "backend error returns 500",
|
|
serial: "01:02:03",
|
|
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
|
|
return nil, fmt.Errorf("storage failure")
|
|
},
|
|
wantStatus: http.StatusInternalServerError,
|
|
wantBodyContains: "storage failure",
|
|
},
|
|
{
|
|
name: "no CA mount returns 404",
|
|
serial: "01:02:03",
|
|
listMountsFn: func(_ context.Context, _ string) ([]MountInfo, error) {
|
|
return []MountInfo{}, nil
|
|
},
|
|
getCertFn: func(_ context.Context, _, _, _ string) (*CertDetail, error) {
|
|
return sampleCert, nil
|
|
},
|
|
wantStatus: http.StatusNotFound,
|
|
wantBodyContains: "no CA engine mounted",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
mock := &mockVault{
|
|
listMountsFn: tc.listMountsFn,
|
|
getCertFn: tc.getCertFn,
|
|
}
|
|
ws := newTestWebServer(t, mock)
|
|
|
|
r := newChiRequest(http.MethodGet, "/pki/cert/"+tc.serial+"/download", map[string]string{"serial": tc.serial})
|
|
r = addAuthCookie(r, &TokenInfo{Username: "testuser"})
|
|
|
|
w := httptest.NewRecorder()
|
|
ws.handleCertDownload(w, r)
|
|
|
|
if w.Code != tc.wantStatus {
|
|
t.Errorf("status = %d, want %d", w.Code, tc.wantStatus)
|
|
}
|
|
if tc.wantBodyContains != "" && !strings.Contains(w.Body.String(), tc.wantBodyContains) {
|
|
t.Errorf("body %q does not contain %q", w.Body.String(), tc.wantBodyContains)
|
|
}
|
|
if tc.wantContentType != "" {
|
|
if got := w.Header().Get("Content-Type"); got != tc.wantContentType {
|
|
t.Errorf("Content-Type = %q, want %q", got, tc.wantContentType)
|
|
}
|
|
}
|
|
if tc.wantDisposition != "" {
|
|
if got := w.Header().Get("Content-Disposition"); got != tc.wantDisposition {
|
|
t.Errorf("Content-Disposition = %q, want %q", got, tc.wantDisposition)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|