package server import ( "encoding/json" "net/http" "net/http/httptest" "testing" "git.wntrmute.dev/mc/mcr/internal/auth" ) type fakeLoginClient struct { token string expiresIn int err error } func (f *fakeLoginClient) Login(_, _ string) (string, int, error) { return f.token, f.expiresIn, f.err } type fakeTokenValidator struct { claims *auth.Claims err error } func (f *fakeTokenValidator) ValidateToken(_ string) (*auth.Claims, error) { return f.claims, f.err } func TestTokenHandlerSuccess(t *testing.T) { lc := &fakeLoginClient{token: "tok-xyz", expiresIn: 7200} tv := &fakeTokenValidator{err: auth.ErrUnauthorized} handler := TokenHandler(lc, tv) req := httptest.NewRequest(http.MethodGet, "/v2/token", nil) req.SetBasicAuth("alice", "secret") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status: got %d, want %d", rec.Code, http.StatusOK) } var resp tokenResponse if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { t.Fatalf("decode response: %v", err) } if resp.Token != "tok-xyz" { t.Fatalf("token: got %q, want %q", resp.Token, "tok-xyz") } if resp.ExpiresIn != 7200 { t.Fatalf("expires_in: got %d, want %d", resp.ExpiresIn, 7200) } if resp.IssuedAt == "" { t.Fatal("issued_at: expected non-empty RFC 3339 timestamp") } } func TestTokenHandlerJWTAsPassword(t *testing.T) { lc := &fakeLoginClient{err: auth.ErrUnauthorized} tv := &fakeTokenValidator{claims: &auth.Claims{ Subject: "mcp-agent", AccountType: "system", Roles: nil, }} handler := TokenHandler(lc, tv) jwt := "eyJhbGciOiJFZERTQSJ9.eyJzdWIiOiJ0ZXN0In0.c2lnbmF0dXJl" req := httptest.NewRequest(http.MethodGet, "/v2/token", nil) req.SetBasicAuth("x", jwt) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status: got %d, want %d", rec.Code, http.StatusOK) } var resp tokenResponse if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { t.Fatalf("decode response: %v", err) } if resp.Token != jwt { t.Fatalf("token: got %q, want JWT pass-through", resp.Token) } } func TestTokenHandlerJWTFallsBackToLogin(t *testing.T) { lc := &fakeLoginClient{token: "login-tok", expiresIn: 3600} tv := &fakeTokenValidator{err: auth.ErrUnauthorized} handler := TokenHandler(lc, tv) // Password looks like a JWT but validator rejects it — should fall through to login. req := httptest.NewRequest(http.MethodGet, "/v2/token", nil) req.SetBasicAuth("alice", "not.a.jwt") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status: got %d, want %d", rec.Code, http.StatusOK) } var resp tokenResponse if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { t.Fatalf("decode response: %v", err) } if resp.Token != "login-tok" { t.Fatalf("token: got %q, want %q (login fallback)", resp.Token, "login-tok") } } func TestTokenHandlerInvalidCreds(t *testing.T) { lc := &fakeLoginClient{err: auth.ErrUnauthorized} tv := &fakeTokenValidator{err: auth.ErrUnauthorized} handler := TokenHandler(lc, tv) req := httptest.NewRequest(http.MethodGet, "/v2/token", nil) req.SetBasicAuth("alice", "wrong") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized) } var ociErr ociErrorResponse if err := json.NewDecoder(rec.Body).Decode(&ociErr); err != nil { t.Fatalf("decode OCI error: %v", err) } if len(ociErr.Errors) != 1 || ociErr.Errors[0].Code != "UNAUTHORIZED" { t.Fatalf("OCI error: got %+v, want UNAUTHORIZED", ociErr.Errors) } } func TestTokenHandlerMissingAuth(t *testing.T) { lc := &fakeLoginClient{token: "should-not-matter"} tv := &fakeTokenValidator{err: auth.ErrUnauthorized} handler := TokenHandler(lc, tv) req := httptest.NewRequest(http.MethodGet, "/v2/token", nil) // No Authorization header. rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized) } var ociErr ociErrorResponse if err := json.NewDecoder(rec.Body).Decode(&ociErr); err != nil { t.Fatalf("decode OCI error: %v", err) } if len(ociErr.Errors) != 1 || ociErr.Errors[0].Code != "UNAUTHORIZED" { t.Fatalf("OCI error: got %+v, want UNAUTHORIZED", ociErr.Errors) } }