Files
mcdsl/csrf/csrf.go
Kyle Isom 27f81c81ac Add csrf package: HMAC-SHA256 double-submit cookies
- Protect with configurable secret, cookie name, field name
- Middleware validates POST/PUT/PATCH/DELETE, passes GET/HEAD/OPTIONS
- SetToken generates token and sets HttpOnly/Secure/SameSite=Strict cookie
- TemplateFunc returns FuncMap with csrfField helper for templates
- Token format: base64(nonce).base64(HMAC-SHA256(secret, nonce))
- 10 tests

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 16:29:42 -07:00

145 lines
4.0 KiB
Go

// Package csrf provides HMAC-SHA256 double-submit cookie CSRF protection
// for Metacircular web UIs.
//
// The token format is base64(nonce) + "." + base64(HMAC-SHA256(secret, nonce)).
// A fresh token is set as a cookie on each page load. Mutating requests
// (POST, PUT, PATCH, DELETE) must include the token as a form field that
// matches the cookie value. Both the match and the HMAC signature are
// verified.
package csrf
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"html/template"
"net/http"
"strings"
)
// Protect provides CSRF token generation, validation, and middleware.
type Protect struct {
secret [32]byte
cookieName string
fieldName string
}
// New creates a Protect with the given secret, cookie name, and form
// field name. The secret must be 32 bytes from crypto/rand and should
// be unique per service instance.
//
// Typical usage:
//
// secret := make([]byte, 32)
// crypto_rand.Read(secret)
// csrf := csrf.New(secret, "myservice_csrf", "csrf_token")
func New(secret []byte, cookieName, fieldName string) *Protect {
p := &Protect{
cookieName: cookieName,
fieldName: fieldName,
}
copy(p.secret[:], secret)
return p
}
// Middleware validates CSRF tokens on mutating requests (POST, PUT,
// PATCH, DELETE). Safe methods (GET, HEAD, OPTIONS) pass through.
// Returns 403 Forbidden if the token is missing, mismatched, or invalid.
func (p *Protect) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet, http.MethodHead, http.MethodOptions:
next.ServeHTTP(w, r)
return
}
formToken := r.FormValue(p.fieldName) //nolint:gosec // form size is bounded by the http.Server's MaxBytesReader or ReadTimeout
cookie, err := r.Cookie(p.cookieName)
if err != nil || cookie.Value == "" || formToken == "" {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
if formToken != cookie.Value {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
if !p.validateToken(formToken) {
http.Error(w, "forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
// SetToken generates a new CSRF token, sets it as a cookie on the
// response, and returns the token string. Call this when rendering
// pages that contain forms.
func (p *Protect) SetToken(w http.ResponseWriter) string {
token := p.generateToken()
http.SetCookie(w, &http.Cookie{
Name: p.cookieName,
Value: token,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
return token
}
// TemplateFunc returns a [template.FuncMap] containing a "csrfField"
// function that renders a hidden input with the CSRF token. It calls
// SetToken to set the cookie. Use in template rendering:
//
// tmpl.Funcs(csrf.TemplateFunc(w))
//
// In templates:
//
// <form method="POST">
// {{ csrfField }}
// ...
// </form>
func (p *Protect) TemplateFunc(w http.ResponseWriter) template.FuncMap {
token := p.SetToken(w)
return template.FuncMap{
"csrfField": func() template.HTML {
return template.HTML(fmt.Sprintf( //nolint:gosec // output is escaped field name + validated token
`<input type="hidden" name="%s" value="%s">`,
template.HTMLEscapeString(p.fieldName),
template.HTMLEscapeString(token),
))
},
}
}
func (p *Protect) generateToken() string {
nonce := make([]byte, 32)
if _, err := rand.Read(nonce); err != nil {
panic("csrf: failed to read random bytes: " + err.Error())
}
mac := hmac.New(sha256.New, p.secret[:])
mac.Write(nonce)
sig := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(nonce) + "." + base64.StdEncoding.EncodeToString(sig)
}
func (p *Protect) validateToken(token string) bool {
parts := strings.SplitN(token, ".", 2)
if len(parts) != 2 {
return false
}
nonce, err := base64.StdEncoding.DecodeString(parts[0])
if err != nil {
return false
}
sig, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return false
}
mac := hmac.New(sha256.New, p.secret[:])
mac.Write(nonce)
return hmac.Equal(sig, mac.Sum(nil))
}