package middleware import ( "bytes" "context" "crypto/ed25519" "crypto/rand" "log/slog" "net/http" "net/http/httptest" "testing" "time" "git.wntrmute.dev/kyle/mcias/internal/db" "git.wntrmute.dev/kyle/mcias/internal/model" "git.wntrmute.dev/kyle/mcias/internal/token" ) func generateTestKey(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) { t.Helper() pub, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatalf("generate test key: %v", err) } return pub, priv } func openTestDB(t *testing.T) *db.DB { t.Helper() database, err := db.Open(":memory:") if err != nil { t.Fatalf("open test db: %v", err) } if err := db.Migrate(database); err != nil { t.Fatalf("migrate test db: %v", err) } t.Cleanup(func() { _ = database.Close() }) return database } const testIssuer = "https://auth.example.com" // issueAndTrackToken creates a valid JWT and records it in the DB. func issueAndTrackToken(t *testing.T, priv ed25519.PrivateKey, database *db.DB, accountID int64, roles []string) string { t.Helper() tokenStr, claims, err := token.IssueToken(priv, testIssuer, "user-uuid", roles, time.Hour) if err != nil { t.Fatalf("IssueToken: %v", err) } if err := database.TrackToken(claims.JTI, accountID, claims.IssuedAt, claims.ExpiresAt); err != nil { t.Fatalf("TrackToken: %v", err) } return tokenStr } func TestRequestLogger(t *testing.T) { var buf bytes.Buffer logger := slog.New(slog.NewTextHandler(&buf, nil)) handler := RequestLogger(logger)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/v1/health", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("status = %d, want 200", rr.Code) } logOutput := buf.String() if logOutput == "" { t.Error("expected log output, got empty string") } // Security: Authorization header must not appear in logs. req2 := httptest.NewRequest(http.MethodGet, "/v1/health", nil) req2.Header.Set("Authorization", "Bearer secret-token-value") buf.Reset() rr2 := httptest.NewRecorder() handler.ServeHTTP(rr2, req2) if bytes.Contains(buf.Bytes(), []byte("secret-token-value")) { t.Error("log output contains Authorization token value — credential leak!") } } func TestRequireAuthValid(t *testing.T) { pub, priv := generateTestKey(t) database := openTestDB(t) acct, err := database.CreateAccount("alice", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("CreateAccount: %v", err) } tokenStr := issueAndTrackToken(t, priv, database, acct.ID, []string{"reader"}) reached := false handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { reached = true claims := ClaimsFromContext(r.Context()) if claims == nil { t.Error("claims not in context") } w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("status = %d, want 200; body: %s", rr.Code, rr.Body.String()) } if !reached { t.Error("handler was not reached with valid token") } } func TestRequireAuthMissingHeader(t *testing.T) { pub, priv := generateTestKey(t) _ = priv database := openTestDB(t) handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be reached without auth") w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("status = %d, want 401", rr.Code) } } func TestRequireAuthInvalidToken(t *testing.T) { pub, _ := generateTestKey(t) database := openTestDB(t) handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be reached with invalid token") w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) req.Header.Set("Authorization", "Bearer not.a.valid.jwt") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("status = %d, want 401", rr.Code) } } func TestRequireAuthRevokedToken(t *testing.T) { pub, priv := generateTestKey(t) database := openTestDB(t) acct, err := database.CreateAccount("bob", model.AccountTypeHuman, "hash") if err != nil { t.Fatalf("CreateAccount: %v", err) } tokenStr := issueAndTrackToken(t, priv, database, acct.ID, nil) // Extract JTI and revoke the token. claims, err := token.ValidateToken(pub, tokenStr, testIssuer) if err != nil { t.Fatalf("ValidateToken: %v", err) } if err := database.RevokeToken(claims.JTI, "test revocation"); err != nil { t.Fatalf("RevokeToken: %v", err) } handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be reached with revoked token") w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("status = %d, want 401", rr.Code) } } func TestRequireAuthExpiredToken(t *testing.T) { pub, priv := generateTestKey(t) database := openTestDB(t) // Issue an already-expired token. tokenStr, _, err := token.IssueToken(priv, testIssuer, "user-uuid", nil, -time.Minute) if err != nil { t.Fatalf("IssueToken: %v", err) } handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be reached with expired token") w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/v1/test", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusUnauthorized { t.Errorf("status = %d, want 401", rr.Code) } } func TestRequireRoleGranted(t *testing.T) { claims := &token.Claims{Roles: []string{"admin"}} ctx := context.WithValue(context.Background(), claimsKey, claims) reached := false handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { reached = true w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("status = %d, want 200", rr.Code) } if !reached { t.Error("handler not reached with correct role") } } func TestRequireRoleForbidden(t *testing.T) { claims := &token.Claims{Roles: []string{"reader"}} ctx := context.WithValue(context.Background(), claimsKey, claims) handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be reached without admin role") w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusForbidden { t.Errorf("status = %d, want 403", rr.Code) } } func TestRequireRoleNoClaims(t *testing.T) { handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Error("handler should not be reached without claims in context") w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodGet, "/", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusForbidden { t.Errorf("status = %d, want 403", rr.Code) } } func TestRateLimitAllows(t *testing.T) { handler := RateLimit(10, 5)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil) req.RemoteAddr = "127.0.0.1:12345" // First 5 requests should be allowed (burst=5). for i := range 5 { rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("request %d: status = %d, want 200", i+1, rr.Code) } } } func TestRateLimitBlocks(t *testing.T) { handler := RateLimit(0.1, 2)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) })) req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil) req.RemoteAddr = "10.0.0.1:9999" // Exhaust the burst of 2. for range 2 { rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) } // Next request should be rate-limited. rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusTooManyRequests { t.Errorf("status = %d, want 429 after burst exceeded", rr.Code) } } func TestExtractBearerToken(t *testing.T) { tests := []struct { name string header string wantErr bool want string }{ {"valid", "Bearer mytoken123", false, "mytoken123"}, {"missing header", "", true, ""}, {"no bearer prefix", "Token mytoken123", true, ""}, {"empty token", "Bearer ", true, ""}, {"case insensitive", "bearer mytoken123", false, "mytoken123"}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) if tc.header != "" { req.Header.Set("Authorization", tc.header) } got, err := extractBearerToken(req) if (err != nil) != tc.wantErr { t.Errorf("wantErr=%v, got err=%v", tc.wantErr, err) } if !tc.wantErr && got != tc.want { t.Errorf("token = %q, want %q", got, tc.want) } }) } }