Files
mcdsl/sso/sso_test.go
Kyle Isom 8561b34451 Add mcdsl/sso package for SSO redirect clients
New package providing the client side of the MCIAS SSO authorization
code flow. Web services use this to redirect users to MCIAS for login
and exchange the returned authorization code for a JWT.

- Client type with AuthorizeURL() and ExchangeCode() (TLS 1.3 minimum)
- State cookie helpers (SameSite=Lax for cross-site redirect compat)
- Return-to cookie for preserving the original URL across the redirect
- RedirectToLogin() and HandleCallback() high-level helpers
- Full test suite with mock MCIAS server

Security:
- State is 256-bit random, stored in HttpOnly/Secure/Lax cookie
- Return-to URLs stored client-side only (MCIAS never sees them)
- Login/callback paths excluded from return-to to prevent loops

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-30 15:26:50 -07:00

226 lines
5.7 KiB
Go

package sso
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
)
func TestNewValidation(t *testing.T) {
tests := []struct {
name string
cfg Config
wantErr bool
}{
{"valid", Config{MciasURL: "https://mcias.example.com", ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb"}, false},
{"missing url", Config{ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb"}, true},
{"missing client_id", Config{MciasURL: "https://mcias.example.com", RedirectURI: "https://mcr.example.com/cb"}, true},
{"missing redirect_uri", Config{MciasURL: "https://mcias.example.com", ClientID: "mcr"}, true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
_, err := New(tc.cfg)
if tc.wantErr && err == nil {
t.Error("expected error, got nil")
}
if !tc.wantErr && err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestAuthorizeURL(t *testing.T) {
c, err := New(Config{
MciasURL: "http://localhost:8443",
ClientID: "mcr",
RedirectURI: "https://mcr.example.com/sso/callback",
})
if err != nil {
t.Fatalf("New: %v", err)
}
u := c.AuthorizeURL("test-state")
if u == "" {
t.Fatal("AuthorizeURL returned empty string")
}
// Should contain all required params.
for _, want := range []string{"client_id=mcr", "state=test-state", "redirect_uri="} {
if !contains(u, want) {
t.Errorf("URL %q missing %q", u, want)
}
}
}
func TestExchangeCode(t *testing.T) {
// Fake MCIAS server.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/sso/token" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
Code string `json:"code"`
ClientID string `json:"client_id"`
RedirectURI string `json:"redirect_uri"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if req.Code != "valid-code" {
http.Error(w, `{"error":"invalid code"}`, http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]string{
"token": "jwt-token-here",
"expires_at": "2026-03-30T23:00:00Z",
})
}))
defer srv.Close()
c, err := New(Config{
MciasURL: srv.URL,
ClientID: "mcr",
RedirectURI: "https://mcr.example.com/cb",
})
if err != nil {
t.Fatalf("New: %v", err)
}
// Valid code.
token, _, err := c.ExchangeCode(t.Context(), "valid-code")
if err != nil {
t.Fatalf("ExchangeCode: %v", err)
}
if token != "jwt-token-here" {
t.Errorf("token = %q, want %q", token, "jwt-token-here")
}
// Invalid code.
_, _, err = c.ExchangeCode(t.Context(), "bad-code")
if err == nil {
t.Error("expected error for bad code")
}
}
func TestGenerateState(t *testing.T) {
s1, err := GenerateState()
if err != nil {
t.Fatalf("GenerateState: %v", err)
}
if len(s1) != 64 { // 32 bytes = 64 hex chars
t.Errorf("state length = %d, want 64", len(s1))
}
s2, err := GenerateState()
if err != nil {
t.Fatalf("GenerateState: %v", err)
}
if s1 == s2 {
t.Error("two states should differ")
}
}
func TestStateCookieRoundTrip(t *testing.T) {
state := "test-state-value"
rec := httptest.NewRecorder()
SetStateCookie(rec, "mcr", state)
// Simulate a request with the cookie.
req := httptest.NewRequest(http.MethodGet, "/sso/callback?state="+state, nil)
for _, c := range rec.Result().Cookies() {
req.AddCookie(c)
}
w := httptest.NewRecorder()
if err := ValidateStateCookie(w, req, "mcr", state); err != nil {
t.Fatalf("ValidateStateCookie: %v", err)
}
}
func TestStateCookieMismatch(t *testing.T) {
rec := httptest.NewRecorder()
SetStateCookie(rec, "mcr", "correct-state")
req := httptest.NewRequest(http.MethodGet, "/sso/callback?state=wrong-state", nil)
for _, c := range rec.Result().Cookies() {
req.AddCookie(c)
}
w := httptest.NewRecorder()
if err := ValidateStateCookie(w, req, "mcr", "wrong-state"); err == nil {
t.Error("expected error for state mismatch")
}
}
func TestReturnToCookie(t *testing.T) {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/repositories/myrepo", nil)
SetReturnToCookie(rec, req, "mcr")
// Read back.
req2 := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
for _, c := range rec.Result().Cookies() {
req2.AddCookie(c)
}
w2 := httptest.NewRecorder()
path := ConsumeReturnToCookie(w2, req2, "mcr")
if path != "/repositories/myrepo" {
t.Errorf("return-to = %q, want %q", path, "/repositories/myrepo")
}
}
func TestReturnToDefaultsToRoot(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
w := httptest.NewRecorder()
path := ConsumeReturnToCookie(w, req, "mcr")
if path != "/" {
t.Errorf("return-to = %q, want %q", path, "/")
}
}
func TestReturnToSkipsLoginPaths(t *testing.T) {
for _, p := range []string{"/login", "/sso/callback"} {
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, p, nil)
SetReturnToCookie(rec, req, "mcr")
req2 := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
for _, c := range rec.Result().Cookies() {
req2.AddCookie(c)
}
w2 := httptest.NewRecorder()
path := ConsumeReturnToCookie(w2, req2, "mcr")
if path != "/" {
t.Errorf("return-to for %s = %q, want %q", p, path, "/")
}
}
}
func contains(s, sub string) bool {
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub))
}
func containsStr(s, sub string) bool {
for i := 0; i <= len(s)-len(sub); i++ {
if s[i:i+len(sub)] == sub {
return true
}
}
return false
}