package sso import ( "encoding/json" "net/http" "net/http/httptest" "testing" ) func TestNewValidation(t *testing.T) { tests := []struct { name string cfg Config wantErr bool }{ {"valid", Config{MciasURL: "https://mcias.example.com", ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb"}, false}, {"missing url", Config{ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb"}, true}, {"missing client_id", Config{MciasURL: "https://mcias.example.com", RedirectURI: "https://mcr.example.com/cb"}, true}, {"missing redirect_uri", Config{MciasURL: "https://mcias.example.com", ClientID: "mcr"}, true}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { _, err := New(tc.cfg) if tc.wantErr && err == nil { t.Error("expected error, got nil") } if !tc.wantErr && err != nil { t.Errorf("unexpected error: %v", err) } }) } } func TestAuthorizeURL(t *testing.T) { c, err := New(Config{ MciasURL: "http://localhost:8443", ClientID: "mcr", RedirectURI: "https://mcr.example.com/sso/callback", }) if err != nil { t.Fatalf("New: %v", err) } u := c.AuthorizeURL("test-state") if u == "" { t.Fatal("AuthorizeURL returned empty string") } // Should contain all required params. for _, want := range []string{"client_id=mcr", "state=test-state", "redirect_uri="} { if !contains(u, want) { t.Errorf("URL %q missing %q", u, want) } } } func TestExchangeCode(t *testing.T) { // Fake MCIAS server. srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/sso/token" { http.Error(w, "not found", http.StatusNotFound) return } if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } var req struct { Code string `json:"code"` ClientID string `json:"client_id"` RedirectURI string `json:"redirect_uri"` } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, "bad request", http.StatusBadRequest) return } if req.Code != "valid-code" { http.Error(w, `{"error":"invalid code"}`, http.StatusUnauthorized) return } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]string{ "token": "jwt-token-here", "expires_at": "2026-03-30T23:00:00Z", }) })) defer srv.Close() c, err := New(Config{ MciasURL: srv.URL, ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb", }) if err != nil { t.Fatalf("New: %v", err) } // Valid code. token, _, err := c.ExchangeCode(t.Context(), "valid-code") if err != nil { t.Fatalf("ExchangeCode: %v", err) } if token != "jwt-token-here" { t.Errorf("token = %q, want %q", token, "jwt-token-here") } // Invalid code. _, _, err = c.ExchangeCode(t.Context(), "bad-code") if err == nil { t.Error("expected error for bad code") } } func TestGenerateState(t *testing.T) { s1, err := GenerateState() if err != nil { t.Fatalf("GenerateState: %v", err) } if len(s1) != 64 { // 32 bytes = 64 hex chars t.Errorf("state length = %d, want 64", len(s1)) } s2, err := GenerateState() if err != nil { t.Fatalf("GenerateState: %v", err) } if s1 == s2 { t.Error("two states should differ") } } func TestStateCookieRoundTrip(t *testing.T) { state := "test-state-value" rec := httptest.NewRecorder() SetStateCookie(rec, "mcr", state) // Simulate a request with the cookie. req := httptest.NewRequest(http.MethodGet, "/sso/callback?state="+state, nil) for _, c := range rec.Result().Cookies() { req.AddCookie(c) } w := httptest.NewRecorder() if err := ValidateStateCookie(w, req, "mcr", state); err != nil { t.Fatalf("ValidateStateCookie: %v", err) } } func TestStateCookieMismatch(t *testing.T) { rec := httptest.NewRecorder() SetStateCookie(rec, "mcr", "correct-state") req := httptest.NewRequest(http.MethodGet, "/sso/callback?state=wrong-state", nil) for _, c := range rec.Result().Cookies() { req.AddCookie(c) } w := httptest.NewRecorder() if err := ValidateStateCookie(w, req, "mcr", "wrong-state"); err == nil { t.Error("expected error for state mismatch") } } func TestReturnToCookie(t *testing.T) { rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/repositories/myrepo", nil) SetReturnToCookie(rec, req, "mcr") // Read back. req2 := httptest.NewRequest(http.MethodGet, "/sso/callback", nil) for _, c := range rec.Result().Cookies() { req2.AddCookie(c) } w2 := httptest.NewRecorder() path := ConsumeReturnToCookie(w2, req2, "mcr") if path != "/repositories/myrepo" { t.Errorf("return-to = %q, want %q", path, "/repositories/myrepo") } } func TestReturnToDefaultsToRoot(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/sso/callback", nil) w := httptest.NewRecorder() path := ConsumeReturnToCookie(w, req, "mcr") if path != "/" { t.Errorf("return-to = %q, want %q", path, "/") } } func TestReturnToSkipsLoginPaths(t *testing.T) { for _, p := range []string{"/login", "/sso/callback", "/sso/redirect"} { rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, p, nil) SetReturnToCookie(rec, req, "mcr") req2 := httptest.NewRequest(http.MethodGet, "/sso/callback", nil) for _, c := range rec.Result().Cookies() { req2.AddCookie(c) } w2 := httptest.NewRecorder() path := ConsumeReturnToCookie(w2, req2, "mcr") if path != "/" { t.Errorf("return-to for %s = %q, want %q", p, path, "/") } } } func contains(s, sub string) bool { return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub)) } func containsStr(s, sub string) bool { for i := 0; i <= len(s)-len(sub); i++ { if s[i:i+len(sub)] == sub { return true } } return false }