checkpoint mciassrv
This commit is contained in:
290
internal/middleware/middleware.go
Normal file
290
internal/middleware/middleware.go
Normal file
@@ -0,0 +1,290 @@
|
||||
// Package middleware provides HTTP middleware for the MCIAS server.
|
||||
//
|
||||
// Security design:
|
||||
// - RequireAuth extracts the Bearer token from the Authorization header,
|
||||
// validates it (alg check, signature, expiry, issuer), and checks revocation
|
||||
// against the database before injecting claims into the request context.
|
||||
// - RequireRole checks claims from context for the required role.
|
||||
// No role implies no access; the check fails closed.
|
||||
// - RateLimit implements a per-IP token bucket to limit login brute-force.
|
||||
// - RequestLogger logs request metadata but never logs the Authorization
|
||||
// header value (which contains credential tokens).
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/token"
|
||||
)
|
||||
|
||||
// contextKey is the unexported type for context keys in this package, preventing
|
||||
// collisions with keys from other packages.
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
claimsKey contextKey = iota
|
||||
)
|
||||
|
||||
// ClaimsFromContext retrieves the validated JWT claims from the request context.
|
||||
// Returns nil if no claims are present (unauthenticated request).
|
||||
func ClaimsFromContext(ctx context.Context) *token.Claims {
|
||||
c, _ := ctx.Value(claimsKey).(*token.Claims)
|
||||
return c
|
||||
}
|
||||
|
||||
// RequestLogger returns middleware that logs each request at INFO level.
|
||||
// The Authorization header is intentionally never logged.
|
||||
func RequestLogger(logger *slog.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
// Wrap the ResponseWriter to capture the status code.
|
||||
rw := &responseWriter{ResponseWriter: w, status: http.StatusOK}
|
||||
next.ServeHTTP(rw, r)
|
||||
|
||||
logger.Info("request",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", rw.status,
|
||||
"duration_ms", time.Since(start).Milliseconds(),
|
||||
"remote_addr", r.RemoteAddr,
|
||||
"user_agent", r.UserAgent(),
|
||||
// Security: Authorization header is never logged.
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// responseWriter wraps http.ResponseWriter to capture the status code.
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.status = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// RequireAuth returns middleware that validates a Bearer JWT and injects the
|
||||
// claims into the request context. Returns 401 on any auth failure.
|
||||
//
|
||||
// Security: Token validation order:
|
||||
// 1. Extract Bearer token from Authorization header.
|
||||
// 2. Validate the JWT (alg=EdDSA, signature, expiry, issuer).
|
||||
// 3. Check the JTI against the revocation table in the database.
|
||||
// 4. Inject validated claims into context for downstream handlers.
|
||||
func RequireAuth(pubKey ed25519.PublicKey, database *db.DB, issuer string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tokenStr, err := extractBearerToken(r)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "missing or malformed Authorization header", "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := token.ValidateToken(pubKey, tokenStr, issuer)
|
||||
if err != nil {
|
||||
// Security: Map all token errors to a generic 401; do not
|
||||
// reveal which specific check failed.
|
||||
if errors.Is(err, token.ErrExpiredToken) {
|
||||
writeError(w, http.StatusUnauthorized, "token expired", "token_expired")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusUnauthorized, "invalid token", "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// Security: Check revocation table. A token may be cryptographically
|
||||
// valid but explicitly revoked (logout, account suspension, etc.).
|
||||
rec, err := database.GetTokenRecord(claims.JTI)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
// Token not tracked — could be from a different server instance
|
||||
// or pre-dates tracking. Reject to be safe (fail closed).
|
||||
writeError(w, http.StatusUnauthorized, "unrecognized token", "unauthorized")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||
return
|
||||
}
|
||||
if rec.IsRevoked() {
|
||||
writeError(w, http.StatusUnauthorized, "token has been revoked", "token_revoked")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), claimsKey, claims)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequireRole returns middleware that checks whether the authenticated user has
|
||||
// the given role. Must be used after RequireAuth. Returns 403 if role is absent.
|
||||
func RequireRole(role string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims := ClaimsFromContext(r.Context())
|
||||
if claims == nil {
|
||||
// RequireAuth was not applied upstream; fail closed.
|
||||
writeError(w, http.StatusForbidden, "forbidden", "forbidden")
|
||||
return
|
||||
}
|
||||
if !claims.HasRole(role) {
|
||||
writeError(w, http.StatusForbidden, "insufficient privileges", "forbidden")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// rateLimitEntry holds the token bucket state for a single IP.
|
||||
type rateLimitEntry struct {
|
||||
tokens float64
|
||||
lastSeen time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// ipRateLimiter implements a per-IP token bucket rate limiter.
|
||||
type ipRateLimiter struct {
|
||||
rps float64 // refill rate: tokens per second
|
||||
burst float64 // bucket capacity
|
||||
ttl time.Duration // how long to keep idle entries
|
||||
mu sync.Mutex
|
||||
ips map[string]*rateLimitEntry
|
||||
}
|
||||
|
||||
// RateLimit returns middleware implementing a per-IP token bucket.
|
||||
// rps is the sustained request rate (tokens refilled per second).
|
||||
// burst is the maximum burst size (initial and maximum token count).
|
||||
//
|
||||
// Security: Rate limiting is applied at the IP level. In production, the
|
||||
// server should be behind a reverse proxy that sets X-Forwarded-For; this
|
||||
// middleware uses RemoteAddr directly which may be the proxy IP. For single-
|
||||
// instance deployment without a proxy, RemoteAddr is the client IP.
|
||||
func RateLimit(rps float64, burst int) func(http.Handler) http.Handler {
|
||||
limiter := &ipRateLimiter{
|
||||
rps: rps,
|
||||
burst: float64(burst),
|
||||
ttl: 10 * time.Minute,
|
||||
ips: make(map[string]*rateLimitEntry),
|
||||
}
|
||||
|
||||
// Background cleanup of idle entries to prevent unbounded memory growth.
|
||||
go limiter.cleanup()
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
ip = r.RemoteAddr
|
||||
}
|
||||
|
||||
if !limiter.allow(ip) {
|
||||
w.Header().Set("Retry-After", "60")
|
||||
writeError(w, http.StatusTooManyRequests, "rate limit exceeded", "rate_limited")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// allow returns true if a request from ip is permitted under the rate limit.
|
||||
func (l *ipRateLimiter) allow(ip string) bool {
|
||||
l.mu.Lock()
|
||||
entry, ok := l.ips[ip]
|
||||
if !ok {
|
||||
entry = &rateLimitEntry{tokens: l.burst, lastSeen: time.Now()}
|
||||
l.ips[ip] = entry
|
||||
}
|
||||
l.mu.Unlock()
|
||||
|
||||
entry.mu.Lock()
|
||||
defer entry.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(entry.lastSeen).Seconds()
|
||||
entry.tokens = min(l.burst, entry.tokens+elapsed*l.rps)
|
||||
entry.lastSeen = now
|
||||
|
||||
if entry.tokens < 1 {
|
||||
return false
|
||||
}
|
||||
entry.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
// cleanup periodically removes idle rate-limit entries.
|
||||
func (l *ipRateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
l.mu.Lock()
|
||||
cutoff := time.Now().Add(-l.ttl)
|
||||
for ip, entry := range l.ips {
|
||||
entry.mu.Lock()
|
||||
if entry.lastSeen.Before(cutoff) {
|
||||
delete(l.ips, ip)
|
||||
}
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
l.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// extractBearerToken extracts the token from "Authorization: Bearer <token>".
|
||||
func extractBearerToken(r *http.Request) (string, error) {
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth == "" {
|
||||
return "", fmt.Errorf("missing Authorization header")
|
||||
}
|
||||
parts := strings.SplitN(auth, " ", 2)
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
||||
return "", fmt.Errorf("malformed Authorization header")
|
||||
}
|
||||
if parts[1] == "" {
|
||||
return "", fmt.Errorf("empty Bearer token")
|
||||
}
|
||||
return parts[1], nil
|
||||
}
|
||||
|
||||
// apiError is the uniform error response structure.
|
||||
type apiError struct {
|
||||
Error string `json:"error"`
|
||||
Code string `json:"code"`
|
||||
}
|
||||
|
||||
// writeError writes a JSON error response.
|
||||
func writeError(w http.ResponseWriter, status int, message, code string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
// Intentionally ignoring the error here; if the write fails, the client
|
||||
// already got the status code.
|
||||
_ = json.NewEncoder(w).Encode(apiError{Error: message, Code: code})
|
||||
}
|
||||
|
||||
// WriteError is the exported version for use by handler packages.
|
||||
func WriteError(w http.ResponseWriter, status int, message, code string) {
|
||||
writeError(w, status, message, code)
|
||||
}
|
||||
|
||||
// min returns the smaller of two float64 values.
|
||||
func min(a, b float64) float64 {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
342
internal/middleware/middleware_test.go
Normal file
342
internal/middleware/middleware_test.go
Normal file
@@ -0,0 +1,342 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/token"
|
||||
)
|
||||
|
||||
func generateTestKey(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) {
|
||||
t.Helper()
|
||||
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("generate test key: %v", err)
|
||||
}
|
||||
return pub, priv
|
||||
}
|
||||
|
||||
func openTestDB(t *testing.T) *db.DB {
|
||||
t.Helper()
|
||||
database, err := db.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("open test db: %v", err)
|
||||
}
|
||||
if err := db.Migrate(database); err != nil {
|
||||
t.Fatalf("migrate test db: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = database.Close() })
|
||||
return database
|
||||
}
|
||||
|
||||
const testIssuer = "https://auth.example.com"
|
||||
|
||||
// issueAndTrackToken creates a valid JWT and records it in the DB.
|
||||
func issueAndTrackToken(t *testing.T, priv ed25519.PrivateKey, database *db.DB, accountID int64, roles []string) string {
|
||||
t.Helper()
|
||||
tokenStr, claims, err := token.IssueToken(priv, testIssuer, "user-uuid", roles, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken: %v", err)
|
||||
}
|
||||
if err := database.TrackToken(claims.JTI, accountID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
||||
t.Fatalf("TrackToken: %v", err)
|
||||
}
|
||||
return tokenStr
|
||||
}
|
||||
|
||||
func TestRequestLogger(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
||||
|
||||
handler := RequestLogger(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/health", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", rr.Code)
|
||||
}
|
||||
logOutput := buf.String()
|
||||
if logOutput == "" {
|
||||
t.Error("expected log output, got empty string")
|
||||
}
|
||||
// Security: Authorization header must not appear in logs.
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/v1/health", nil)
|
||||
req2.Header.Set("Authorization", "Bearer secret-token-value")
|
||||
buf.Reset()
|
||||
rr2 := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr2, req2)
|
||||
if bytes.Contains(buf.Bytes(), []byte("secret-token-value")) {
|
||||
t.Error("log output contains Authorization token value — credential leak!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuthValid(t *testing.T) {
|
||||
pub, priv := generateTestKey(t)
|
||||
database := openTestDB(t)
|
||||
|
||||
acct, err := database.CreateAccount("alice", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
tokenStr := issueAndTrackToken(t, priv, database, acct.ID, []string{"reader"})
|
||||
|
||||
reached := false
|
||||
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reached = true
|
||||
claims := ClaimsFromContext(r.Context())
|
||||
if claims == nil {
|
||||
t.Error("claims not in context")
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200; body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
if !reached {
|
||||
t.Error("handler was not reached with valid token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuthMissingHeader(t *testing.T) {
|
||||
pub, priv := generateTestKey(t)
|
||||
_ = priv
|
||||
database := openTestDB(t)
|
||||
|
||||
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("handler should not be reached without auth")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuthInvalidToken(t *testing.T) {
|
||||
pub, _ := generateTestKey(t)
|
||||
database := openTestDB(t)
|
||||
|
||||
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("handler should not be reached with invalid token")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer not.a.valid.jwt")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuthRevokedToken(t *testing.T) {
|
||||
pub, priv := generateTestKey(t)
|
||||
database := openTestDB(t)
|
||||
|
||||
acct, err := database.CreateAccount("bob", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
tokenStr := issueAndTrackToken(t, priv, database, acct.ID, nil)
|
||||
|
||||
// Extract JTI and revoke the token.
|
||||
claims, err := token.ValidateToken(pub, tokenStr, testIssuer)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken: %v", err)
|
||||
}
|
||||
if err := database.RevokeToken(claims.JTI, "test revocation"); err != nil {
|
||||
t.Fatalf("RevokeToken: %v", err)
|
||||
}
|
||||
|
||||
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("handler should not be reached with revoked token")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAuthExpiredToken(t *testing.T) {
|
||||
pub, priv := generateTestKey(t)
|
||||
database := openTestDB(t)
|
||||
|
||||
// Issue an already-expired token.
|
||||
tokenStr, _, err := token.IssueToken(priv, testIssuer, "user-uuid", nil, -time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("IssueToken: %v", err)
|
||||
}
|
||||
|
||||
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("handler should not be reached with expired token")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("status = %d, want 401", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireRoleGranted(t *testing.T) {
|
||||
claims := &token.Claims{Roles: []string{"admin"}}
|
||||
ctx := context.WithValue(context.Background(), claimsKey, claims)
|
||||
|
||||
reached := false
|
||||
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
reached = true
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", rr.Code)
|
||||
}
|
||||
if !reached {
|
||||
t.Error("handler not reached with correct role")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireRoleForbidden(t *testing.T) {
|
||||
claims := &token.Claims{Roles: []string{"reader"}}
|
||||
ctx := context.WithValue(context.Background(), claimsKey, claims)
|
||||
|
||||
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("handler should not be reached without admin role")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusForbidden {
|
||||
t.Errorf("status = %d, want 403", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireRoleNoClaims(t *testing.T) {
|
||||
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("handler should not be reached without claims in context")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusForbidden {
|
||||
t.Errorf("status = %d, want 403", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitAllows(t *testing.T) {
|
||||
handler := RateLimit(10, 5)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
|
||||
// First 5 requests should be allowed (burst=5).
|
||||
for i := range 5 {
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("request %d: status = %d, want 200", i+1, rr.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitBlocks(t *testing.T) {
|
||||
handler := RateLimit(0.1, 2)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
|
||||
req.RemoteAddr = "10.0.0.1:9999"
|
||||
|
||||
// Exhaust the burst of 2.
|
||||
for range 2 {
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
}
|
||||
|
||||
// Next request should be rate-limited.
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("status = %d, want 429 after burst exceeded", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBearerToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
wantErr bool
|
||||
want string
|
||||
}{
|
||||
{"valid", "Bearer mytoken123", false, "mytoken123"},
|
||||
{"missing header", "", true, ""},
|
||||
{"no bearer prefix", "Token mytoken123", true, ""},
|
||||
{"empty token", "Bearer ", true, ""},
|
||||
{"case insensitive", "bearer mytoken123", false, "mytoken123"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if tc.header != "" {
|
||||
req.Header.Set("Authorization", tc.header)
|
||||
}
|
||||
got, err := extractBearerToken(req)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("wantErr=%v, got err=%v", tc.wantErr, err)
|
||||
}
|
||||
if !tc.wantErr && got != tc.want {
|
||||
t.Errorf("token = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user