Migrate CSRF, web templates, session cookies, and snapshot to mcdsl

CSRF: Replace local csrfProtect with mcdsl/csrf.Protect. Delete
internal/webserver/csrf.go.

Web: Replace renderTemplate with web.RenderTemplate + csrf.TemplateFunc.
Replace extractCookie with web.GetSessionToken. Replace manual session
cookie SetCookie with web.SetSessionCookie.

Snapshot: Replace local sqliteBackup with mcdsl/db.Snapshot.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 14:14:11 -07:00
parent 2a927e5359
commit 806f63957b
6 changed files with 41 additions and 228 deletions

View File

@@ -1,44 +1,28 @@
package webserver
import (
"crypto/rand"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"git.wntrmute.dev/kyle/mcdsl/csrf"
)
func TestCSRFTokenGenerateAndValidate(t *testing.T) {
c := newCSRFProtect()
token, err := c.generateToken()
if err != nil {
t.Fatalf("generateToken: %v", err)
}
if !c.validToken(token) {
t.Fatal("valid token rejected")
}
}
func TestCSRFTokenInvalidFormats(t *testing.T) {
c := newCSRFProtect()
for _, bad := range []string{"", "nodot", "a.b.c", "abc.def"} {
if c.validToken(bad) {
t.Errorf("should reject %q", bad)
}
}
}
func TestCSRFTokenCrossSecret(t *testing.T) {
c1 := newCSRFProtect()
c2 := newCSRFProtect()
token, _ := c1.generateToken()
if c2.validToken(token) {
t.Fatal("token from different secret should be rejected")
func newTestCSRF(t *testing.T) *csrf.Protect {
t.Helper()
secret := make([]byte, 32)
if _, err := rand.Read(secret); err != nil {
t.Fatal(err)
}
return csrf.New(secret, "metacrypt_csrf", "csrf_token")
}
func TestCSRFMiddlewareAllowsGET(t *testing.T) {
c := newCSRFProtect()
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c := newTestCSRF(t)
handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
@@ -50,8 +34,8 @@ func TestCSRFMiddlewareAllowsGET(t *testing.T) {
}
func TestCSRFMiddlewareBlocksPOSTWithoutToken(t *testing.T) {
c := newCSRFProtect()
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c := newTestCSRF(t)
handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar"))
@@ -64,17 +48,20 @@ func TestCSRFMiddlewareBlocksPOSTWithoutToken(t *testing.T) {
}
func TestCSRFMiddlewareAllowsPOSTWithValidToken(t *testing.T) {
c := newCSRFProtect()
token, _ := c.generateToken()
c := newTestCSRF(t)
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Generate a token via SetToken (writes cookie to response).
rec := httptest.NewRecorder()
token := c.SetToken(rec)
handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
body := csrfFieldName + "=" + token
body := "csrf_token=" + url.QueryEscape(token)
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: token})
req.AddCookie(&http.Cookie{Name: "metacrypt_csrf", Value: token})
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
@@ -82,24 +69,3 @@ func TestCSRFMiddlewareAllowsPOSTWithValidToken(t *testing.T) {
t.Fatalf("POST with valid CSRF token should pass, got %d", w.Code)
}
}
func TestCSRFMiddlewareRejectsMismatch(t *testing.T) {
c := newCSRFProtect()
token1, _ := c.generateToken()
token2, _ := c.generateToken()
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
body := csrfFieldName + "=" + token1
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: token2})
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Fatalf("POST with mismatched tokens should be forbidden, got %d", w.Code)
}
}