package webserver import ( "crypto/rand" "net/http" "net/http/httptest" "net/url" "strings" "testing" "git.wntrmute.dev/kyle/mcdsl/csrf" ) func newTestCSRF(t *testing.T) *csrf.Protect { t.Helper() secret := make([]byte, 32) if _, err := rand.Read(secret); err != nil { t.Fatal(err) } return csrf.New(secret, "metacrypt_csrf", "csrf_token") } func TestCSRFMiddlewareAllowsGET(t *testing.T) { c := newTestCSRF(t) handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("GET should pass through, got %d", w.Code) } } func TestCSRFMiddlewareBlocksPOSTWithoutToken(t *testing.T) { c := newTestCSRF(t) handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusForbidden { t.Fatalf("POST without CSRF token should be forbidden, got %d", w.Code) } } func TestCSRFMiddlewareAllowsPOSTWithValidToken(t *testing.T) { c := newTestCSRF(t) // Generate a token via SetToken (writes cookie to response). rec := httptest.NewRecorder() token := c.SetToken(rec) handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) body := "csrf_token=" + url.QueryEscape(token) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.AddCookie(&http.Cookie{Name: "metacrypt_csrf", Value: token}) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("POST with valid CSRF token should pass, got %d", w.Code) } }