Files
mcdsl/sso/sso_test.go
Kyle Isom bcab16f2bf Fix SSO return-to redirect loop
SetReturnToCookie stored /sso/redirect as the return-to path,
causing a redirect loop after successful SSO login: the callback
would redirect back to /sso/redirect instead of /. Filter all
/sso/* paths, not just /sso/callback.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 14:54:55 -07:00

226 lines
5.7 KiB
Go

package sso
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestNewValidation(t *testing.T) {
tests := []struct {
name string
cfg Config
wantErr bool
}{
{"valid", Config{MciasURL: "https://mcias.example.com", ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb"}, false},
{"missing url", Config{ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb"}, true},
{"missing client_id", Config{MciasURL: "https://mcias.example.com", RedirectURI: "https://mcr.example.com/cb"}, true},
{"missing redirect_uri", Config{MciasURL: "https://mcias.example.com", ClientID: "mcr"}, true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := New(tc.cfg)
if tc.wantErr && err == nil {
t.Error("expected error, got nil")
}
if !tc.wantErr && err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestAuthorizeURL(t *testing.T) {
c, err := New(Config{
MciasURL: "http://localhost:8443",
ClientID: "mcr",
RedirectURI: "https://mcr.example.com/sso/callback",
})
if err != nil {
t.Fatalf("New: %v", err)
}
u := c.AuthorizeURL("test-state")
if u == "" {
t.Fatal("AuthorizeURL returned empty string")
}
// Should contain all required params.
for _, want := range []string{"client_id=mcr", "state=test-state", "redirect_uri="} {
if !contains(u, want) {
t.Errorf("URL %q missing %q", u, want)
}
}
}
func TestExchangeCode(t *testing.T) {
// Fake MCIAS server.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/sso/token" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Code string `json:"code"`
ClientID string `json:"client_id"`
RedirectURI string `json:"redirect_uri"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if req.Code != "valid-code" {
http.Error(w, `{"error":"invalid code"}`, http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"token": "jwt-token-here",
"expires_at": "2026-03-30T23:00:00Z",
})
}))
defer srv.Close()
c, err := New(Config{
MciasURL: srv.URL,
ClientID: "mcr",
RedirectURI: "https://mcr.example.com/cb",
})
if err != nil {
t.Fatalf("New: %v", err)
}
// Valid code.
token, _, err := c.ExchangeCode(t.Context(), "valid-code")
if err != nil {
t.Fatalf("ExchangeCode: %v", err)
}
if token != "jwt-token-here" {
t.Errorf("token = %q, want %q", token, "jwt-token-here")
}
// Invalid code.
_, _, err = c.ExchangeCode(t.Context(), "bad-code")
if err == nil {
t.Error("expected error for bad code")
}
}
func TestGenerateState(t *testing.T) {
s1, err := GenerateState()
if err != nil {
t.Fatalf("GenerateState: %v", err)
}
if len(s1) != 64 { // 32 bytes = 64 hex chars
t.Errorf("state length = %d, want 64", len(s1))
}
s2, err := GenerateState()
if err != nil {
t.Fatalf("GenerateState: %v", err)
}
if s1 == s2 {
t.Error("two states should differ")
}
}
func TestStateCookieRoundTrip(t *testing.T) {
state := "test-state-value"
rec := httptest.NewRecorder()
SetStateCookie(rec, "mcr", state)
// Simulate a request with the cookie.
req := httptest.NewRequest(http.MethodGet, "/sso/callback?state="+state, nil)
for _, c := range rec.Result().Cookies() {
req.AddCookie(c)
}
w := httptest.NewRecorder()
if err := ValidateStateCookie(w, req, "mcr", state); err != nil {
t.Fatalf("ValidateStateCookie: %v", err)
}
}
func TestStateCookieMismatch(t *testing.T) {
rec := httptest.NewRecorder()
SetStateCookie(rec, "mcr", "correct-state")
req := httptest.NewRequest(http.MethodGet, "/sso/callback?state=wrong-state", nil)
for _, c := range rec.Result().Cookies() {
req.AddCookie(c)
}
w := httptest.NewRecorder()
if err := ValidateStateCookie(w, req, "mcr", "wrong-state"); err == nil {
t.Error("expected error for state mismatch")
}
}
func TestReturnToCookie(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/repositories/myrepo", nil)
SetReturnToCookie(rec, req, "mcr")
// Read back.
req2 := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
for _, c := range rec.Result().Cookies() {
req2.AddCookie(c)
}
w2 := httptest.NewRecorder()
path := ConsumeReturnToCookie(w2, req2, "mcr")
if path != "/repositories/myrepo" {
t.Errorf("return-to = %q, want %q", path, "/repositories/myrepo")
}
}
func TestReturnToDefaultsToRoot(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
w := httptest.NewRecorder()
path := ConsumeReturnToCookie(w, req, "mcr")
if path != "/" {
t.Errorf("return-to = %q, want %q", path, "/")
}
}
func TestReturnToSkipsLoginPaths(t *testing.T) {
for _, p := range []string{"/login", "/sso/callback", "/sso/redirect"} {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, p, nil)
SetReturnToCookie(rec, req, "mcr")
req2 := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
for _, c := range rec.Result().Cookies() {
req2.AddCookie(c)
}
w2 := httptest.NewRecorder()
path := ConsumeReturnToCookie(w2, req2, "mcr")
if path != "/" {
t.Errorf("return-to for %s = %q, want %q", p, path, "/")
}
}
}
func contains(s, sub string) bool {
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub))
}
func containsStr(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}