CSRF: Replace local csrfProtect with mcdsl/csrf.Protect. Delete internal/webserver/csrf.go. Web: Replace renderTemplate with web.RenderTemplate + csrf.TemplateFunc. Replace extractCookie with web.GetSessionToken. Replace manual session cookie SetCookie with web.SetSessionCookie. Snapshot: Replace local sqliteBackup with mcdsl/db.Snapshot. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
72 lines
2.0 KiB
Go
72 lines
2.0 KiB
Go
package webserver
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
|
|
"git.wntrmute.dev/kyle/mcdsl/csrf"
|
|
)
|
|
|
|
func newTestCSRF(t *testing.T) *csrf.Protect {
|
|
t.Helper()
|
|
secret := make([]byte, 32)
|
|
if _, err := rand.Read(secret); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return csrf.New(secret, "metacrypt_csrf", "csrf_token")
|
|
}
|
|
|
|
func TestCSRFMiddlewareAllowsGET(t *testing.T) {
|
|
c := newTestCSRF(t)
|
|
handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
w := httptest.NewRecorder()
|
|
handler.ServeHTTP(w, req)
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("GET should pass through, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
func TestCSRFMiddlewareBlocksPOSTWithoutToken(t *testing.T) {
|
|
c := newTestCSRF(t)
|
|
handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar"))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
w := httptest.NewRecorder()
|
|
handler.ServeHTTP(w, req)
|
|
if w.Code != http.StatusForbidden {
|
|
t.Fatalf("POST without CSRF token should be forbidden, got %d", w.Code)
|
|
}
|
|
}
|
|
|
|
func TestCSRFMiddlewareAllowsPOSTWithValidToken(t *testing.T) {
|
|
c := newTestCSRF(t)
|
|
|
|
// Generate a token via SetToken (writes cookie to response).
|
|
rec := httptest.NewRecorder()
|
|
token := c.SetToken(rec)
|
|
|
|
handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
body := "csrf_token=" + url.QueryEscape(token)
|
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
req.AddCookie(&http.Cookie{Name: "metacrypt_csrf", Value: token})
|
|
|
|
w := httptest.NewRecorder()
|
|
handler.ServeHTTP(w, req)
|
|
if w.Code != http.StatusOK {
|
|
t.Fatalf("POST with valid CSRF token should pass, got %d", w.Code)
|
|
}
|
|
}
|