package csrf import ( "crypto/rand" "html/template" "net/http" "net/http/httptest" "net/url" "strings" "testing" ) func testProtect(t *testing.T) *Protect { t.Helper() secret := make([]byte, 32) if _, err := rand.Read(secret); err != nil { t.Fatalf("generate secret: %v", err) } return New(secret, "test_csrf", "csrf_token") } func TestGenerateAndValidate(t *testing.T) { p := testProtect(t) token := p.generateToken() if token == "" { t.Fatal("empty token") } if !strings.Contains(token, ".") { t.Fatal("token missing separator") } if !p.validateToken(token) { t.Fatal("valid token rejected") } } func TestValidateInvalid(t *testing.T) { p := testProtect(t) tests := []struct { name string token string }{ {"empty", ""}, {"no separator", "abcdef"}, {"bad nonce", "!!!." + "AAAA"}, {"bad sig", "AAAA" + ".!!!"}, {"wrong sig", "AAAA.BBBB"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if p.validateToken(tt.token) { t.Fatalf("invalid token %q accepted", tt.token) } }) } } func TestValidateWrongSecret(t *testing.T) { p1 := testProtect(t) p2 := testProtect(t) token := p1.generateToken() if p2.validateToken(token) { t.Fatal("token validated with wrong secret") } } func TestTokenUniqueness(t *testing.T) { p := testProtect(t) t1 := p.generateToken() t2 := p.generateToken() if t1 == t2 { t.Fatal("two tokens are identical") } } func TestSetToken(t *testing.T) { p := testProtect(t) rec := httptest.NewRecorder() token := p.SetToken(rec) if token == "" { t.Fatal("empty token") } cookies := rec.Result().Cookies() var found bool for _, c := range cookies { if c.Name == "test_csrf" { found = true if c.Value != token { t.Fatalf("cookie value = %q, want %q", c.Value, token) } if !c.HttpOnly { t.Fatal("cookie not HttpOnly") } if !c.Secure { t.Fatal("cookie not Secure") } if c.SameSite != http.SameSiteStrictMode { t.Fatal("cookie not SameSite=Strict") } } } if !found { t.Fatal("CSRF cookie not set") } } func TestMiddlewareSafeMethods(t *testing.T) { p := testProtect(t) called := false handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) for _, method := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} { called = false rec := httptest.NewRecorder() req := httptest.NewRequest(method, "/", nil) handler.ServeHTTP(rec, req) if !called { t.Fatalf("%s: handler not called", method) } } } func TestMiddlewareRejectsMissingToken(t *testing.T) { p := testProtect(t) handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { t.Fatal("handler should not be called") })) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/", nil) handler.ServeHTTP(rec, req) if rec.Code != http.StatusForbidden { t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden) } } func TestMiddlewareRejectsMismatch(t *testing.T) { p := testProtect(t) handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { t.Fatal("handler should not be called") })) token := p.generateToken() form := url.Values{"csrf_token": {"wrong-token"}} rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.AddCookie(&http.Cookie{Name: "test_csrf", Value: token}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusForbidden { t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden) } } func TestMiddlewareAcceptsValid(t *testing.T) { p := testProtect(t) called := false handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { called = true })) token := p.generateToken() form := url.Values{"csrf_token": {token}} rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.AddCookie(&http.Cookie{Name: "test_csrf", Value: token}) handler.ServeHTTP(rec, req) if !called { t.Fatal("handler not called for valid CSRF") } } func TestTemplateFunc(t *testing.T) { p := testProtect(t) rec := httptest.NewRecorder() funcs := p.TemplateFunc(rec) csrfField, ok := funcs["csrfField"] if !ok { t.Fatal("csrfField not in FuncMap") } fn, ok := csrfField.(func() template.HTML) if !ok { t.Fatal("csrfField is not func() template.HTML") } html := string(fn()) if !strings.Contains(html, `name="csrf_token"`) { t.Fatalf("csrfField HTML missing field name: %s", html) } if !strings.Contains(html, `type="hidden"`) { t.Fatalf("csrfField HTML missing type=hidden: %s", html) } }