Files
mcias/internal/middleware/middleware_test.go
Kyle Isom ec7c966ad2 trusted proxy, TOTP replay protection, new tests
- Trusted proxy config option for proxy-aware IP extraction
  used by rate limiting and audit logs; validates proxy IP
  before trusting X-Forwarded-For / X-Real-IP headers
- TOTP replay protection via counter-based validation to
  reject reused codes within the same time step (±30s)
- RateLimit middleware updated to extract client IP from
  proxy headers without IP spoofing risk
- New tests for ClientIP proxy logic (spoofed headers,
  fallback) and extended rate-limit proxy coverage
- HTMX error banner script integrated into web UI base
- .gitignore updated for mciasdb build artifact

Security: resolves CRIT-01 (TOTP replay attack) and
DEF-03 (proxy-unaware rate limiting); gRPC TOTP
enrollment aligned with REST via StorePendingTOTP

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-12 17:44:01 -07:00

465 lines
13 KiB
Go

package middleware
import (
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"log/slog"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"
"git.wntrmute.dev/kyle/mcias/internal/db"
"git.wntrmute.dev/kyle/mcias/internal/model"
"git.wntrmute.dev/kyle/mcias/internal/token"
)
func generateTestKey(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) {
t.Helper()
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
t.Fatalf("generate test key: %v", err)
}
return pub, priv
}
func openTestDB(t *testing.T) *db.DB {
t.Helper()
database, err := db.Open(":memory:")
if err != nil {
t.Fatalf("open test db: %v", err)
}
if err := db.Migrate(database); err != nil {
t.Fatalf("migrate test db: %v", err)
}
t.Cleanup(func() { _ = database.Close() })
return database
}
const testIssuer = "https://auth.example.com"
// issueAndTrackToken creates a valid JWT and records it in the DB.
func issueAndTrackToken(t *testing.T, priv ed25519.PrivateKey, database *db.DB, accountID int64, roles []string) string {
t.Helper()
tokenStr, claims, err := token.IssueToken(priv, testIssuer, "user-uuid", roles, time.Hour)
if err != nil {
t.Fatalf("IssueToken: %v", err)
}
if err := database.TrackToken(claims.JTI, accountID, claims.IssuedAt, claims.ExpiresAt); err != nil {
t.Fatalf("TrackToken: %v", err)
}
return tokenStr
}
func TestRequestLogger(t *testing.T) {
var buf bytes.Buffer
logger := slog.New(slog.NewTextHandler(&buf, nil))
handler := RequestLogger(logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/v1/health", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("status = %d, want 200", rr.Code)
}
logOutput := buf.String()
if logOutput == "" {
t.Error("expected log output, got empty string")
}
// Security: Authorization header must not appear in logs.
req2 := httptest.NewRequest(http.MethodGet, "/v1/health", nil)
req2.Header.Set("Authorization", "Bearer secret-token-value")
buf.Reset()
rr2 := httptest.NewRecorder()
handler.ServeHTTP(rr2, req2)
if bytes.Contains(buf.Bytes(), []byte("secret-token-value")) {
t.Error("log output contains Authorization token value — credential leak!")
}
}
func TestRequireAuthValid(t *testing.T) {
pub, priv := generateTestKey(t)
database := openTestDB(t)
acct, err := database.CreateAccount("alice", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
tokenStr := issueAndTrackToken(t, priv, database, acct.ID, []string{"reader"})
reached := false
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reached = true
claims := ClaimsFromContext(r.Context())
if claims == nil {
t.Error("claims not in context")
}
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("status = %d, want 200; body: %s", rr.Code, rr.Body.String())
}
if !reached {
t.Error("handler was not reached with valid token")
}
}
func TestRequireAuthMissingHeader(t *testing.T) {
pub, priv := generateTestKey(t)
_ = priv
database := openTestDB(t)
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
t.Error("handler should not be reached without auth")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", rr.Code)
}
}
func TestRequireAuthInvalidToken(t *testing.T) {
pub, _ := generateTestKey(t)
database := openTestDB(t)
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
t.Error("handler should not be reached with invalid token")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
req.Header.Set("Authorization", "Bearer not.a.valid.jwt")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", rr.Code)
}
}
func TestRequireAuthRevokedToken(t *testing.T) {
pub, priv := generateTestKey(t)
database := openTestDB(t)
acct, err := database.CreateAccount("bob", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
tokenStr := issueAndTrackToken(t, priv, database, acct.ID, nil)
// Extract JTI and revoke the token.
claims, err := token.ValidateToken(pub, tokenStr, testIssuer)
if err != nil {
t.Fatalf("ValidateToken: %v", err)
}
if err := database.RevokeToken(claims.JTI, "test revocation"); err != nil {
t.Fatalf("RevokeToken: %v", err)
}
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
t.Error("handler should not be reached with revoked token")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", rr.Code)
}
}
func TestRequireAuthExpiredToken(t *testing.T) {
pub, priv := generateTestKey(t)
database := openTestDB(t)
// Issue an already-expired token.
tokenStr, _, err := token.IssueToken(priv, testIssuer, "user-uuid", nil, -time.Minute)
if err != nil {
t.Fatalf("IssueToken: %v", err)
}
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
t.Error("handler should not be reached with expired token")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
req.Header.Set("Authorization", "Bearer "+tokenStr)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusUnauthorized {
t.Errorf("status = %d, want 401", rr.Code)
}
}
func TestRequireRoleGranted(t *testing.T) {
claims := &token.Claims{Roles: []string{"admin"}}
ctx := context.WithValue(context.Background(), claimsKey, claims)
reached := false
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
reached = true
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("status = %d, want 200", rr.Code)
}
if !reached {
t.Error("handler not reached with correct role")
}
}
func TestRequireRoleForbidden(t *testing.T) {
claims := &token.Claims{Roles: []string{"reader"}}
ctx := context.WithValue(context.Background(), claimsKey, claims)
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
t.Error("handler should not be reached without admin role")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("status = %d, want 403", rr.Code)
}
}
func TestRequireRoleNoClaims(t *testing.T) {
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
t.Error("handler should not be reached without claims in context")
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusForbidden {
t.Errorf("status = %d, want 403", rr.Code)
}
}
func TestRateLimitAllows(t *testing.T) {
handler := RateLimit(10, 5, nil)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
req.RemoteAddr = "127.0.0.1:12345"
// First 5 requests should be allowed (burst=5).
for i := range 5 {
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("request %d: status = %d, want 200", i+1, rr.Code)
}
}
}
func TestRateLimitBlocks(t *testing.T) {
handler := RateLimit(0.1, 2, nil)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
req.RemoteAddr = "10.0.0.1:9999"
// Exhaust the burst of 2.
for range 2 {
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
}
// Next request should be rate-limited.
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusTooManyRequests {
t.Errorf("status = %d, want 429 after burst exceeded", rr.Code)
}
}
func TestExtractBearerToken(t *testing.T) {
tests := []struct {
name string
header string
want string
wantErr bool
}{
{"valid", "Bearer mytoken123", "mytoken123", false},
{"missing header", "", "", true},
{"no bearer prefix", "Token mytoken123", "", true},
{"empty token", "Bearer ", "", true},
{"case insensitive", "bearer mytoken123", "mytoken123", false},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
if tc.header != "" {
req.Header.Set("Authorization", tc.header)
}
got, err := extractBearerToken(req)
if (err != nil) != tc.wantErr {
t.Errorf("wantErr=%v, got err=%v", tc.wantErr, err)
}
if !tc.wantErr && got != tc.want {
t.Errorf("token = %q, want %q", got, tc.want)
}
})
}
}
// TestClientIP verifies the proxy-aware IP extraction logic.
func TestClientIP(t *testing.T) {
proxy := net.ParseIP("10.0.0.1")
tests := []struct {
name string
remoteAddr string
xForwardedFor string
xRealIP string
trustedProxy net.IP
want string
}{
{
name: "no proxy configured: uses RemoteAddr",
remoteAddr: "203.0.113.5:54321",
want: "203.0.113.5",
},
{
name: "proxy configured but request not from proxy: uses RemoteAddr",
remoteAddr: "198.51.100.9:12345",
xForwardedFor: "203.0.113.99",
trustedProxy: proxy,
want: "198.51.100.9",
},
{
name: "request from trusted proxy with X-Real-IP: uses X-Real-IP",
remoteAddr: "10.0.0.1:8080",
xRealIP: "203.0.113.42",
trustedProxy: proxy,
want: "203.0.113.42",
},
{
name: "request from trusted proxy with X-Forwarded-For: uses first entry",
remoteAddr: "10.0.0.1:8080",
xForwardedFor: "203.0.113.77, 10.0.0.2",
trustedProxy: proxy,
want: "203.0.113.77",
},
{
name: "X-Real-IP takes precedence over X-Forwarded-For",
remoteAddr: "10.0.0.1:8080",
xRealIP: "203.0.113.11",
xForwardedFor: "203.0.113.22",
trustedProxy: proxy,
want: "203.0.113.11",
},
{
name: "proxy request with invalid X-Real-IP falls back to X-Forwarded-For",
remoteAddr: "10.0.0.1:8080",
xRealIP: "not-an-ip",
xForwardedFor: "203.0.113.55",
trustedProxy: proxy,
want: "203.0.113.55",
},
{
name: "proxy request with no forwarding headers falls back to RemoteAddr host",
remoteAddr: "10.0.0.1:8080",
trustedProxy: proxy,
want: "10.0.0.1",
},
{
// Security: attacker fakes X-Forwarded-For but connects directly.
name: "spoofed X-Forwarded-For from non-proxy IP is ignored",
remoteAddr: "198.51.100.99:9999",
xForwardedFor: "127.0.0.1",
trustedProxy: proxy,
want: "198.51.100.99",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = tc.remoteAddr
if tc.xForwardedFor != "" {
req.Header.Set("X-Forwarded-For", tc.xForwardedFor)
}
if tc.xRealIP != "" {
req.Header.Set("X-Real-IP", tc.xRealIP)
}
got := ClientIP(req, tc.trustedProxy)
if got != tc.want {
t.Errorf("ClientIP = %q, want %q", got, tc.want)
}
})
}
}
// TestRateLimitTrustedProxy verifies that rate limiting uses the forwarded IP
// when the request originates from a trusted proxy.
func TestRateLimitTrustedProxy(t *testing.T) {
proxy := net.ParseIP("10.0.0.1")
// Very low rps and burst=1 so any two requests from the same IP are blocked.
handler := RateLimit(0.001, 1, proxy)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Two requests from the same real client IP, forwarded by the proxy.
// Both carry the same X-Real-IP; the second should be rate-limited.
for i, wantStatus := range []int{http.StatusOK, http.StatusTooManyRequests} {
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
req.RemoteAddr = "10.0.0.1:5000" // from the trusted proxy
req.Header.Set("X-Real-IP", "203.0.113.5")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != wantStatus {
t.Errorf("request %d: status = %d, want %d", i+1, rr.Code, wantStatus)
}
}
// A different real client (different X-Real-IP) should still be allowed.
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
req.RemoteAddr = "10.0.0.1:5001"
req.Header.Set("X-Real-IP", "203.0.113.99")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("distinct client: status = %d, want 200 (separate bucket)", rr.Code)
}
}