package webserver import ( "net/http" "net/http/httptest" "strings" "testing" ) func TestCSRFTokenGenerateAndValidate(t *testing.T) { c := newCSRFProtect() token, err := c.generateToken() if err != nil { t.Fatalf("generateToken: %v", err) } if !c.validToken(token) { t.Fatal("valid token rejected") } } func TestCSRFTokenInvalidFormats(t *testing.T) { c := newCSRFProtect() for _, bad := range []string{"", "nodot", "a.b.c", "abc.def"} { if c.validToken(bad) { t.Errorf("should reject %q", bad) } } } func TestCSRFTokenCrossSecret(t *testing.T) { c1 := newCSRFProtect() c2 := newCSRFProtect() token, _ := c1.generateToken() if c2.validToken(token) { t.Fatal("token from different secret should be rejected") } } func TestCSRFMiddlewareAllowsGET(t *testing.T) { c := newCSRFProtect() handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("GET should pass through, got %d", w.Code) } } func TestCSRFMiddlewareBlocksPOSTWithoutToken(t *testing.T) { c := newCSRFProtect() handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar")) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusForbidden { t.Fatalf("POST without CSRF token should be forbidden, got %d", w.Code) } } func TestCSRFMiddlewareAllowsPOSTWithValidToken(t *testing.T) { c := newCSRFProtect() token, _ := c.generateToken() handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) body := csrfFieldName + "=" + token req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: token}) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Fatalf("POST with valid CSRF token should pass, got %d", w.Code) } } func TestCSRFMiddlewareRejectsMismatch(t *testing.T) { c := newCSRFProtect() token1, _ := c.generateToken() token2, _ := c.generateToken() handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) body := csrfFieldName + "=" + token1 req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body)) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: token2}) w := httptest.NewRecorder() handler.ServeHTTP(w, req) if w.Code != http.StatusForbidden { t.Fatalf("POST with mismatched tokens should be forbidden, got %d", w.Code) } }