// Package middleware provides HTTP middleware for the MCIAS server. // // Security design: // - RequireAuth extracts the Bearer token from the Authorization header, // validates it (alg check, signature, expiry, issuer), and checks revocation // against the database before injecting claims into the request context. // - RequireRole checks claims from context for the required role. // No role implies no access; the check fails closed. // - RateLimit implements a per-IP token bucket to limit login brute-force. // - RequestLogger logs request metadata but never logs the Authorization // header value (which contains credential tokens). package middleware import ( "context" "crypto/ed25519" "encoding/json" "errors" "fmt" "log/slog" "net" "net/http" "strings" "sync" "time" "git.wntrmute.dev/kyle/mcias/internal/db" "git.wntrmute.dev/kyle/mcias/internal/policy" "git.wntrmute.dev/kyle/mcias/internal/token" ) // contextKey is the unexported type for context keys in this package, preventing // collisions with keys from other packages. type contextKey int const ( claimsKey contextKey = iota ) // ClaimsFromContext retrieves the validated JWT claims from the request context. // Returns nil if no claims are present (unauthenticated request). // // Security: The type assertion uses the ok form so a context value of the wrong // type (e.g. from a different package's context injection) returns nil rather // than panicking. func ClaimsFromContext(ctx context.Context) *token.Claims { // ok is intentionally checked: if the value is absent or the wrong type, // c is nil (zero value for *token.Claims), which is the correct "no auth" result. c, ok := ctx.Value(claimsKey).(*token.Claims) if !ok { return nil } return c } // RequestLogger returns middleware that logs each request at INFO level. // The Authorization header is intentionally never logged. func RequestLogger(logger *slog.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() // Wrap the ResponseWriter to capture the status code. rw := &responseWriter{ResponseWriter: w, status: http.StatusOK} next.ServeHTTP(rw, r) logger.Info("request", "method", r.Method, "path", r.URL.Path, "status", rw.status, "duration_ms", time.Since(start).Milliseconds(), "remote_addr", r.RemoteAddr, "user_agent", r.UserAgent(), // Security: Authorization header is never logged. ) }) } } // responseWriter wraps http.ResponseWriter to capture the status code. type responseWriter struct { http.ResponseWriter status int } func (rw *responseWriter) WriteHeader(code int) { rw.status = code rw.ResponseWriter.WriteHeader(code) } // RequireAuth returns middleware that validates a Bearer JWT and injects the // claims into the request context. Returns 401 on any auth failure. // // Security: Token validation order: // 1. Extract Bearer token from Authorization header. // 2. Validate the JWT (alg=EdDSA, signature, expiry, issuer). // 3. Check the JTI against the revocation table in the database. // 4. Inject validated claims into context for downstream handlers. func RequireAuth(pubKey ed25519.PublicKey, database *db.DB, issuer string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tokenStr, err := extractBearerToken(r) if err != nil { writeError(w, http.StatusUnauthorized, "missing or malformed Authorization header", "unauthorized") return } claims, err := token.ValidateToken(pubKey, tokenStr, issuer) if err != nil { // Security: Map all token errors to a generic 401; do not // reveal which specific check failed. if errors.Is(err, token.ErrExpiredToken) { writeError(w, http.StatusUnauthorized, "token expired", "token_expired") return } writeError(w, http.StatusUnauthorized, "invalid token", "unauthorized") return } // Security: Check revocation table. A token may be cryptographically // valid but explicitly revoked (logout, account suspension, etc.). rec, err := database.GetTokenRecord(claims.JTI) if err != nil { if errors.Is(err, db.ErrNotFound) { // Token not tracked — could be from a different server instance // or pre-dates tracking. Reject to be safe (fail closed). writeError(w, http.StatusUnauthorized, "unrecognized token", "unauthorized") return } writeError(w, http.StatusInternalServerError, "internal error", "internal_error") return } if rec.IsRevoked() { writeError(w, http.StatusUnauthorized, "token has been revoked", "token_revoked") return } ctx := context.WithValue(r.Context(), claimsKey, claims) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // RequireRole returns middleware that checks whether the authenticated user has // the given role. Must be used after RequireAuth. Returns 403 if role is absent. func RequireRole(role string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims := ClaimsFromContext(r.Context()) if claims == nil { // RequireAuth was not applied upstream; fail closed. writeError(w, http.StatusForbidden, "forbidden", "forbidden") return } if !claims.HasRole(role) { writeError(w, http.StatusForbidden, "insufficient privileges", "forbidden") return } next.ServeHTTP(w, r) }) } } // rateLimitEntry holds the token bucket state for a single IP. type rateLimitEntry struct { lastSeen time.Time tokens float64 mu sync.Mutex } // ipRateLimiter implements a per-IP token bucket rate limiter. type ipRateLimiter struct { ips map[string]*rateLimitEntry rps float64 burst float64 ttl time.Duration mu sync.Mutex } // RateLimit returns middleware implementing a per-IP token bucket. // rps is the sustained request rate (tokens refilled per second). // burst is the maximum burst size (initial and maximum token count). // // Security: Rate limiting is applied at the IP level. In production, the // server should be behind a reverse proxy that sets X-Forwarded-For; this // middleware uses RemoteAddr directly which may be the proxy IP. For single- // instance deployment without a proxy, RemoteAddr is the client IP. func RateLimit(rps float64, burst int) func(http.Handler) http.Handler { limiter := &ipRateLimiter{ rps: rps, burst: float64(burst), ttl: 10 * time.Minute, ips: make(map[string]*rateLimitEntry), } // Background cleanup of idle entries to prevent unbounded memory growth. go limiter.cleanup() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { ip = r.RemoteAddr } if !limiter.allow(ip) { w.Header().Set("Retry-After", "60") writeError(w, http.StatusTooManyRequests, "rate limit exceeded", "rate_limited") return } next.ServeHTTP(w, r) }) } } // allow returns true if a request from ip is permitted under the rate limit. func (l *ipRateLimiter) allow(ip string) bool { l.mu.Lock() entry, ok := l.ips[ip] if !ok { entry = &rateLimitEntry{tokens: l.burst, lastSeen: time.Now()} l.ips[ip] = entry } l.mu.Unlock() entry.mu.Lock() defer entry.mu.Unlock() now := time.Now() elapsed := now.Sub(entry.lastSeen).Seconds() entry.tokens = minFloat64(l.burst, entry.tokens+elapsed*l.rps) entry.lastSeen = now if entry.tokens < 1 { return false } entry.tokens-- return true } // cleanup periodically removes idle rate-limit entries. func (l *ipRateLimiter) cleanup() { ticker := time.NewTicker(5 * time.Minute) defer ticker.Stop() for range ticker.C { l.mu.Lock() cutoff := time.Now().Add(-l.ttl) for ip, entry := range l.ips { entry.mu.Lock() if entry.lastSeen.Before(cutoff) { delete(l.ips, ip) } entry.mu.Unlock() } l.mu.Unlock() } } // extractBearerToken extracts the token from "Authorization: Bearer ". func extractBearerToken(r *http.Request) (string, error) { auth := r.Header.Get("Authorization") if auth == "" { return "", fmt.Errorf("missing Authorization header") } parts := strings.SplitN(auth, " ", 2) if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { return "", fmt.Errorf("malformed Authorization header") } if parts[1] == "" { return "", fmt.Errorf("empty Bearer token") } return parts[1], nil } // apiError is the uniform error response structure. type apiError struct { Error string `json:"error"` Code string `json:"code"` } // writeError writes a JSON error response. func writeError(w http.ResponseWriter, status int, message, code string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) // Intentionally ignoring the error here; if the write fails, the client // already got the status code. _ = json.NewEncoder(w).Encode(apiError{Error: message, Code: code}) } // WriteError is the exported version for use by handler packages. func WriteError(w http.ResponseWriter, status int, message, code string) { writeError(w, status, message, code) } // minFloat64 returns the smaller of two float64 values. func minFloat64(a, b float64) float64 { if a < b { return a } return b } // ResourceBuilder is a function that assembles the policy.Resource for a // specific request. The middleware calls it after claims are extracted. // Implementations typically read the path parameter (e.g. account UUID) and // look up the target account's owner UUID, service name, and tags from the DB. // // A nil ResourceBuilder is equivalent to a function that returns an empty // Resource (no owner, no service name, no tags). type ResourceBuilder func(r *http.Request, claims *token.Claims) policy.Resource // AccountTypeLookup resolves the account type ("human" or "system") for the // given account UUID. The middleware calls this to populate PolicyInput when // the AccountTypes match condition is used in any rule. // // Callers supply an implementation backed by db.GetAccountByUUID; the // middleware does not import the db package directly to avoid a cycle. // Returning an empty string is safe — it simply will not match any // AccountTypes condition on rules. type AccountTypeLookup func(subjectUUID string) string // PolicyDenyLogger is a function that records a policy denial in the audit log. // Callers supply an implementation that calls db.WriteAuditEvent; the middleware // itself does not import the db package directly for the audit write, keeping // the dependency on policy and db separate. type PolicyDenyLogger func(r *http.Request, claims *token.Claims, action policy.Action, res policy.Resource, matchedRuleID int64) // RequirePolicy returns middleware that evaluates the policy engine for the // given action and resource type. Must be used after RequireAuth. // // Security: deny-wins and default-deny semantics mean that any misconfiguration // (missing rule, engine error) results in a 403, never silent permit. The // matched rule ID is included in the audit event for traceability. // // AccountType is not stored in the JWT to avoid a signature-breaking change to // IssueToken. It is resolved lazily via lookupAccountType (a DB-backed closure // provided by the caller). Returning "" from lookupAccountType is safe: no // AccountTypes rule condition will match an empty string. // // RequirePolicy is intended to coexist with RequireRole("admin") during the // migration period. Once full policy coverage is validated, RequireRole can be // removed. During the transition both checks must pass. func RequirePolicy( eng *policy.Engine, action policy.Action, resType policy.ResourceType, buildResource ResourceBuilder, lookupAccountType AccountTypeLookup, logDeny PolicyDenyLogger, ) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { claims := ClaimsFromContext(r.Context()) if claims == nil { // RequireAuth was not applied upstream; fail closed. writeError(w, http.StatusForbidden, "forbidden", "forbidden") return } var res policy.Resource res.Type = resType if buildResource != nil { res = buildResource(r, claims) res.Type = resType // ensure type is always set even if builder overrides } accountType := "" if lookupAccountType != nil { accountType = lookupAccountType(claims.Subject) } input := policy.PolicyInput{ Subject: claims.Subject, AccountType: accountType, Roles: claims.Roles, Action: action, Resource: res, } effect, matched := eng.Evaluate(input) if effect == policy.Deny { var ruleID int64 if matched != nil { ruleID = matched.ID } if logDeny != nil { logDeny(r, claims, action, res, ruleID) } writeError(w, http.StatusForbidden, "insufficient privileges", "forbidden") return } next.ServeHTTP(w, r) }) } }