- Trusted proxy config option for proxy-aware IP extraction used by rate limiting and audit logs; validates proxy IP before trusting X-Forwarded-For / X-Real-IP headers - TOTP replay protection via counter-based validation to reject reused codes within the same time step (±30s) - RateLimit middleware updated to extract client IP from proxy headers without IP spoofing risk - New tests for ClientIP proxy logic (spoofed headers, fallback) and extended rate-limit proxy coverage - HTMX error banner script integrated into web UI base - .gitignore updated for mciasdb build artifact Security: resolves CRIT-01 (TOTP replay attack) and DEF-03 (proxy-unaware rate limiting); gRPC TOTP enrollment aligned with REST via StorePendingTOTP Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
440 lines
15 KiB
Go
440 lines
15 KiB
Go
// Package middleware provides HTTP middleware for the MCIAS server.
|
|
//
|
|
// Security design:
|
|
// - RequireAuth extracts the Bearer token from the Authorization header,
|
|
// validates it (alg check, signature, expiry, issuer), and checks revocation
|
|
// against the database before injecting claims into the request context.
|
|
// - RequireRole checks claims from context for the required role.
|
|
// No role implies no access; the check fails closed.
|
|
// - RateLimit implements a per-IP token bucket to limit login brute-force.
|
|
// - RequestLogger logs request metadata but never logs the Authorization
|
|
// header value (which contains credential tokens).
|
|
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"git.wntrmute.dev/kyle/mcias/internal/db"
|
|
"git.wntrmute.dev/kyle/mcias/internal/policy"
|
|
"git.wntrmute.dev/kyle/mcias/internal/token"
|
|
)
|
|
|
|
// contextKey is the unexported type for context keys in this package, preventing
|
|
// collisions with keys from other packages.
|
|
type contextKey int
|
|
|
|
const (
|
|
claimsKey contextKey = iota
|
|
)
|
|
|
|
// ClaimsFromContext retrieves the validated JWT claims from the request context.
|
|
// Returns nil if no claims are present (unauthenticated request).
|
|
//
|
|
// Security: The type assertion uses the ok form so a context value of the wrong
|
|
// type (e.g. from a different package's context injection) returns nil rather
|
|
// than panicking.
|
|
func ClaimsFromContext(ctx context.Context) *token.Claims {
|
|
// ok is intentionally checked: if the value is absent or the wrong type,
|
|
// c is nil (zero value for *token.Claims), which is the correct "no auth" result.
|
|
c, ok := ctx.Value(claimsKey).(*token.Claims)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return c
|
|
}
|
|
|
|
// RequestLogger returns middleware that logs each request at INFO level.
|
|
// The Authorization header is intentionally never logged.
|
|
func RequestLogger(logger *slog.Logger) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
// Wrap the ResponseWriter to capture the status code.
|
|
rw := &responseWriter{ResponseWriter: w, status: http.StatusOK}
|
|
next.ServeHTTP(rw, r)
|
|
|
|
logger.Info("request",
|
|
"method", r.Method,
|
|
"path", r.URL.Path,
|
|
"status", rw.status,
|
|
"duration_ms", time.Since(start).Milliseconds(),
|
|
"remote_addr", r.RemoteAddr,
|
|
"user_agent", r.UserAgent(),
|
|
// Security: Authorization header is never logged.
|
|
)
|
|
})
|
|
}
|
|
}
|
|
|
|
// responseWriter wraps http.ResponseWriter to capture the status code.
|
|
type responseWriter struct {
|
|
http.ResponseWriter
|
|
status int
|
|
}
|
|
|
|
func (rw *responseWriter) WriteHeader(code int) {
|
|
rw.status = code
|
|
rw.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
// RequireAuth returns middleware that validates a Bearer JWT and injects the
|
|
// claims into the request context. Returns 401 on any auth failure.
|
|
//
|
|
// Security: Token validation order:
|
|
// 1. Extract Bearer token from Authorization header.
|
|
// 2. Validate the JWT (alg=EdDSA, signature, expiry, issuer).
|
|
// 3. Check the JTI against the revocation table in the database.
|
|
// 4. Inject validated claims into context for downstream handlers.
|
|
func RequireAuth(pubKey ed25519.PublicKey, database *db.DB, issuer string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
tokenStr, err := extractBearerToken(r)
|
|
if err != nil {
|
|
writeError(w, http.StatusUnauthorized, "missing or malformed Authorization header", "unauthorized")
|
|
return
|
|
}
|
|
|
|
claims, err := token.ValidateToken(pubKey, tokenStr, issuer)
|
|
if err != nil {
|
|
// Security: Map all token errors to a generic 401; do not
|
|
// reveal which specific check failed.
|
|
if errors.Is(err, token.ErrExpiredToken) {
|
|
writeError(w, http.StatusUnauthorized, "token expired", "token_expired")
|
|
return
|
|
}
|
|
writeError(w, http.StatusUnauthorized, "invalid token", "unauthorized")
|
|
return
|
|
}
|
|
|
|
// Security: Check revocation table. A token may be cryptographically
|
|
// valid but explicitly revoked (logout, account suspension, etc.).
|
|
rec, err := database.GetTokenRecord(claims.JTI)
|
|
if err != nil {
|
|
if errors.Is(err, db.ErrNotFound) {
|
|
// Token not tracked — could be from a different server instance
|
|
// or pre-dates tracking. Reject to be safe (fail closed).
|
|
writeError(w, http.StatusUnauthorized, "unrecognized token", "unauthorized")
|
|
return
|
|
}
|
|
writeError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
|
return
|
|
}
|
|
if rec.IsRevoked() {
|
|
writeError(w, http.StatusUnauthorized, "token has been revoked", "token_revoked")
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), claimsKey, claims)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// RequireRole returns middleware that checks whether the authenticated user has
|
|
// the given role. Must be used after RequireAuth. Returns 403 if role is absent.
|
|
func RequireRole(role string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
claims := ClaimsFromContext(r.Context())
|
|
if claims == nil {
|
|
// RequireAuth was not applied upstream; fail closed.
|
|
writeError(w, http.StatusForbidden, "forbidden", "forbidden")
|
|
return
|
|
}
|
|
if !claims.HasRole(role) {
|
|
writeError(w, http.StatusForbidden, "insufficient privileges", "forbidden")
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// rateLimitEntry holds the token bucket state for a single IP.
|
|
type rateLimitEntry struct {
|
|
lastSeen time.Time
|
|
tokens float64
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// ipRateLimiter implements a per-IP token bucket rate limiter.
|
|
type ipRateLimiter struct {
|
|
ips map[string]*rateLimitEntry
|
|
rps float64
|
|
burst float64
|
|
ttl time.Duration
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// ClientIP returns the real client IP for a request, optionally trusting a
|
|
// single reverse-proxy address.
|
|
//
|
|
// Security (DEF-03): X-Forwarded-For and X-Real-IP headers can be forged by
|
|
// any client. This function only honours them when the immediate TCP peer
|
|
// (r.RemoteAddr) matches trustedProxy exactly. When trustedProxy is nil or
|
|
// the peer address does not match, r.RemoteAddr is used unconditionally.
|
|
//
|
|
// This prevents IP-spoofing attacks: an attacker who sends a fake
|
|
// X-Forwarded-For header from their own connection still has their real IP
|
|
// used for rate limiting, because their RemoteAddr will not match the proxy.
|
|
//
|
|
// Only the first (leftmost) value in X-Forwarded-For is used, as that is the
|
|
// client-supplied address as appended by the outermost proxy. If neither
|
|
// header is present, RemoteAddr is used as a fallback even when the request
|
|
// comes from the proxy.
|
|
func ClientIP(r *http.Request, trustedProxy net.IP) string {
|
|
remoteHost, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
remoteHost = r.RemoteAddr
|
|
}
|
|
|
|
if trustedProxy != nil {
|
|
remoteIP := net.ParseIP(remoteHost)
|
|
if remoteIP != nil && remoteIP.Equal(trustedProxy) {
|
|
// Request is from the trusted proxy; extract the real client IP.
|
|
// Prefer X-Real-IP (single value) over X-Forwarded-For (may be a
|
|
// comma-separated list when multiple proxies are chained).
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
if ip := net.ParseIP(strings.TrimSpace(xri)); ip != nil {
|
|
return ip.String()
|
|
}
|
|
}
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
// Take the first (leftmost) address — the original client.
|
|
first, _, _ := strings.Cut(xff, ",")
|
|
if ip := net.ParseIP(strings.TrimSpace(first)); ip != nil {
|
|
return ip.String()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return remoteHost
|
|
}
|
|
|
|
// RateLimit returns middleware implementing a per-IP token bucket.
|
|
// rps is the sustained request rate (tokens refilled per second).
|
|
// burst is the maximum burst size (initial and maximum token count).
|
|
// trustedProxy, if non-nil, enables proxy-aware client IP extraction via
|
|
// ClientIP; pass nil when not running behind a reverse proxy.
|
|
//
|
|
// Security (DEF-03): when trustedProxy is set, real client IPs are extracted
|
|
// from X-Forwarded-For/X-Real-IP headers but only for requests whose
|
|
// RemoteAddr matches the trusted proxy, preventing IP-spoofing.
|
|
func RateLimit(rps float64, burst int, trustedProxy net.IP) func(http.Handler) http.Handler {
|
|
limiter := &ipRateLimiter{
|
|
rps: rps,
|
|
burst: float64(burst),
|
|
ttl: 10 * time.Minute,
|
|
ips: make(map[string]*rateLimitEntry),
|
|
}
|
|
|
|
// Background cleanup of idle entries to prevent unbounded memory growth.
|
|
go limiter.cleanup()
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ip := ClientIP(r, trustedProxy)
|
|
|
|
if !limiter.allow(ip) {
|
|
w.Header().Set("Retry-After", "60")
|
|
writeError(w, http.StatusTooManyRequests, "rate limit exceeded", "rate_limited")
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// allow returns true if a request from ip is permitted under the rate limit.
|
|
func (l *ipRateLimiter) allow(ip string) bool {
|
|
l.mu.Lock()
|
|
entry, ok := l.ips[ip]
|
|
if !ok {
|
|
entry = &rateLimitEntry{tokens: l.burst, lastSeen: time.Now()}
|
|
l.ips[ip] = entry
|
|
}
|
|
l.mu.Unlock()
|
|
|
|
entry.mu.Lock()
|
|
defer entry.mu.Unlock()
|
|
|
|
now := time.Now()
|
|
elapsed := now.Sub(entry.lastSeen).Seconds()
|
|
entry.tokens = minFloat64(l.burst, entry.tokens+elapsed*l.rps)
|
|
entry.lastSeen = now
|
|
|
|
if entry.tokens < 1 {
|
|
return false
|
|
}
|
|
entry.tokens--
|
|
return true
|
|
}
|
|
|
|
// cleanup periodically removes idle rate-limit entries.
|
|
func (l *ipRateLimiter) cleanup() {
|
|
ticker := time.NewTicker(5 * time.Minute)
|
|
defer ticker.Stop()
|
|
for range ticker.C {
|
|
l.mu.Lock()
|
|
cutoff := time.Now().Add(-l.ttl)
|
|
for ip, entry := range l.ips {
|
|
entry.mu.Lock()
|
|
if entry.lastSeen.Before(cutoff) {
|
|
delete(l.ips, ip)
|
|
}
|
|
entry.mu.Unlock()
|
|
}
|
|
l.mu.Unlock()
|
|
}
|
|
}
|
|
|
|
// extractBearerToken extracts the token from "Authorization: Bearer <token>".
|
|
func extractBearerToken(r *http.Request) (string, error) {
|
|
auth := r.Header.Get("Authorization")
|
|
if auth == "" {
|
|
return "", fmt.Errorf("missing Authorization header")
|
|
}
|
|
parts := strings.SplitN(auth, " ", 2)
|
|
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
|
return "", fmt.Errorf("malformed Authorization header")
|
|
}
|
|
if parts[1] == "" {
|
|
return "", fmt.Errorf("empty Bearer token")
|
|
}
|
|
return parts[1], nil
|
|
}
|
|
|
|
// apiError is the uniform error response structure.
|
|
type apiError struct {
|
|
Error string `json:"error"`
|
|
Code string `json:"code"`
|
|
}
|
|
|
|
// writeError writes a JSON error response.
|
|
func writeError(w http.ResponseWriter, status int, message, code string) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
// Intentionally ignoring the error here; if the write fails, the client
|
|
// already got the status code.
|
|
_ = json.NewEncoder(w).Encode(apiError{Error: message, Code: code})
|
|
}
|
|
|
|
// WriteError is the exported version for use by handler packages.
|
|
func WriteError(w http.ResponseWriter, status int, message, code string) {
|
|
writeError(w, status, message, code)
|
|
}
|
|
|
|
// minFloat64 returns the smaller of two float64 values.
|
|
func minFloat64(a, b float64) float64 {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
// ResourceBuilder is a function that assembles the policy.Resource for a
|
|
// specific request. The middleware calls it after claims are extracted.
|
|
// Implementations typically read the path parameter (e.g. account UUID) and
|
|
// look up the target account's owner UUID, service name, and tags from the DB.
|
|
//
|
|
// A nil ResourceBuilder is equivalent to a function that returns an empty
|
|
// Resource (no owner, no service name, no tags).
|
|
type ResourceBuilder func(r *http.Request, claims *token.Claims) policy.Resource
|
|
|
|
// AccountTypeLookup resolves the account type ("human" or "system") for the
|
|
// given account UUID. The middleware calls this to populate PolicyInput when
|
|
// the AccountTypes match condition is used in any rule.
|
|
//
|
|
// Callers supply an implementation backed by db.GetAccountByUUID; the
|
|
// middleware does not import the db package directly to avoid a cycle.
|
|
// Returning an empty string is safe — it simply will not match any
|
|
// AccountTypes condition on rules.
|
|
type AccountTypeLookup func(subjectUUID string) string
|
|
|
|
// PolicyDenyLogger is a function that records a policy denial in the audit log.
|
|
// Callers supply an implementation that calls db.WriteAuditEvent; the middleware
|
|
// itself does not import the db package directly for the audit write, keeping
|
|
// the dependency on policy and db separate.
|
|
type PolicyDenyLogger func(r *http.Request, claims *token.Claims, action policy.Action, res policy.Resource, matchedRuleID int64)
|
|
|
|
// RequirePolicy returns middleware that evaluates the policy engine for the
|
|
// given action and resource type. Must be used after RequireAuth.
|
|
//
|
|
// Security: deny-wins and default-deny semantics mean that any misconfiguration
|
|
// (missing rule, engine error) results in a 403, never silent permit. The
|
|
// matched rule ID is included in the audit event for traceability.
|
|
//
|
|
// AccountType is not stored in the JWT to avoid a signature-breaking change to
|
|
// IssueToken. It is resolved lazily via lookupAccountType (a DB-backed closure
|
|
// provided by the caller). Returning "" from lookupAccountType is safe: no
|
|
// AccountTypes rule condition will match an empty string.
|
|
//
|
|
// RequirePolicy is intended to coexist with RequireRole("admin") during the
|
|
// migration period. Once full policy coverage is validated, RequireRole can be
|
|
// removed. During the transition both checks must pass.
|
|
func RequirePolicy(
|
|
eng *policy.Engine,
|
|
action policy.Action,
|
|
resType policy.ResourceType,
|
|
buildResource ResourceBuilder,
|
|
lookupAccountType AccountTypeLookup,
|
|
logDeny PolicyDenyLogger,
|
|
) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
claims := ClaimsFromContext(r.Context())
|
|
if claims == nil {
|
|
// RequireAuth was not applied upstream; fail closed.
|
|
writeError(w, http.StatusForbidden, "forbidden", "forbidden")
|
|
return
|
|
}
|
|
|
|
var res policy.Resource
|
|
res.Type = resType
|
|
if buildResource != nil {
|
|
res = buildResource(r, claims)
|
|
res.Type = resType // ensure type is always set even if builder overrides
|
|
}
|
|
|
|
accountType := ""
|
|
if lookupAccountType != nil {
|
|
accountType = lookupAccountType(claims.Subject)
|
|
}
|
|
|
|
input := policy.PolicyInput{
|
|
Subject: claims.Subject,
|
|
AccountType: accountType,
|
|
Roles: claims.Roles,
|
|
Action: action,
|
|
Resource: res,
|
|
}
|
|
|
|
effect, matched := eng.Evaluate(input)
|
|
if effect == policy.Deny {
|
|
var ruleID int64
|
|
if matched != nil {
|
|
ruleID = matched.ID
|
|
}
|
|
if logDeny != nil {
|
|
logDeny(r, claims, action, res, ruleID)
|
|
}
|
|
writeError(w, http.StatusForbidden, "insufficient privileges", "forbidden")
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|