SetReturnToCookie stored /sso/redirect as the return-to path, causing a redirect loop after successful SSO login: the callback would redirect back to /sso/redirect instead of /. Filter all /sso/* paths, not just /sso/callback. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
226 lines
5.7 KiB
Go
226 lines
5.7 KiB
Go
package sso
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
)
|
|
|
|
func TestNewValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cfg Config
|
|
wantErr bool
|
|
}{
|
|
{"valid", Config{MciasURL: "https://mcias.example.com", ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb"}, false},
|
|
{"missing url", Config{ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb"}, true},
|
|
{"missing client_id", Config{MciasURL: "https://mcias.example.com", RedirectURI: "https://mcr.example.com/cb"}, true},
|
|
{"missing redirect_uri", Config{MciasURL: "https://mcias.example.com", ClientID: "mcr"}, true},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
_, err := New(tc.cfg)
|
|
if tc.wantErr && err == nil {
|
|
t.Error("expected error, got nil")
|
|
}
|
|
if !tc.wantErr && err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthorizeURL(t *testing.T) {
|
|
c, err := New(Config{
|
|
MciasURL: "http://localhost:8443",
|
|
ClientID: "mcr",
|
|
RedirectURI: "https://mcr.example.com/sso/callback",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("New: %v", err)
|
|
}
|
|
|
|
u := c.AuthorizeURL("test-state")
|
|
if u == "" {
|
|
t.Fatal("AuthorizeURL returned empty string")
|
|
}
|
|
|
|
// Should contain all required params.
|
|
for _, want := range []string{"client_id=mcr", "state=test-state", "redirect_uri="} {
|
|
if !contains(u, want) {
|
|
t.Errorf("URL %q missing %q", u, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestExchangeCode(t *testing.T) {
|
|
// Fake MCIAS server.
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/sso/token" {
|
|
http.Error(w, "not found", http.StatusNotFound)
|
|
return
|
|
}
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
var req struct {
|
|
Code string `json:"code"`
|
|
ClientID string `json:"client_id"`
|
|
RedirectURI string `json:"redirect_uri"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "bad request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if req.Code != "valid-code" {
|
|
http.Error(w, `{"error":"invalid code"}`, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
|
"token": "jwt-token-here",
|
|
"expires_at": "2026-03-30T23:00:00Z",
|
|
})
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c, err := New(Config{
|
|
MciasURL: srv.URL,
|
|
ClientID: "mcr",
|
|
RedirectURI: "https://mcr.example.com/cb",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("New: %v", err)
|
|
}
|
|
|
|
// Valid code.
|
|
token, _, err := c.ExchangeCode(t.Context(), "valid-code")
|
|
if err != nil {
|
|
t.Fatalf("ExchangeCode: %v", err)
|
|
}
|
|
if token != "jwt-token-here" {
|
|
t.Errorf("token = %q, want %q", token, "jwt-token-here")
|
|
}
|
|
|
|
// Invalid code.
|
|
_, _, err = c.ExchangeCode(t.Context(), "bad-code")
|
|
if err == nil {
|
|
t.Error("expected error for bad code")
|
|
}
|
|
}
|
|
|
|
func TestGenerateState(t *testing.T) {
|
|
s1, err := GenerateState()
|
|
if err != nil {
|
|
t.Fatalf("GenerateState: %v", err)
|
|
}
|
|
if len(s1) != 64 { // 32 bytes = 64 hex chars
|
|
t.Errorf("state length = %d, want 64", len(s1))
|
|
}
|
|
|
|
s2, err := GenerateState()
|
|
if err != nil {
|
|
t.Fatalf("GenerateState: %v", err)
|
|
}
|
|
if s1 == s2 {
|
|
t.Error("two states should differ")
|
|
}
|
|
}
|
|
|
|
func TestStateCookieRoundTrip(t *testing.T) {
|
|
state := "test-state-value"
|
|
rec := httptest.NewRecorder()
|
|
SetStateCookie(rec, "mcr", state)
|
|
|
|
// Simulate a request with the cookie.
|
|
req := httptest.NewRequest(http.MethodGet, "/sso/callback?state="+state, nil)
|
|
for _, c := range rec.Result().Cookies() {
|
|
req.AddCookie(c)
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
if err := ValidateStateCookie(w, req, "mcr", state); err != nil {
|
|
t.Fatalf("ValidateStateCookie: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestStateCookieMismatch(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
SetStateCookie(rec, "mcr", "correct-state")
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/sso/callback?state=wrong-state", nil)
|
|
for _, c := range rec.Result().Cookies() {
|
|
req.AddCookie(c)
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
if err := ValidateStateCookie(w, req, "mcr", "wrong-state"); err == nil {
|
|
t.Error("expected error for state mismatch")
|
|
}
|
|
}
|
|
|
|
func TestReturnToCookie(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/repositories/myrepo", nil)
|
|
SetReturnToCookie(rec, req, "mcr")
|
|
|
|
// Read back.
|
|
req2 := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
|
|
for _, c := range rec.Result().Cookies() {
|
|
req2.AddCookie(c)
|
|
}
|
|
|
|
w2 := httptest.NewRecorder()
|
|
path := ConsumeReturnToCookie(w2, req2, "mcr")
|
|
if path != "/repositories/myrepo" {
|
|
t.Errorf("return-to = %q, want %q", path, "/repositories/myrepo")
|
|
}
|
|
}
|
|
|
|
func TestReturnToDefaultsToRoot(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
|
|
w := httptest.NewRecorder()
|
|
path := ConsumeReturnToCookie(w, req, "mcr")
|
|
if path != "/" {
|
|
t.Errorf("return-to = %q, want %q", path, "/")
|
|
}
|
|
}
|
|
|
|
func TestReturnToSkipsLoginPaths(t *testing.T) {
|
|
for _, p := range []string{"/login", "/sso/callback", "/sso/redirect"} {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, p, nil)
|
|
SetReturnToCookie(rec, req, "mcr")
|
|
|
|
req2 := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
|
|
for _, c := range rec.Result().Cookies() {
|
|
req2.AddCookie(c)
|
|
}
|
|
|
|
w2 := httptest.NewRecorder()
|
|
path := ConsumeReturnToCookie(w2, req2, "mcr")
|
|
if path != "/" {
|
|
t.Errorf("return-to for %s = %q, want %q", p, path, "/")
|
|
}
|
|
}
|
|
}
|
|
|
|
func contains(s, sub string) bool {
|
|
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub))
|
|
}
|
|
|
|
func containsStr(s, sub string) bool {
|
|
for i := 0; i <= len(s)-len(sub); i++ {
|
|
if s[i:i+len(sub)] == sub {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|