package middleware import ( "bytes" "context" "crypto/ed25519" "crypto/rand" "log/slog" "net" "net/http" "net/http/httptest" "testing" "time" "git.wntrmute.dev/kyle/mcias/internal/db" "git.wntrmute.dev/kyle/mcias/internal/model" "git.wntrmute.dev/kyle/mcias/internal/token" ) func generateTestKey(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) { t.Helper() pub, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatalf("generate test key: %v", err) } return pub, priv } func openTestDB(t *testing.T) *db.DB { t.Helper() database, err := db.Open(":memory:") if err != nil { t.Fatalf("open test db: %v", err) } if err := db.Migrate(database); err != nil { t.Fatalf("migrate test db: %v", err) } t.Cleanup(func() { _ = database.Close() }) return database } const testIssuer = "https://auth.example.com" // issueAndTrackToken creates a valid JWT and records it in the DB. func issueAndTrackToken(t *testing.T, priv ed25519.PrivateKey, database *db.DB, accountID int64, roles []string) string { t.Helper() tokenStr, claims, err := token.IssueToken(priv, testIssuer, "user-uuid", roles, time.Hour) if err != nil { t.Fatalf("IssueToken: %v", err) } if err := database.TrackToken(claims.JTI, accountID, claims.IssuedAt, claims.ExpiresAt); err != nil { t.Fatalf("TrackToken: %v", err) } return tokenStr } func TestRequestLogger(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) handler := RequestLogger(logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/v1/health", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("status = %d, want 200", rr.Code) } logOutput := buf.String() if logOutput == "" { t.Error("expected log output, got empty string") } // Security: Authorization header must not appear in logs. req2 := httptest.NewRequest(http.MethodGet, "/v1/health", nil) req2.Header.Set("Authorization", "Bearer secret-token-value") buf.Reset() rr2 := httptest.NewRecorder() handler.ServeHTTP(rr2, req2) if bytes.Contains(buf.Bytes(), []byte("secret-token-value")) { t.Error("log output contains Authorization token value — credential leak!") } } func TestRequireAuthValid(t *testing.T) { pub, priv := generateTestKey(t) database := openTestDB(t) acct, err := database.CreateAccount("alice", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("CreateAccount: %v", err) } tokenStr := issueAndTrackToken(t, priv, database, acct.ID, []string{"reader"}) reached := false handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { reached = true claims := ClaimsFromContext(r.Context()) if claims == nil { t.Error("claims not in context") } w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("status = %d, want 200; body: %s", rr.Code, rr.Body.String()) } if !reached { t.Error("handler was not reached with valid token") } } func TestRequireAuthMissingHeader(t *testing.T) { pub, priv := generateTestKey(t) _ = priv database := openTestDB(t) handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { t.Error("handler should not be reached without auth") w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("status = %d, want 401", rr.Code) } } func TestRequireAuthInvalidToken(t *testing.T) { pub, _ := generateTestKey(t) database := openTestDB(t) handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { t.Error("handler should not be reached with invalid token") w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) req.Header.Set("Authorization", "Bearer not.a.valid.jwt") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("status = %d, want 401", rr.Code) } } func TestRequireAuthRevokedToken(t *testing.T) { pub, priv := generateTestKey(t) database := openTestDB(t) acct, err := database.CreateAccount("bob", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("CreateAccount: %v", err) } tokenStr := issueAndTrackToken(t, priv, database, acct.ID, nil) // Extract JTI and revoke the token. claims, err := token.ValidateToken(pub, tokenStr, testIssuer) if err != nil { t.Fatalf("ValidateToken: %v", err) } if err := database.RevokeToken(claims.JTI, "test revocation"); err != nil { t.Fatalf("RevokeToken: %v", err) } handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { t.Error("handler should not be reached with revoked token") w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("status = %d, want 401", rr.Code) } } func TestRequireAuthExpiredToken(t *testing.T) { pub, priv := generateTestKey(t) database := openTestDB(t) // Issue an already-expired token. tokenStr, _, err := token.IssueToken(priv, testIssuer, "user-uuid", nil, -time.Minute) if err != nil { t.Fatalf("IssueToken: %v", err) } handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { t.Error("handler should not be reached with expired token") w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("status = %d, want 401", rr.Code) } } func TestRequireRoleGranted(t *testing.T) { claims := &token.Claims{Roles: []string{"admin"}} ctx := context.WithValue(context.Background(), claimsKey, claims) reached := false handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { reached = true w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("status = %d, want 200", rr.Code) } if !reached { t.Error("handler not reached with correct role") } } func TestRequireRoleForbidden(t *testing.T) { claims := &token.Claims{Roles: []string{"reader"}} ctx := context.WithValue(context.Background(), claimsKey, claims) handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { t.Error("handler should not be reached without admin role") w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusForbidden { t.Errorf("status = %d, want 403", rr.Code) } } func TestRequireRoleNoClaims(t *testing.T) { handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { t.Error("handler should not be reached without claims in context") w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusForbidden { t.Errorf("status = %d, want 403", rr.Code) } } func TestRateLimitAllows(t *testing.T) { handler := RateLimit(10, 5, nil)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil) req.RemoteAddr = "127.0.0.1:12345" // First 5 requests should be allowed (burst=5). for i := range 5 { rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("request %d: status = %d, want 200", i+1, rr.Code) } } } func TestRateLimitBlocks(t *testing.T) { handler := RateLimit(0.1, 2, nil)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil) req.RemoteAddr = "10.0.0.1:9999" // Exhaust the burst of 2. for range 2 { rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) } // Next request should be rate-limited. rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusTooManyRequests { t.Errorf("status = %d, want 429 after burst exceeded", rr.Code) } } func TestExtractBearerToken(t *testing.T) { tests := []struct { name string header string want string wantErr bool }{ {"valid", "Bearer mytoken123", "mytoken123", false}, {"missing header", "", "", true}, {"no bearer prefix", "Token mytoken123", "", true}, {"empty token", "Bearer ", "", true}, {"case insensitive", "bearer mytoken123", "mytoken123", false}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) if tc.header != "" { req.Header.Set("Authorization", tc.header) } got, err := extractBearerToken(req) if (err != nil) != tc.wantErr { t.Errorf("wantErr=%v, got err=%v", tc.wantErr, err) } if !tc.wantErr && got != tc.want { t.Errorf("token = %q, want %q", got, tc.want) } }) } } // TestClientIP verifies the proxy-aware IP extraction logic. func TestClientIP(t *testing.T) { proxy := net.ParseIP("10.0.0.1") tests := []struct { name string remoteAddr string xForwardedFor string xRealIP string trustedProxy net.IP want string }{ { name: "no proxy configured: uses RemoteAddr", remoteAddr: "203.0.113.5:54321", want: "203.0.113.5", }, { name: "proxy configured but request not from proxy: uses RemoteAddr", remoteAddr: "198.51.100.9:12345", xForwardedFor: "203.0.113.99", trustedProxy: proxy, want: "198.51.100.9", }, { name: "request from trusted proxy with X-Real-IP: uses X-Real-IP", remoteAddr: "10.0.0.1:8080", xRealIP: "203.0.113.42", trustedProxy: proxy, want: "203.0.113.42", }, { name: "request from trusted proxy with X-Forwarded-For: uses first entry", remoteAddr: "10.0.0.1:8080", xForwardedFor: "203.0.113.77, 10.0.0.2", trustedProxy: proxy, want: "203.0.113.77", }, { name: "X-Real-IP takes precedence over X-Forwarded-For", remoteAddr: "10.0.0.1:8080", xRealIP: "203.0.113.11", xForwardedFor: "203.0.113.22", trustedProxy: proxy, want: "203.0.113.11", }, { name: "proxy request with invalid X-Real-IP falls back to X-Forwarded-For", remoteAddr: "10.0.0.1:8080", xRealIP: "not-an-ip", xForwardedFor: "203.0.113.55", trustedProxy: proxy, want: "203.0.113.55", }, { name: "proxy request with no forwarding headers falls back to RemoteAddr host", remoteAddr: "10.0.0.1:8080", trustedProxy: proxy, want: "10.0.0.1", }, { // Security: attacker fakes X-Forwarded-For but connects directly. name: "spoofed X-Forwarded-For from non-proxy IP is ignored", remoteAddr: "198.51.100.99:9999", xForwardedFor: "127.0.0.1", trustedProxy: proxy, want: "198.51.100.99", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.RemoteAddr = tc.remoteAddr if tc.xForwardedFor != "" { req.Header.Set("X-Forwarded-For", tc.xForwardedFor) } if tc.xRealIP != "" { req.Header.Set("X-Real-IP", tc.xRealIP) } got := ClientIP(req, tc.trustedProxy) if got != tc.want { t.Errorf("ClientIP = %q, want %q", got, tc.want) } }) } } // TestRateLimitTrustedProxy verifies that rate limiting uses the forwarded IP // when the request originates from a trusted proxy. func TestRateLimitTrustedProxy(t *testing.T) { proxy := net.ParseIP("10.0.0.1") // Very low rps and burst=1 so any two requests from the same IP are blocked. handler := RateLimit(0.001, 1, proxy)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) // Two requests from the same real client IP, forwarded by the proxy. // Both carry the same X-Real-IP; the second should be rate-limited. for i, wantStatus := range []int{http.StatusOK, http.StatusTooManyRequests} { req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil) req.RemoteAddr = "10.0.0.1:5000" // from the trusted proxy req.Header.Set("X-Real-IP", "203.0.113.5") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != wantStatus { t.Errorf("request %d: status = %d, want %d", i+1, rr.Code, wantStatus) } } // A different real client (different X-Real-IP) should still be allowed. req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil) req.RemoteAddr = "10.0.0.1:5001" req.Header.Set("X-Real-IP", "203.0.113.99") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("distinct client: status = %d, want 200 (separate bucket)", rr.Code) } }