From 27f81c81acaae730fe98b57ae9d1519af96f63f9 Mon Sep 17 00:00:00 2001 From: Kyle Isom Date: Wed, 25 Mar 2026 16:29:42 -0700 Subject: [PATCH] Add csrf package: HMAC-SHA256 double-submit cookies - Protect with configurable secret, cookie name, field name - Middleware validates POST/PUT/PATCH/DELETE, passes GET/HEAD/OPTIONS - SetToken generates token and sets HttpOnly/Secure/SameSite=Strict cookie - TemplateFunc returns FuncMap with csrfField helper for templates - Token format: base64(nonce).base64(HMAC-SHA256(secret, nonce)) - 10 tests Co-Authored-By: Claude Opus 4.6 (1M context) --- csrf/csrf.go | 144 ++++++++++++++++++++++++++++++++ csrf/csrf_test.go | 207 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 351 insertions(+) create mode 100644 csrf/csrf.go create mode 100644 csrf/csrf_test.go diff --git a/csrf/csrf.go b/csrf/csrf.go new file mode 100644 index 0000000..c4b8d44 --- /dev/null +++ b/csrf/csrf.go @@ -0,0 +1,144 @@ +// Package csrf provides HMAC-SHA256 double-submit cookie CSRF protection +// for Metacircular web UIs. +// +// The token format is base64(nonce) + "." + base64(HMAC-SHA256(secret, nonce)). +// A fresh token is set as a cookie on each page load. Mutating requests +// (POST, PUT, PATCH, DELETE) must include the token as a form field that +// matches the cookie value. Both the match and the HMAC signature are +// verified. +package csrf + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "fmt" + "html/template" + "net/http" + "strings" +) + +// Protect provides CSRF token generation, validation, and middleware. +type Protect struct { + secret [32]byte + cookieName string + fieldName string +} + +// New creates a Protect with the given secret, cookie name, and form +// field name. The secret must be 32 bytes from crypto/rand and should +// be unique per service instance. +// +// Typical usage: +// +// secret := make([]byte, 32) +// crypto_rand.Read(secret) +// csrf := csrf.New(secret, "myservice_csrf", "csrf_token") +func New(secret []byte, cookieName, fieldName string) *Protect { + p := &Protect{ + cookieName: cookieName, + fieldName: fieldName, + } + copy(p.secret[:], secret) + return p +} + +// Middleware validates CSRF tokens on mutating requests (POST, PUT, +// PATCH, DELETE). Safe methods (GET, HEAD, OPTIONS) pass through. +// Returns 403 Forbidden if the token is missing, mismatched, or invalid. +func (p *Protect) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet, http.MethodHead, http.MethodOptions: + next.ServeHTTP(w, r) + return + } + + formToken := r.FormValue(p.fieldName) //nolint:gosec // form size is bounded by the http.Server's MaxBytesReader or ReadTimeout + cookie, err := r.Cookie(p.cookieName) + if err != nil || cookie.Value == "" || formToken == "" { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + if formToken != cookie.Value { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + if !p.validateToken(formToken) { + http.Error(w, "forbidden", http.StatusForbidden) + return + } + next.ServeHTTP(w, r) + }) +} + +// SetToken generates a new CSRF token, sets it as a cookie on the +// response, and returns the token string. Call this when rendering +// pages that contain forms. +func (p *Protect) SetToken(w http.ResponseWriter) string { + token := p.generateToken() + http.SetCookie(w, &http.Cookie{ + Name: p.cookieName, + Value: token, + Path: "/", + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteStrictMode, + }) + return token +} + +// TemplateFunc returns a [template.FuncMap] containing a "csrfField" +// function that renders a hidden input with the CSRF token. It calls +// SetToken to set the cookie. Use in template rendering: +// +// tmpl.Funcs(csrf.TemplateFunc(w)) +// +// In templates: +// +//
+// {{ csrfField }} +// ... +//
+func (p *Protect) TemplateFunc(w http.ResponseWriter) template.FuncMap { + token := p.SetToken(w) + return template.FuncMap{ + "csrfField": func() template.HTML { + return template.HTML(fmt.Sprintf( //nolint:gosec // output is escaped field name + validated token + ``, + template.HTMLEscapeString(p.fieldName), + template.HTMLEscapeString(token), + )) + }, + } +} + +func (p *Protect) generateToken() string { + nonce := make([]byte, 32) + if _, err := rand.Read(nonce); err != nil { + panic("csrf: failed to read random bytes: " + err.Error()) + } + mac := hmac.New(sha256.New, p.secret[:]) + mac.Write(nonce) + sig := mac.Sum(nil) + return base64.StdEncoding.EncodeToString(nonce) + "." + base64.StdEncoding.EncodeToString(sig) +} + +func (p *Protect) validateToken(token string) bool { + parts := strings.SplitN(token, ".", 2) + if len(parts) != 2 { + return false + } + nonce, err := base64.StdEncoding.DecodeString(parts[0]) + if err != nil { + return false + } + sig, err := base64.StdEncoding.DecodeString(parts[1]) + if err != nil { + return false + } + mac := hmac.New(sha256.New, p.secret[:]) + mac.Write(nonce) + return hmac.Equal(sig, mac.Sum(nil)) +} diff --git a/csrf/csrf_test.go b/csrf/csrf_test.go new file mode 100644 index 0000000..1973966 --- /dev/null +++ b/csrf/csrf_test.go @@ -0,0 +1,207 @@ +package csrf + +import ( + "crypto/rand" + "html/template" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +func testProtect(t *testing.T) *Protect { + t.Helper() + secret := make([]byte, 32) + if _, err := rand.Read(secret); err != nil { + t.Fatalf("generate secret: %v", err) + } + return New(secret, "test_csrf", "csrf_token") +} + +func TestGenerateAndValidate(t *testing.T) { + p := testProtect(t) + token := p.generateToken() + + if token == "" { + t.Fatal("empty token") + } + if !strings.Contains(token, ".") { + t.Fatal("token missing separator") + } + if !p.validateToken(token) { + t.Fatal("valid token rejected") + } +} + +func TestValidateInvalid(t *testing.T) { + p := testProtect(t) + + tests := []struct { + name string + token string + }{ + {"empty", ""}, + {"no separator", "abcdef"}, + {"bad nonce", "!!!." + "AAAA"}, + {"bad sig", "AAAA" + ".!!!"}, + {"wrong sig", "AAAA.BBBB"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if p.validateToken(tt.token) { + t.Fatalf("invalid token %q accepted", tt.token) + } + }) + } +} + +func TestValidateWrongSecret(t *testing.T) { + p1 := testProtect(t) + p2 := testProtect(t) + + token := p1.generateToken() + if p2.validateToken(token) { + t.Fatal("token validated with wrong secret") + } +} + +func TestTokenUniqueness(t *testing.T) { + p := testProtect(t) + t1 := p.generateToken() + t2 := p.generateToken() + if t1 == t2 { + t.Fatal("two tokens are identical") + } +} + +func TestSetToken(t *testing.T) { + p := testProtect(t) + rec := httptest.NewRecorder() + + token := p.SetToken(rec) + + if token == "" { + t.Fatal("empty token") + } + + cookies := rec.Result().Cookies() + var found bool + for _, c := range cookies { + if c.Name == "test_csrf" { + found = true + if c.Value != token { + t.Fatalf("cookie value = %q, want %q", c.Value, token) + } + if !c.HttpOnly { + t.Fatal("cookie not HttpOnly") + } + if !c.Secure { + t.Fatal("cookie not Secure") + } + if c.SameSite != http.SameSiteStrictMode { + t.Fatal("cookie not SameSite=Strict") + } + } + } + if !found { + t.Fatal("CSRF cookie not set") + } +} + +func TestMiddlewareSafeMethods(t *testing.T) { + p := testProtect(t) + called := false + handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true + })) + + for _, method := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} { + called = false + rec := httptest.NewRecorder() + req := httptest.NewRequest(method, "/", nil) + handler.ServeHTTP(rec, req) + if !called { + t.Fatalf("%s: handler not called", method) + } + } +} + +func TestMiddlewareRejectsMissingToken(t *testing.T) { + p := testProtect(t) + handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Fatal("handler should not be called") + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", nil) + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden) + } +} + +func TestMiddlewareRejectsMismatch(t *testing.T) { + p := testProtect(t) + handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + t.Fatal("handler should not be called") + })) + + token := p.generateToken() + form := url.Values{"csrf_token": {"wrong-token"}} + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.AddCookie(&http.Cookie{Name: "test_csrf", Value: token}) + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusForbidden { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden) + } +} + +func TestMiddlewareAcceptsValid(t *testing.T) { + p := testProtect(t) + called := false + handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true + })) + + token := p.generateToken() + form := url.Values{"csrf_token": {token}} + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.AddCookie(&http.Cookie{Name: "test_csrf", Value: token}) + handler.ServeHTTP(rec, req) + + if !called { + t.Fatal("handler not called for valid CSRF") + } +} + +func TestTemplateFunc(t *testing.T) { + p := testProtect(t) + rec := httptest.NewRecorder() + + funcs := p.TemplateFunc(rec) + csrfField, ok := funcs["csrfField"] + if !ok { + t.Fatal("csrfField not in FuncMap") + } + + fn, ok := csrfField.(func() template.HTML) + if !ok { + t.Fatal("csrfField is not func() template.HTML") + } + + html := string(fn()) + if !strings.Contains(html, `name="csrf_token"`) { + t.Fatalf("csrfField HTML missing field name: %s", html) + } + if !strings.Contains(html, `type="hidden"`) { + t.Fatalf("csrfField HTML missing type=hidden: %s", html) + } +}