From 806f63957bac460eabad27d5fb06b6755938aeb9 Mon Sep 17 00:00:00 2001 From: Kyle Isom Date: Thu, 26 Mar 2026 14:14:11 -0700 Subject: [PATCH] Migrate CSRF, web templates, session cookies, and snapshot to mcdsl CSRF: Replace local csrfProtect with mcdsl/csrf.Protect. Delete internal/webserver/csrf.go. Web: Replace renderTemplate with web.RenderTemplate + csrf.TemplateFunc. Replace extractCookie with web.GetSessionToken. Replace manual session cookie SetCookie with web.SetSessionCookie. Snapshot: Replace local sqliteBackup with mcdsl/db.Snapshot. Co-Authored-By: Claude Opus 4.6 (1M context) --- cmd/metacrypt/snapshot.go | 13 +-- internal/webserver/cert_detail_test.go | 2 +- internal/webserver/csrf.go | 119 ------------------------- internal/webserver/csrf_test.go | 78 +++++----------- internal/webserver/routes.go | 11 +-- internal/webserver/server.go | 46 +++------- 6 files changed, 41 insertions(+), 228 deletions(-) delete mode 100644 internal/webserver/csrf.go diff --git a/cmd/metacrypt/snapshot.go b/cmd/metacrypt/snapshot.go index 99ede15..22fee7d 100644 --- a/cmd/metacrypt/snapshot.go +++ b/cmd/metacrypt/snapshot.go @@ -1,12 +1,11 @@ package main import ( - "database/sql" "fmt" - "os" "github.com/spf13/cobra" + mcdsldb "git.wntrmute.dev/kyle/mcdsl/db" "git.wntrmute.dev/kyle/metacrypt/internal/config" "git.wntrmute.dev/kyle/metacrypt/internal/db" ) @@ -42,17 +41,9 @@ func runSnapshot(cmd *cobra.Command, args []string) error { } defer func() { _ = database.Close() }() - if err := sqliteBackup(database, snapshotOutput); err != nil { + if err := mcdsldb.Snapshot(database, snapshotOutput); err != nil { return err } fmt.Printf("Snapshot saved to %s\n", snapshotOutput) return nil } - -func sqliteBackup(srcDB *sql.DB, dstPath string) error { - _, err := srcDB.Exec("VACUUM INTO ?", dstPath) - if err != nil { - return fmt.Errorf("snapshot: %w", err) - } - return os.Chmod(dstPath, 0600) -} diff --git a/internal/webserver/cert_detail_test.go b/internal/webserver/cert_detail_test.go index a8fd8bd..83f57c2 100644 --- a/internal/webserver/cert_detail_test.go +++ b/internal/webserver/cert_detail_test.go @@ -307,7 +307,7 @@ func newTestWebServer(t *testing.T, vault vaultBackend) *WebServer { vault: vault, logger: slog.Default(), staticFS: staticFS, - csrf: newCSRFProtect(), + csrf: newTestCSRF(t), } } diff --git a/internal/webserver/csrf.go b/internal/webserver/csrf.go deleted file mode 100644 index 21b5944..0000000 --- a/internal/webserver/csrf.go +++ /dev/null @@ -1,119 +0,0 @@ -package webserver - -import ( - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "encoding/base64" - "fmt" - "net/http" - "strings" - "sync" -) - -const ( - csrfCookieName = "metacrypt_csrf" - csrfFieldName = "csrf_token" - csrfTokenLen = 32 -) - -// csrfProtect provides CSRF protection using the signed double-submit cookie -// pattern. A random secret is generated at startup. CSRF tokens are an HMAC of -// a random nonce, sent as both a cookie and a hidden form field. On POST the -// middleware verifies that the form field matches the cookie's HMAC. -type csrfProtect struct { - secret []byte - once sync.Once -} - -func newCSRFProtect() *csrfProtect { - secret := make([]byte, 32) - if _, err := rand.Read(secret); err != nil { - panic("csrf: failed to generate secret: " + err.Error()) - } - return &csrfProtect{secret: secret} -} - -// generateToken creates a new CSRF token: base64(nonce) + "." + base64(hmac(nonce)). -func (c *csrfProtect) generateToken() (string, error) { - nonce := make([]byte, csrfTokenLen) - if _, err := rand.Read(nonce); err != nil { - return "", fmt.Errorf("csrf: generate nonce: %w", err) - } - nonceB64 := base64.RawURLEncoding.EncodeToString(nonce) - mac := hmac.New(sha256.New, c.secret) - mac.Write(nonce) - sigB64 := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) - return nonceB64 + "." + sigB64, nil -} - -// validToken checks that a token has a valid HMAC signature. -func (c *csrfProtect) validToken(token string) bool { - parts := strings.SplitN(token, ".", 2) - if len(parts) != 2 { - return false - } - nonce, err := base64.RawURLEncoding.DecodeString(parts[0]) - if err != nil || len(nonce) != csrfTokenLen { - return false - } - sig, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - return false - } - mac := hmac.New(sha256.New, c.secret) - mac.Write(nonce) - return hmac.Equal(mac.Sum(nil), sig) -} - -// setToken generates a new CSRF token, sets it as a cookie, and returns it -// for embedding in a form. -func (c *csrfProtect) setToken(w http.ResponseWriter) string { - token, err := c.generateToken() - if err != nil { - return "" - } - http.SetCookie(w, &http.Cookie{ - Name: csrfCookieName, - Value: token, - Path: "/", - HttpOnly: true, - Secure: true, - SameSite: http.SameSiteStrictMode, - }) - return token -} - -// middleware returns an HTTP middleware that enforces CSRF validation on -// mutation requests (POST, PUT, PATCH, DELETE). GET/HEAD/OPTIONS are passed -// through. The HTMX hx-post for /v1/seal is excluded since it hits the API -// server directly and uses token auth, not cookies. -func (c *csrfProtect) middleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodGet, http.MethodHead, http.MethodOptions: - next.ServeHTTP(w, r) - return - } - - // Read token from form field (works for both regular forms and - // multipart forms since ParseForm/ParseMultipartForm will have - // been called or the field is available via FormValue). - formToken := r.FormValue(csrfFieldName) - - // Read token from cookie. - cookie, err := r.Cookie(csrfCookieName) - if err != nil || cookie.Value == "" { - http.Error(w, "CSRF validation failed", http.StatusForbidden) - return - } - - // Both tokens must be present, match each other, and be validly signed. - if formToken == "" || formToken != cookie.Value || !c.validToken(formToken) { - http.Error(w, "CSRF validation failed", http.StatusForbidden) - return - } - - next.ServeHTTP(w, r) - }) -} diff --git a/internal/webserver/csrf_test.go b/internal/webserver/csrf_test.go index 5f223f2..60e547e 100644 --- a/internal/webserver/csrf_test.go +++ b/internal/webserver/csrf_test.go @@ -1,44 +1,28 @@ package webserver import ( + "crypto/rand" "net/http" "net/http/httptest" + "net/url" "strings" "testing" + + "git.wntrmute.dev/kyle/mcdsl/csrf" ) -func TestCSRFTokenGenerateAndValidate(t *testing.T) { - c := newCSRFProtect() - token, err := c.generateToken() - if err != nil { - t.Fatalf("generateToken: %v", err) - } - if !c.validToken(token) { - t.Fatal("valid token rejected") - } -} - -func TestCSRFTokenInvalidFormats(t *testing.T) { - c := newCSRFProtect() - for _, bad := range []string{"", "nodot", "a.b.c", "abc.def"} { - if c.validToken(bad) { - t.Errorf("should reject %q", bad) - } - } -} - -func TestCSRFTokenCrossSecret(t *testing.T) { - c1 := newCSRFProtect() - c2 := newCSRFProtect() - token, _ := c1.generateToken() - if c2.validToken(token) { - t.Fatal("token from different secret should be rejected") +func newTestCSRF(t *testing.T) *csrf.Protect { + t.Helper() + secret := make([]byte, 32) + if _, err := rand.Read(secret); err != nil { + t.Fatal(err) } + return csrf.New(secret, "metacrypt_csrf", "csrf_token") } func TestCSRFMiddlewareAllowsGET(t *testing.T) { - c := newCSRFProtect() - handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c := newTestCSRF(t) + handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -50,8 +34,8 @@ func TestCSRFMiddlewareAllowsGET(t *testing.T) { } func TestCSRFMiddlewareBlocksPOSTWithoutToken(t *testing.T) { - c := newCSRFProtect() - handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c := newTestCSRF(t) + handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")) @@ -64,17 +48,20 @@ func TestCSRFMiddlewareBlocksPOSTWithoutToken(t *testing.T) { } func TestCSRFMiddlewareAllowsPOSTWithValidToken(t *testing.T) { - c := newCSRFProtect() - token, _ := c.generateToken() + c := newTestCSRF(t) - handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Generate a token via SetToken (writes cookie to response). + rec := httptest.NewRecorder() + token := c.SetToken(rec) + + handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) - body := csrfFieldName + "=" + token + body := "csrf_token=" + url.QueryEscape(token) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: token}) + req.AddCookie(&http.Cookie{Name: "metacrypt_csrf", Value: token}) w := httptest.NewRecorder() handler.ServeHTTP(w, req) @@ -82,24 +69,3 @@ func TestCSRFMiddlewareAllowsPOSTWithValidToken(t *testing.T) { t.Fatalf("POST with valid CSRF token should pass, got %d", w.Code) } } - -func TestCSRFMiddlewareRejectsMismatch(t *testing.T) { - c := newCSRFProtect() - token1, _ := c.generateToken() - token2, _ := c.generateToken() - - handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - })) - - body := csrfFieldName + "=" + token1 - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: token2}) - - w := httptest.NewRecorder() - handler.ServeHTTP(w, req) - if w.Code != http.StatusForbidden { - t.Fatalf("POST with mismatched tokens should be forbidden, got %d", w.Code) - } -} diff --git a/internal/webserver/routes.go b/internal/webserver/routes.go index fdbc2bd..3b6b542 100644 --- a/internal/webserver/routes.go +++ b/internal/webserver/routes.go @@ -16,6 +16,8 @@ import ( "github.com/go-chi/chi/v5" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + + "git.wntrmute.dev/kyle/mcdsl/web" ) // splitLines splits a newline-delimited string into non-empty trimmed lines. @@ -224,14 +226,7 @@ func (ws *WebServer) handleLogin(w http.ResponseWriter, r *http.Request) { ws.renderTemplate(w, "login.html", map[string]interface{}{"Error": "Invalid credentials"}) return } - http.SetCookie(w, &http.Cookie{ - Name: "metacrypt_token", - Value: token, - Path: "/", - HttpOnly: true, - Secure: true, - SameSite: http.SameSiteStrictMode, - }) + web.SetSessionCookie(w, "metacrypt_token", token) http.Redirect(w, r, "/dashboard", http.StatusFound) default: http.Error(w, "method not allowed", http.StatusMethodNotAllowed) diff --git a/internal/webserver/server.go b/internal/webserver/server.go index 5d3bfbb..58ea7e7 100644 --- a/internal/webserver/server.go +++ b/internal/webserver/server.go @@ -4,9 +4,9 @@ package webserver import ( "context" + "crypto/rand" "crypto/tls" "fmt" - "html/template" "io/fs" "log/slog" "net/http" @@ -16,6 +16,8 @@ import ( "github.com/go-chi/chi/v5" mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth" + "git.wntrmute.dev/kyle/mcdsl/csrf" + "git.wntrmute.dev/kyle/mcdsl/web" "git.wntrmute.dev/kyle/metacrypt/internal/config" webui "git.wntrmute.dev/kyle/metacrypt/web" ) @@ -116,7 +118,7 @@ type WebServer struct { logger *slog.Logger httpSrv *http.Server staticFS fs.FS - csrf *csrfProtect + csrf *csrf.Protect tgzCache sync.Map // key: UUID string → *tgzEntry userCache sync.Map // key: UUID string → *cachedUsername } @@ -154,12 +156,17 @@ func New(cfg *config.Config, logger *slog.Logger) (*WebServer, error) { return nil, fmt.Errorf("webserver: static FS: %w", err) } + secret := make([]byte, 32) + if _, err := rand.Read(secret); err != nil { + return nil, fmt.Errorf("webserver: generate CSRF secret: %w", err) + } + ws := &WebServer{ cfg: cfg, vault: vault, logger: logger, staticFS: staticFS, - csrf: newCSRFProtect(), + csrf: csrf.New(secret, "metacrypt_csrf", "csrf_token"), } if tok := cfg.MCIAS.ServiceToken; tok != "" { @@ -220,7 +227,7 @@ func (lw *loggingResponseWriter) Unwrap() http.ResponseWriter { func (ws *WebServer) Start() error { r := chi.NewRouter() r.Use(ws.loggingMiddleware) - r.Use(ws.csrf.middleware) + r.Use(ws.csrf.Middleware) ws.registerRoutes(r) ws.httpSrv = &http.Server{ @@ -259,36 +266,9 @@ func (ws *WebServer) Shutdown(ctx context.Context) error { } func (ws *WebServer) renderTemplate(w http.ResponseWriter, name string, data interface{}) { - csrfToken := ws.csrf.setToken(w) - - funcMap := template.FuncMap{ - "csrfField": func() template.HTML { - return template.HTML(fmt.Sprintf( - ``, - csrfFieldName, template.HTMLEscapeString(csrfToken), - )) - }, - } - - tmpl, err := template.New("").Funcs(funcMap).ParseFS(webui.FS, - "templates/layout.html", - "templates/"+name, - ) - if err != nil { - ws.logger.Error("parse template", "name", name, "error", err) - http.Error(w, "internal server error", http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "text/html; charset=utf-8") - if err := tmpl.ExecuteTemplate(w, "layout", data); err != nil { - ws.logger.Error("execute template", "name", name, "error", err) - } + web.RenderTemplate(w, webui.FS, name, data, ws.csrf.TemplateFunc(w)) } func extractCookie(r *http.Request) string { - c, err := r.Cookie("metacrypt_token") - if err != nil { - return "" - } - return c.Value + return web.GetSessionToken(r, "metacrypt_token") }