// Package middleware provides HTTP middleware for the MCIAS server. // // Security design: // - RequireAuth extracts the Bearer token from the Authorization header, // validates it (alg check, signature, expiry, issuer), and checks revocation // against the database before injecting claims into the request context. // - RequireRole checks claims from context for the required role. // No role implies no access; the check fails closed. // - RateLimit implements a per-IP token bucket to limit login brute-force. // - RequestLogger logs request metadata but never logs the Authorization // header value (which contains credential tokens). package middleware import ( "context" "crypto/ed25519" "encoding/json" "errors" "fmt" "log/slog" "net" "net/http" "strings" "sync" "time" "git.wntrmute.dev/kyle/mcias/internal/db" "git.wntrmute.dev/kyle/mcias/internal/token" ) // contextKey is the unexported type for context keys in this package, preventing // collisions with keys from other packages. type contextKey int const ( claimsKey contextKey = iota ) // ClaimsFromContext retrieves the validated JWT claims from the request context. // Returns nil if no claims are present (unauthenticated request). func ClaimsFromContext(ctx context.Context) *token.Claims { c, _ := ctx.Value(claimsKey).(*token.Claims) return c } // RequestLogger returns middleware that logs each request at INFO level. // The Authorization header is intentionally never logged. func RequestLogger(logger *slog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() // Wrap the ResponseWriter to capture the status code. rw := &responseWriter{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(rw, r) logger.Info("request", "method", r.Method, "path", r.URL.Path, "status", rw.status, "duration_ms", time.Since(start).Milliseconds(), "remote_addr", r.RemoteAddr, "user_agent", r.UserAgent(), // Security: Authorization header is never logged. ) }) } } // responseWriter wraps http.ResponseWriter to capture the status code. type responseWriter struct { http.ResponseWriter status int } func (rw *responseWriter) WriteHeader(code int) { rw.status = code rw.ResponseWriter.WriteHeader(code) } // RequireAuth returns middleware that validates a Bearer JWT and injects the // claims into the request context. Returns 401 on any auth failure. // // Security: Token validation order: // 1. Extract Bearer token from Authorization header. // 2. Validate the JWT (alg=EdDSA, signature, expiry, issuer). // 3. Check the JTI against the revocation table in the database. // 4. Inject validated claims into context for downstream handlers. func RequireAuth(pubKey ed25519.PublicKey, database *db.DB, issuer string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenStr, err := extractBearerToken(r) if err != nil { writeError(w, http.StatusUnauthorized, "missing or malformed Authorization header", "unauthorized") return } claims, err := token.ValidateToken(pubKey, tokenStr, issuer) if err != nil { // Security: Map all token errors to a generic 401; do not // reveal which specific check failed. if errors.Is(err, token.ErrExpiredToken) { writeError(w, http.StatusUnauthorized, "token expired", "token_expired") return } writeError(w, http.StatusUnauthorized, "invalid token", "unauthorized") return } // Security: Check revocation table. A token may be cryptographically // valid but explicitly revoked (logout, account suspension, etc.). rec, err := database.GetTokenRecord(claims.JTI) if err != nil { if errors.Is(err, db.ErrNotFound) { // Token not tracked — could be from a different server instance // or pre-dates tracking. Reject to be safe (fail closed). writeError(w, http.StatusUnauthorized, "unrecognized token", "unauthorized") return } writeError(w, http.StatusInternalServerError, "internal error", "internal_error") return } if rec.IsRevoked() { writeError(w, http.StatusUnauthorized, "token has been revoked", "token_revoked") return } ctx := context.WithValue(r.Context(), claimsKey, claims) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // RequireRole returns middleware that checks whether the authenticated user has // the given role. Must be used after RequireAuth. Returns 403 if role is absent. func RequireRole(role string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims := ClaimsFromContext(r.Context()) if claims == nil { // RequireAuth was not applied upstream; fail closed. writeError(w, http.StatusForbidden, "forbidden", "forbidden") return } if !claims.HasRole(role) { writeError(w, http.StatusForbidden, "insufficient privileges", "forbidden") return } next.ServeHTTP(w, r) }) } } // rateLimitEntry holds the token bucket state for a single IP. type rateLimitEntry struct { tokens float64 lastSeen time.Time mu sync.Mutex } // ipRateLimiter implements a per-IP token bucket rate limiter. type ipRateLimiter struct { rps float64 // refill rate: tokens per second burst float64 // bucket capacity ttl time.Duration // how long to keep idle entries mu sync.Mutex ips map[string]*rateLimitEntry } // RateLimit returns middleware implementing a per-IP token bucket. // rps is the sustained request rate (tokens refilled per second). // burst is the maximum burst size (initial and maximum token count). // // Security: Rate limiting is applied at the IP level. In production, the // server should be behind a reverse proxy that sets X-Forwarded-For; this // middleware uses RemoteAddr directly which may be the proxy IP. For single- // instance deployment without a proxy, RemoteAddr is the client IP. func RateLimit(rps float64, burst int) func(http.Handler) http.Handler { limiter := &ipRateLimiter{ rps: rps, burst: float64(burst), ttl: 10 * time.Minute, ips: make(map[string]*rateLimitEntry), } // Background cleanup of idle entries to prevent unbounded memory growth. go limiter.cleanup() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { ip = r.RemoteAddr } if !limiter.allow(ip) { w.Header().Set("Retry-After", "60") writeError(w, http.StatusTooManyRequests, "rate limit exceeded", "rate_limited") return } next.ServeHTTP(w, r) }) } } // allow returns true if a request from ip is permitted under the rate limit. func (l *ipRateLimiter) allow(ip string) bool { l.mu.Lock() entry, ok := l.ips[ip] if !ok { entry = &rateLimitEntry{tokens: l.burst, lastSeen: time.Now()} l.ips[ip] = entry } l.mu.Unlock() entry.mu.Lock() defer entry.mu.Unlock() now := time.Now() elapsed := now.Sub(entry.lastSeen).Seconds() entry.tokens = minFloat64(l.burst, entry.tokens+elapsed*l.rps) entry.lastSeen = now if entry.tokens < 1 { return false } entry.tokens-- return true } // cleanup periodically removes idle rate-limit entries. func (l *ipRateLimiter) cleanup() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { l.mu.Lock() cutoff := time.Now().Add(-l.ttl) for ip, entry := range l.ips { entry.mu.Lock() if entry.lastSeen.Before(cutoff) { delete(l.ips, ip) } entry.mu.Unlock() } l.mu.Unlock() } } // extractBearerToken extracts the token from "Authorization: Bearer ". func extractBearerToken(r *http.Request) (string, error) { auth := r.Header.Get("Authorization") if auth == "" { return "", fmt.Errorf("missing Authorization header") } parts := strings.SplitN(auth, " ", 2) if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { return "", fmt.Errorf("malformed Authorization header") } if parts[1] == "" { return "", fmt.Errorf("empty Bearer token") } return parts[1], nil } // apiError is the uniform error response structure. type apiError struct { Error string `json:"error"` Code string `json:"code"` } // writeError writes a JSON error response. func writeError(w http.ResponseWriter, status int, message, code string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) // Intentionally ignoring the error here; if the write fails, the client // already got the status code. _ = json.NewEncoder(w).Encode(apiError{Error: message, Code: code}) } // WriteError is the exported version for use by handler packages. func WriteError(w http.ResponseWriter, status int, message, code string) { writeError(w, status, message, code) } // minFloat64 returns the smaller of two float64 values. func minFloat64(a, b float64) float64 { if a < b { return a } return b }