Add csrf package: HMAC-SHA256 double-submit cookies
- Protect with configurable secret, cookie name, field name - Middleware validates POST/PUT/PATCH/DELETE, passes GET/HEAD/OPTIONS - SetToken generates token and sets HttpOnly/Secure/SameSite=Strict cookie - TemplateFunc returns FuncMap with csrfField helper for templates - Token format: base64(nonce).base64(HMAC-SHA256(secret, nonce)) - 10 tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
207
csrf/csrf_test.go
Normal file
207
csrf/csrf_test.go
Normal file
@@ -0,0 +1,207 @@
|
||||
package csrf
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"html/template"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func testProtect(t *testing.T) *Protect {
|
||||
t.Helper()
|
||||
secret := make([]byte, 32)
|
||||
if _, err := rand.Read(secret); err != nil {
|
||||
t.Fatalf("generate secret: %v", err)
|
||||
}
|
||||
return New(secret, "test_csrf", "csrf_token")
|
||||
}
|
||||
|
||||
func TestGenerateAndValidate(t *testing.T) {
|
||||
p := testProtect(t)
|
||||
token := p.generateToken()
|
||||
|
||||
if token == "" {
|
||||
t.Fatal("empty token")
|
||||
}
|
||||
if !strings.Contains(token, ".") {
|
||||
t.Fatal("token missing separator")
|
||||
}
|
||||
if !p.validateToken(token) {
|
||||
t.Fatal("valid token rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateInvalid(t *testing.T) {
|
||||
p := testProtect(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
}{
|
||||
{"empty", ""},
|
||||
{"no separator", "abcdef"},
|
||||
{"bad nonce", "!!!." + "AAAA"},
|
||||
{"bad sig", "AAAA" + ".!!!"},
|
||||
{"wrong sig", "AAAA.BBBB"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if p.validateToken(tt.token) {
|
||||
t.Fatalf("invalid token %q accepted", tt.token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateWrongSecret(t *testing.T) {
|
||||
p1 := testProtect(t)
|
||||
p2 := testProtect(t)
|
||||
|
||||
token := p1.generateToken()
|
||||
if p2.validateToken(token) {
|
||||
t.Fatal("token validated with wrong secret")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenUniqueness(t *testing.T) {
|
||||
p := testProtect(t)
|
||||
t1 := p.generateToken()
|
||||
t2 := p.generateToken()
|
||||
if t1 == t2 {
|
||||
t.Fatal("two tokens are identical")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetToken(t *testing.T) {
|
||||
p := testProtect(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
token := p.SetToken(rec)
|
||||
|
||||
if token == "" {
|
||||
t.Fatal("empty token")
|
||||
}
|
||||
|
||||
cookies := rec.Result().Cookies()
|
||||
var found bool
|
||||
for _, c := range cookies {
|
||||
if c.Name == "test_csrf" {
|
||||
found = true
|
||||
if c.Value != token {
|
||||
t.Fatalf("cookie value = %q, want %q", c.Value, token)
|
||||
}
|
||||
if !c.HttpOnly {
|
||||
t.Fatal("cookie not HttpOnly")
|
||||
}
|
||||
if !c.Secure {
|
||||
t.Fatal("cookie not Secure")
|
||||
}
|
||||
if c.SameSite != http.SameSiteStrictMode {
|
||||
t.Fatal("cookie not SameSite=Strict")
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("CSRF cookie not set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareSafeMethods(t *testing.T) {
|
||||
p := testProtect(t)
|
||||
called := false
|
||||
handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
for _, method := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} {
|
||||
called = false
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(method, "/", nil)
|
||||
handler.ServeHTTP(rec, req)
|
||||
if !called {
|
||||
t.Fatalf("%s: handler not called", method)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareRejectsMissingToken(t *testing.T) {
|
||||
p := testProtect(t)
|
||||
handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
t.Fatal("handler should not be called")
|
||||
}))
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareRejectsMismatch(t *testing.T) {
|
||||
p := testProtect(t)
|
||||
handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
t.Fatal("handler should not be called")
|
||||
}))
|
||||
|
||||
token := p.generateToken()
|
||||
form := url.Values{"csrf_token": {"wrong-token"}}
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.AddCookie(&http.Cookie{Name: "test_csrf", Value: token})
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusForbidden {
|
||||
t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareAcceptsValid(t *testing.T) {
|
||||
p := testProtect(t)
|
||||
called := false
|
||||
handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||
called = true
|
||||
}))
|
||||
|
||||
token := p.generateToken()
|
||||
form := url.Values{"csrf_token": {token}}
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.AddCookie(&http.Cookie{Name: "test_csrf", Value: token})
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if !called {
|
||||
t.Fatal("handler not called for valid CSRF")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateFunc(t *testing.T) {
|
||||
p := testProtect(t)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
funcs := p.TemplateFunc(rec)
|
||||
csrfField, ok := funcs["csrfField"]
|
||||
if !ok {
|
||||
t.Fatal("csrfField not in FuncMap")
|
||||
}
|
||||
|
||||
fn, ok := csrfField.(func() template.HTML)
|
||||
if !ok {
|
||||
t.Fatal("csrfField is not func() template.HTML")
|
||||
}
|
||||
|
||||
html := string(fn())
|
||||
if !strings.Contains(html, `name="csrf_token"`) {
|
||||
t.Fatalf("csrfField HTML missing field name: %s", html)
|
||||
}
|
||||
if !strings.Contains(html, `type="hidden"`) {
|
||||
t.Fatalf("csrfField HTML missing type=hidden: %s", html)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user