package auth import ( "encoding/json" "errors" "net/http" "net/http/httptest" "sync/atomic" "testing" "time" ) // newTestServer starts an httptest.Server that routes MCIAS endpoints. // The handler functions are pluggable per test. func newTestServer(t *testing.T, loginHandler, validateHandler http.HandlerFunc) *httptest.Server { t.Helper() mux := http.NewServeMux() if loginHandler != nil { mux.HandleFunc("/v1/auth/login", loginHandler) } if validateHandler != nil { mux.HandleFunc("/v1/token/validate", validateHandler) } srv := httptest.NewServer(mux) t.Cleanup(srv.Close) return srv } func newTestClient(t *testing.T, serverURL string) *Client { t.Helper() c, err := NewClient(serverURL, "", "mcr-test", []string{"env:test"}) if err != nil { t.Fatalf("NewClient: %v", err) } return c } func TestLoginSuccess(t *testing.T) { srv := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { var req loginRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "bad request", http.StatusBadRequest) return } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(loginResponse{ Token: "tok-abc", ExpiresIn: 3600, }) }, nil) c := newTestClient(t, srv.URL) token, expiresIn, err := c.Login("alice", "secret") if err != nil { t.Fatalf("Login: %v", err) } if token != "tok-abc" { t.Fatalf("token: got %q, want %q", token, "tok-abc") } if expiresIn != 3600 { t.Fatalf("expiresIn: got %d, want %d", expiresIn, 3600) } } func TestLoginFailure(t *testing.T) { srv := newTestServer(t, func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusUnauthorized) }, nil) c := newTestClient(t, srv.URL) _, _, err := c.Login("alice", "wrong") if !errors.Is(err, ErrUnauthorized) { t.Fatalf("Login error: got %v, want %v", err, ErrUnauthorized) } } func TestValidateSuccess(t *testing.T) { srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(validateResponse{ Valid: true, Claims: struct { Subject string `json:"subject"` AccountType string `json:"account_type"` Roles []string `json:"roles"` }{ Subject: "alice", AccountType: "user", Roles: []string{"reader", "writer"}, }, }) }) c := newTestClient(t, srv.URL) claims, err := c.ValidateToken("valid-token-123") if err != nil { t.Fatalf("ValidateToken: %v", err) } if claims.Subject != "alice" { t.Fatalf("subject: got %q, want %q", claims.Subject, "alice") } if claims.AccountType != "user" { t.Fatalf("account_type: got %q, want %q", claims.AccountType, "user") } if len(claims.Roles) != 2 || claims.Roles[0] != "reader" || claims.Roles[1] != "writer" { t.Fatalf("roles: got %v, want [reader writer]", claims.Roles) } } func TestValidateRevoked(t *testing.T) { srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(validateResponse{Valid: false}) }) c := newTestClient(t, srv.URL) _, err := c.ValidateToken("revoked-token") if !errors.Is(err, ErrUnauthorized) { t.Fatalf("ValidateToken error: got %v, want %v", err, ErrUnauthorized) } } func TestValidateCacheHit(t *testing.T) { var callCount atomic.Int64 srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) { callCount.Add(1) w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(validateResponse{ Valid: true, Claims: struct { Subject string `json:"subject"` AccountType string `json:"account_type"` Roles []string `json:"roles"` }{ Subject: "bob", AccountType: "service", Roles: []string{"admin"}, }, }) }) c := newTestClient(t, srv.URL) // First call — should hit the server. claims1, err := c.ValidateToken("cached-token") if err != nil { t.Fatalf("first ValidateToken: %v", err) } if callCount.Load() != 1 { t.Fatalf("expected 1 server call after first validate, got %d", callCount.Load()) } // Second call — should come from cache. claims2, err := c.ValidateToken("cached-token") if err != nil { t.Fatalf("second ValidateToken: %v", err) } if callCount.Load() != 1 { t.Fatalf("expected 1 server call after second validate (cache hit), got %d", callCount.Load()) } if claims1.Subject != claims2.Subject { t.Fatalf("cached claims mismatch: %q vs %q", claims1.Subject, claims2.Subject) } } func TestValidateCacheExpiry(t *testing.T) { var callCount atomic.Int64 srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) { callCount.Add(1) w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(validateResponse{ Valid: true, Claims: struct { Subject string `json:"subject"` AccountType string `json:"account_type"` Roles []string `json:"roles"` }{ Subject: "charlie", AccountType: "user", Roles: nil, }, }) }) c := newTestClient(t, srv.URL) // Inject a controllable clock. now := time.Now() c.cache.now = func() time.Time { return now } // First call. if _, err := c.ValidateToken("expiry-token"); err != nil { t.Fatalf("first ValidateToken: %v", err) } if callCount.Load() != 1 { t.Fatalf("expected 1 server call, got %d", callCount.Load()) } // Second call within TTL — cache hit. if _, err := c.ValidateToken("expiry-token"); err != nil { t.Fatalf("second ValidateToken: %v", err) } if callCount.Load() != 1 { t.Fatalf("expected 1 server call (cache hit), got %d", callCount.Load()) } // Advance clock past the 30s TTL. c.cache.now = func() time.Time { return now.Add(31 * time.Second) } // Third call — cache miss, should hit server again. if _, err := c.ValidateToken("expiry-token"); err != nil { t.Fatalf("third ValidateToken: %v", err) } if callCount.Load() != 2 { t.Fatalf("expected 2 server calls after cache expiry, got %d", callCount.Load()) } }