Add csrf package: HMAC-SHA256 double-submit cookies
- Protect with configurable secret, cookie name, field name - Middleware validates POST/PUT/PATCH/DELETE, passes GET/HEAD/OPTIONS - SetToken generates token and sets HttpOnly/Secure/SameSite=Strict cookie - TemplateFunc returns FuncMap with csrfField helper for templates - Token format: base64(nonce).base64(HMAC-SHA256(secret, nonce)) - 10 tests Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
144
csrf/csrf.go
Normal file
144
csrf/csrf.go
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
// Package csrf provides HMAC-SHA256 double-submit cookie CSRF protection
|
||||||
|
// for Metacircular web UIs.
|
||||||
|
//
|
||||||
|
// The token format is base64(nonce) + "." + base64(HMAC-SHA256(secret, nonce)).
|
||||||
|
// A fresh token is set as a cookie on each page load. Mutating requests
|
||||||
|
// (POST, PUT, PATCH, DELETE) must include the token as a form field that
|
||||||
|
// matches the cookie value. Both the match and the HMAC signature are
|
||||||
|
// verified.
|
||||||
|
package csrf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"html/template"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Protect provides CSRF token generation, validation, and middleware.
|
||||||
|
type Protect struct {
|
||||||
|
secret [32]byte
|
||||||
|
cookieName string
|
||||||
|
fieldName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a Protect with the given secret, cookie name, and form
|
||||||
|
// field name. The secret must be 32 bytes from crypto/rand and should
|
||||||
|
// be unique per service instance.
|
||||||
|
//
|
||||||
|
// Typical usage:
|
||||||
|
//
|
||||||
|
// secret := make([]byte, 32)
|
||||||
|
// crypto_rand.Read(secret)
|
||||||
|
// csrf := csrf.New(secret, "myservice_csrf", "csrf_token")
|
||||||
|
func New(secret []byte, cookieName, fieldName string) *Protect {
|
||||||
|
p := &Protect{
|
||||||
|
cookieName: cookieName,
|
||||||
|
fieldName: fieldName,
|
||||||
|
}
|
||||||
|
copy(p.secret[:], secret)
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware validates CSRF tokens on mutating requests (POST, PUT,
|
||||||
|
// PATCH, DELETE). Safe methods (GET, HEAD, OPTIONS) pass through.
|
||||||
|
// Returns 403 Forbidden if the token is missing, mismatched, or invalid.
|
||||||
|
func (p *Protect) Middleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet, http.MethodHead, http.MethodOptions:
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
formToken := r.FormValue(p.fieldName) //nolint:gosec // form size is bounded by the http.Server's MaxBytesReader or ReadTimeout
|
||||||
|
cookie, err := r.Cookie(p.cookieName)
|
||||||
|
if err != nil || cookie.Value == "" || formToken == "" {
|
||||||
|
http.Error(w, "forbidden", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if formToken != cookie.Value {
|
||||||
|
http.Error(w, "forbidden", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !p.validateToken(formToken) {
|
||||||
|
http.Error(w, "forbidden", http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetToken generates a new CSRF token, sets it as a cookie on the
|
||||||
|
// response, and returns the token string. Call this when rendering
|
||||||
|
// pages that contain forms.
|
||||||
|
func (p *Protect) SetToken(w http.ResponseWriter) string {
|
||||||
|
token := p.generateToken()
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: p.cookieName,
|
||||||
|
Value: token,
|
||||||
|
Path: "/",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteStrictMode,
|
||||||
|
})
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
// TemplateFunc returns a [template.FuncMap] containing a "csrfField"
|
||||||
|
// function that renders a hidden input with the CSRF token. It calls
|
||||||
|
// SetToken to set the cookie. Use in template rendering:
|
||||||
|
//
|
||||||
|
// tmpl.Funcs(csrf.TemplateFunc(w))
|
||||||
|
//
|
||||||
|
// In templates:
|
||||||
|
//
|
||||||
|
// <form method="POST">
|
||||||
|
// {{ csrfField }}
|
||||||
|
// ...
|
||||||
|
// </form>
|
||||||
|
func (p *Protect) TemplateFunc(w http.ResponseWriter) template.FuncMap {
|
||||||
|
token := p.SetToken(w)
|
||||||
|
return template.FuncMap{
|
||||||
|
"csrfField": func() template.HTML {
|
||||||
|
return template.HTML(fmt.Sprintf( //nolint:gosec // output is escaped field name + validated token
|
||||||
|
`<input type="hidden" name="%s" value="%s">`,
|
||||||
|
template.HTMLEscapeString(p.fieldName),
|
||||||
|
template.HTMLEscapeString(token),
|
||||||
|
))
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Protect) generateToken() string {
|
||||||
|
nonce := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(nonce); err != nil {
|
||||||
|
panic("csrf: failed to read random bytes: " + err.Error())
|
||||||
|
}
|
||||||
|
mac := hmac.New(sha256.New, p.secret[:])
|
||||||
|
mac.Write(nonce)
|
||||||
|
sig := mac.Sum(nil)
|
||||||
|
return base64.StdEncoding.EncodeToString(nonce) + "." + base64.StdEncoding.EncodeToString(sig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Protect) validateToken(token string) bool {
|
||||||
|
parts := strings.SplitN(token, ".", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
nonce, err := base64.StdEncoding.DecodeString(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
sig, err := base64.StdEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
mac := hmac.New(sha256.New, p.secret[:])
|
||||||
|
mac.Write(nonce)
|
||||||
|
return hmac.Equal(sig, mac.Sum(nil))
|
||||||
|
}
|
||||||
207
csrf/csrf_test.go
Normal file
207
csrf/csrf_test.go
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
package csrf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"html/template"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testProtect(t *testing.T) *Protect {
|
||||||
|
t.Helper()
|
||||||
|
secret := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(secret); err != nil {
|
||||||
|
t.Fatalf("generate secret: %v", err)
|
||||||
|
}
|
||||||
|
return New(secret, "test_csrf", "csrf_token")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateAndValidate(t *testing.T) {
|
||||||
|
p := testProtect(t)
|
||||||
|
token := p.generateToken()
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Fatal("empty token")
|
||||||
|
}
|
||||||
|
if !strings.Contains(token, ".") {
|
||||||
|
t.Fatal("token missing separator")
|
||||||
|
}
|
||||||
|
if !p.validateToken(token) {
|
||||||
|
t.Fatal("valid token rejected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateInvalid(t *testing.T) {
|
||||||
|
p := testProtect(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
}{
|
||||||
|
{"empty", ""},
|
||||||
|
{"no separator", "abcdef"},
|
||||||
|
{"bad nonce", "!!!." + "AAAA"},
|
||||||
|
{"bad sig", "AAAA" + ".!!!"},
|
||||||
|
{"wrong sig", "AAAA.BBBB"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if p.validateToken(tt.token) {
|
||||||
|
t.Fatalf("invalid token %q accepted", tt.token)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateWrongSecret(t *testing.T) {
|
||||||
|
p1 := testProtect(t)
|
||||||
|
p2 := testProtect(t)
|
||||||
|
|
||||||
|
token := p1.generateToken()
|
||||||
|
if p2.validateToken(token) {
|
||||||
|
t.Fatal("token validated with wrong secret")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenUniqueness(t *testing.T) {
|
||||||
|
p := testProtect(t)
|
||||||
|
t1 := p.generateToken()
|
||||||
|
t2 := p.generateToken()
|
||||||
|
if t1 == t2 {
|
||||||
|
t.Fatal("two tokens are identical")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetToken(t *testing.T) {
|
||||||
|
p := testProtect(t)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
token := p.SetToken(rec)
|
||||||
|
|
||||||
|
if token == "" {
|
||||||
|
t.Fatal("empty token")
|
||||||
|
}
|
||||||
|
|
||||||
|
cookies := rec.Result().Cookies()
|
||||||
|
var found bool
|
||||||
|
for _, c := range cookies {
|
||||||
|
if c.Name == "test_csrf" {
|
||||||
|
found = true
|
||||||
|
if c.Value != token {
|
||||||
|
t.Fatalf("cookie value = %q, want %q", c.Value, token)
|
||||||
|
}
|
||||||
|
if !c.HttpOnly {
|
||||||
|
t.Fatal("cookie not HttpOnly")
|
||||||
|
}
|
||||||
|
if !c.Secure {
|
||||||
|
t.Fatal("cookie not Secure")
|
||||||
|
}
|
||||||
|
if c.SameSite != http.SameSiteStrictMode {
|
||||||
|
t.Fatal("cookie not SameSite=Strict")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Fatal("CSRF cookie not set")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareSafeMethods(t *testing.T) {
|
||||||
|
p := testProtect(t)
|
||||||
|
called := false
|
||||||
|
handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
called = true
|
||||||
|
}))
|
||||||
|
|
||||||
|
for _, method := range []string{http.MethodGet, http.MethodHead, http.MethodOptions} {
|
||||||
|
called = false
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(method, "/", nil)
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
if !called {
|
||||||
|
t.Fatalf("%s: handler not called", method)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareRejectsMissingToken(t *testing.T) {
|
||||||
|
p := testProtect(t)
|
||||||
|
handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
t.Fatal("handler should not be called")
|
||||||
|
}))
|
||||||
|
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusForbidden {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareRejectsMismatch(t *testing.T) {
|
||||||
|
p := testProtect(t)
|
||||||
|
handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
t.Fatal("handler should not be called")
|
||||||
|
}))
|
||||||
|
|
||||||
|
token := p.generateToken()
|
||||||
|
form := url.Values{"csrf_token": {"wrong-token"}}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.AddCookie(&http.Cookie{Name: "test_csrf", Value: token})
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if rec.Code != http.StatusForbidden {
|
||||||
|
t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewareAcceptsValid(t *testing.T) {
|
||||||
|
p := testProtect(t)
|
||||||
|
called := false
|
||||||
|
handler := p.Middleware(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
|
||||||
|
called = true
|
||||||
|
}))
|
||||||
|
|
||||||
|
token := p.generateToken()
|
||||||
|
form := url.Values{"csrf_token": {token}}
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(form.Encode()))
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.AddCookie(&http.Cookie{Name: "test_csrf", Value: token})
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Fatal("handler not called for valid CSRF")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTemplateFunc(t *testing.T) {
|
||||||
|
p := testProtect(t)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
|
funcs := p.TemplateFunc(rec)
|
||||||
|
csrfField, ok := funcs["csrfField"]
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("csrfField not in FuncMap")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn, ok := csrfField.(func() template.HTML)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("csrfField is not func() template.HTML")
|
||||||
|
}
|
||||||
|
|
||||||
|
html := string(fn())
|
||||||
|
if !strings.Contains(html, `name="csrf_token"`) {
|
||||||
|
t.Fatalf("csrfField HTML missing field name: %s", html)
|
||||||
|
}
|
||||||
|
if !strings.Contains(html, `type="hidden"`) {
|
||||||
|
t.Fatalf("csrfField HTML missing type=hidden: %s", html)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user