New package providing the client side of the MCIAS SSO authorization code flow. Web services use this to redirect users to MCIAS for login and exchange the returned authorization code for a JWT. - Client type with AuthorizeURL() and ExchangeCode() (TLS 1.3 minimum) - State cookie helpers (SameSite=Lax for cross-site redirect compat) - Return-to cookie for preserving the original URL across the redirect - RedirectToLogin() and HandleCallback() high-level helpers - Full test suite with mock MCIAS server Security: - State is 256-bit random, stored in HttpOnly/Secure/Lax cookie - Return-to URLs stored client-side only (MCIAS never sees them) - Login/callback paths excluded from return-to to prevent loops Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
226 lines
5.7 KiB
Go
226 lines
5.7 KiB
Go
package sso
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
)
|
|
|
|
func TestNewValidation(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cfg Config
|
|
wantErr bool
|
|
}{
|
|
{"valid", Config{MciasURL: "https://mcias.example.com", ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb"}, false},
|
|
{"missing url", Config{ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb"}, true},
|
|
{"missing client_id", Config{MciasURL: "https://mcias.example.com", RedirectURI: "https://mcr.example.com/cb"}, true},
|
|
{"missing redirect_uri", Config{MciasURL: "https://mcias.example.com", ClientID: "mcr"}, true},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
_, err := New(tc.cfg)
|
|
if tc.wantErr && err == nil {
|
|
t.Error("expected error, got nil")
|
|
}
|
|
if !tc.wantErr && err != nil {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthorizeURL(t *testing.T) {
|
|
c, err := New(Config{
|
|
MciasURL: "http://localhost:8443",
|
|
ClientID: "mcr",
|
|
RedirectURI: "https://mcr.example.com/sso/callback",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("New: %v", err)
|
|
}
|
|
|
|
u := c.AuthorizeURL("test-state")
|
|
if u == "" {
|
|
t.Fatal("AuthorizeURL returned empty string")
|
|
}
|
|
|
|
// Should contain all required params.
|
|
for _, want := range []string{"client_id=mcr", "state=test-state", "redirect_uri="} {
|
|
if !contains(u, want) {
|
|
t.Errorf("URL %q missing %q", u, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestExchangeCode(t *testing.T) {
|
|
// Fake MCIAS server.
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v1/sso/token" {
|
|
http.Error(w, "not found", http.StatusNotFound)
|
|
return
|
|
}
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
var req struct {
|
|
Code string `json:"code"`
|
|
ClientID string `json:"client_id"`
|
|
RedirectURI string `json:"redirect_uri"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
http.Error(w, "bad request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if req.Code != "valid-code" {
|
|
http.Error(w, `{"error":"invalid code"}`, http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
|
"token": "jwt-token-here",
|
|
"expires_at": "2026-03-30T23:00:00Z",
|
|
})
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c, err := New(Config{
|
|
MciasURL: srv.URL,
|
|
ClientID: "mcr",
|
|
RedirectURI: "https://mcr.example.com/cb",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("New: %v", err)
|
|
}
|
|
|
|
// Valid code.
|
|
token, _, err := c.ExchangeCode(t.Context(), "valid-code")
|
|
if err != nil {
|
|
t.Fatalf("ExchangeCode: %v", err)
|
|
}
|
|
if token != "jwt-token-here" {
|
|
t.Errorf("token = %q, want %q", token, "jwt-token-here")
|
|
}
|
|
|
|
// Invalid code.
|
|
_, _, err = c.ExchangeCode(t.Context(), "bad-code")
|
|
if err == nil {
|
|
t.Error("expected error for bad code")
|
|
}
|
|
}
|
|
|
|
func TestGenerateState(t *testing.T) {
|
|
s1, err := GenerateState()
|
|
if err != nil {
|
|
t.Fatalf("GenerateState: %v", err)
|
|
}
|
|
if len(s1) != 64 { // 32 bytes = 64 hex chars
|
|
t.Errorf("state length = %d, want 64", len(s1))
|
|
}
|
|
|
|
s2, err := GenerateState()
|
|
if err != nil {
|
|
t.Fatalf("GenerateState: %v", err)
|
|
}
|
|
if s1 == s2 {
|
|
t.Error("two states should differ")
|
|
}
|
|
}
|
|
|
|
func TestStateCookieRoundTrip(t *testing.T) {
|
|
state := "test-state-value"
|
|
rec := httptest.NewRecorder()
|
|
SetStateCookie(rec, "mcr", state)
|
|
|
|
// Simulate a request with the cookie.
|
|
req := httptest.NewRequest(http.MethodGet, "/sso/callback?state="+state, nil)
|
|
for _, c := range rec.Result().Cookies() {
|
|
req.AddCookie(c)
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
if err := ValidateStateCookie(w, req, "mcr", state); err != nil {
|
|
t.Fatalf("ValidateStateCookie: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestStateCookieMismatch(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
SetStateCookie(rec, "mcr", "correct-state")
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/sso/callback?state=wrong-state", nil)
|
|
for _, c := range rec.Result().Cookies() {
|
|
req.AddCookie(c)
|
|
}
|
|
|
|
w := httptest.NewRecorder()
|
|
if err := ValidateStateCookie(w, req, "mcr", "wrong-state"); err == nil {
|
|
t.Error("expected error for state mismatch")
|
|
}
|
|
}
|
|
|
|
func TestReturnToCookie(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, "/repositories/myrepo", nil)
|
|
SetReturnToCookie(rec, req, "mcr")
|
|
|
|
// Read back.
|
|
req2 := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
|
|
for _, c := range rec.Result().Cookies() {
|
|
req2.AddCookie(c)
|
|
}
|
|
|
|
w2 := httptest.NewRecorder()
|
|
path := ConsumeReturnToCookie(w2, req2, "mcr")
|
|
if path != "/repositories/myrepo" {
|
|
t.Errorf("return-to = %q, want %q", path, "/repositories/myrepo")
|
|
}
|
|
}
|
|
|
|
func TestReturnToDefaultsToRoot(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
|
|
w := httptest.NewRecorder()
|
|
path := ConsumeReturnToCookie(w, req, "mcr")
|
|
if path != "/" {
|
|
t.Errorf("return-to = %q, want %q", path, "/")
|
|
}
|
|
}
|
|
|
|
func TestReturnToSkipsLoginPaths(t *testing.T) {
|
|
for _, p := range []string{"/login", "/sso/callback"} {
|
|
rec := httptest.NewRecorder()
|
|
req := httptest.NewRequest(http.MethodGet, p, nil)
|
|
SetReturnToCookie(rec, req, "mcr")
|
|
|
|
req2 := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
|
|
for _, c := range rec.Result().Cookies() {
|
|
req2.AddCookie(c)
|
|
}
|
|
|
|
w2 := httptest.NewRecorder()
|
|
path := ConsumeReturnToCookie(w2, req2, "mcr")
|
|
if path != "/" {
|
|
t.Errorf("return-to for %s = %q, want %q", p, path, "/")
|
|
}
|
|
}
|
|
}
|
|
|
|
func contains(s, sub string) bool {
|
|
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub))
|
|
}
|
|
|
|
func containsStr(s, sub string) bool {
|
|
for i := 0; i <= len(s)-len(sub); i++ {
|
|
if s[i:i+len(sub)] == sub {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|