Add mcdsl/sso package for SSO redirect clients
New package providing the client side of the MCIAS SSO authorization code flow. Web services use this to redirect users to MCIAS for login and exchange the returned authorization code for a JWT. - Client type with AuthorizeURL() and ExchangeCode() (TLS 1.3 minimum) - State cookie helpers (SameSite=Lax for cross-site redirect compat) - Return-to cookie for preserving the original URL across the redirect - RedirectToLogin() and HandleCallback() high-level helpers - Full test suite with mock MCIAS server Security: - State is 256-bit random, stored in HttpOnly/Secure/Lax cookie - Return-to URLs stored client-side only (MCIAS never sees them) - Login/callback paths excluded from return-to to prevent loops Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
304
sso/sso.go
Normal file
304
sso/sso.go
Normal file
@@ -0,0 +1,304 @@
|
||||
// Package sso provides an SSO redirect client for Metacircular web services.
|
||||
//
|
||||
// Services redirect unauthenticated users to MCIAS for login. After
|
||||
// authentication, MCIAS redirects back with an authorization code that
|
||||
// the service exchanges for a JWT token. This package handles the
|
||||
// redirect, state management, and code exchange.
|
||||
//
|
||||
// Security design:
|
||||
// - State cookies use SameSite=Lax (not Strict) because the redirect from
|
||||
// MCIAS back to the service is a cross-site navigation.
|
||||
// - State is a 256-bit random value stored in an HttpOnly cookie.
|
||||
// - Return-to URLs are stored in a separate cookie so MCIAS never sees them.
|
||||
// - The code exchange is a server-to-server HTTPS call (TLS 1.3 minimum).
|
||||
package sso
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
stateBytes = 32 // 256 bits
|
||||
stateCookieAge = 5 * 60 // 5 minutes in seconds
|
||||
)
|
||||
|
||||
// Config holds the SSO client configuration. The values must match the
|
||||
// SSO client registration in MCIAS config.
|
||||
type Config struct {
|
||||
// MciasURL is the base URL of the MCIAS server.
|
||||
MciasURL string
|
||||
|
||||
// ClientID is the registered SSO client identifier.
|
||||
ClientID string
|
||||
|
||||
// RedirectURI is the callback URL that MCIAS redirects to after login.
|
||||
// Must exactly match the redirect_uri registered in MCIAS config.
|
||||
RedirectURI string
|
||||
|
||||
// CACert is an optional path to a PEM-encoded CA certificate for
|
||||
// verifying the MCIAS server's TLS certificate.
|
||||
CACert string
|
||||
}
|
||||
|
||||
// Client handles the SSO redirect flow with MCIAS.
|
||||
type Client struct {
|
||||
cfg Config
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// New creates an SSO client. TLS 1.3 is required for all HTTPS
|
||||
// connections to MCIAS.
|
||||
func New(cfg Config) (*Client, error) {
|
||||
if cfg.MciasURL == "" {
|
||||
return nil, fmt.Errorf("sso: mcias_url is required")
|
||||
}
|
||||
if cfg.ClientID == "" {
|
||||
return nil, fmt.Errorf("sso: client_id is required")
|
||||
}
|
||||
if cfg.RedirectURI == "" {
|
||||
return nil, fmt.Errorf("sso: redirect_uri is required")
|
||||
}
|
||||
|
||||
transport := &http.Transport{}
|
||||
|
||||
if !strings.HasPrefix(cfg.MciasURL, "http://") {
|
||||
tlsCfg := &tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
if cfg.CACert != "" {
|
||||
pem, err := os.ReadFile(cfg.CACert) //nolint:gosec // CA cert path from operator config
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sso: read CA cert %s: %w", cfg.CACert, err)
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(pem) {
|
||||
return nil, fmt.Errorf("sso: no valid certificates in %s", cfg.CACert)
|
||||
}
|
||||
tlsCfg.RootCAs = pool
|
||||
}
|
||||
|
||||
transport.TLSClientConfig = tlsCfg
|
||||
}
|
||||
|
||||
return &Client{
|
||||
cfg: cfg,
|
||||
httpClient: &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AuthorizeURL returns the MCIAS authorize URL with the given state parameter.
|
||||
func (c *Client) AuthorizeURL(state string) string {
|
||||
base := strings.TrimRight(c.cfg.MciasURL, "/")
|
||||
return base + "/sso/authorize?" + url.Values{
|
||||
"client_id": {c.cfg.ClientID},
|
||||
"redirect_uri": {c.cfg.RedirectURI},
|
||||
"state": {state},
|
||||
}.Encode()
|
||||
}
|
||||
|
||||
// ExchangeCode exchanges an authorization code for a JWT token by calling
|
||||
// MCIAS POST /v1/sso/token.
|
||||
func (c *Client) ExchangeCode(ctx context.Context, code string) (token string, expiresAt time.Time, err error) {
|
||||
reqBody, _ := json.Marshal(map[string]string{
|
||||
"code": code,
|
||||
"client_id": c.cfg.ClientID,
|
||||
"redirect_uri": c.cfg.RedirectURI,
|
||||
})
|
||||
|
||||
base := strings.TrimRight(c.cfg.MciasURL, "/")
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
||||
base+"/v1/sso/token", bytes.NewReader(reqBody))
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("sso: build exchange request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("sso: MCIAS exchange: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("sso: read exchange response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", time.Time{}, fmt.Errorf("sso: exchange failed (HTTP %d): %s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("sso: decode exchange response: %w", err)
|
||||
}
|
||||
|
||||
exp, parseErr := time.Parse(time.RFC3339, result.ExpiresAt)
|
||||
if parseErr != nil {
|
||||
exp = time.Now().Add(1 * time.Hour)
|
||||
}
|
||||
|
||||
return result.Token, exp, nil
|
||||
}
|
||||
|
||||
// GenerateState returns a cryptographically random hex-encoded state string.
|
||||
func GenerateState() (string, error) {
|
||||
raw := make([]byte, stateBytes)
|
||||
if _, err := rand.Read(raw); err != nil {
|
||||
return "", fmt.Errorf("sso: generate state: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(raw), nil
|
||||
}
|
||||
|
||||
// StateCookieName returns the cookie name used for SSO state for a given
|
||||
// service cookie prefix (e.g., "mcr" → "mcr_sso_state").
|
||||
func StateCookieName(prefix string) string {
|
||||
return prefix + "_sso_state"
|
||||
}
|
||||
|
||||
// ReturnToCookieName returns the cookie name used for SSO return-to URL
|
||||
// (e.g., "mcr" → "mcr_sso_return").
|
||||
func ReturnToCookieName(prefix string) string {
|
||||
return prefix + "_sso_return"
|
||||
}
|
||||
|
||||
// SetStateCookie stores the SSO state in a short-lived cookie.
|
||||
//
|
||||
// Security: SameSite=Lax is required because the redirect from MCIAS back to
|
||||
// the service is a cross-site top-level navigation. SameSite=Strict cookies
|
||||
// would not be sent on that redirect.
|
||||
func SetStateCookie(w http.ResponseWriter, prefix, state string) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: StateCookieName(prefix),
|
||||
Value: state,
|
||||
Path: "/",
|
||||
MaxAge: stateCookieAge,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
// ValidateStateCookie compares the state query parameter against the state
|
||||
// cookie. If they match, the cookie is cleared and nil is returned.
|
||||
func ValidateStateCookie(w http.ResponseWriter, r *http.Request, prefix, queryState string) error {
|
||||
c, err := r.Cookie(StateCookieName(prefix))
|
||||
if err != nil || c.Value == "" {
|
||||
return fmt.Errorf("sso: missing state cookie")
|
||||
}
|
||||
|
||||
if c.Value != queryState {
|
||||
return fmt.Errorf("sso: state mismatch")
|
||||
}
|
||||
|
||||
// Clear the state cookie (single-use).
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: StateCookieName(prefix),
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReturnToCookie stores the current request path so the service can
|
||||
// redirect back to it after SSO login completes.
|
||||
func SetReturnToCookie(w http.ResponseWriter, r *http.Request, prefix string) {
|
||||
path := r.URL.Path
|
||||
if path == "" || path == "/login" || path == "/sso/callback" {
|
||||
path = "/"
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: ReturnToCookieName(prefix),
|
||||
Value: path,
|
||||
Path: "/",
|
||||
MaxAge: stateCookieAge,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
}
|
||||
|
||||
// ConsumeReturnToCookie reads and clears the return-to cookie, returning
|
||||
// the path. Returns "/" if the cookie is missing or empty.
|
||||
func ConsumeReturnToCookie(w http.ResponseWriter, r *http.Request, prefix string) string {
|
||||
c, err := r.Cookie(ReturnToCookieName(prefix))
|
||||
path := "/"
|
||||
if err == nil && c.Value != "" {
|
||||
path = c.Value
|
||||
}
|
||||
|
||||
// Clear the cookie.
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: ReturnToCookieName(prefix),
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
})
|
||||
|
||||
return path
|
||||
}
|
||||
|
||||
// RedirectToLogin generates a state, sets the state and return-to cookies,
|
||||
// and redirects the user to the MCIAS authorize URL.
|
||||
func RedirectToLogin(w http.ResponseWriter, r *http.Request, client *Client, cookiePrefix string) error {
|
||||
state, err := GenerateState()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
SetStateCookie(w, cookiePrefix, state)
|
||||
SetReturnToCookie(w, r, cookiePrefix)
|
||||
http.Redirect(w, r, client.AuthorizeURL(state), http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleCallback validates the state, exchanges the authorization code for
|
||||
// a JWT, and returns the token and the return-to path. The caller should
|
||||
// set the session cookie with the returned token.
|
||||
func HandleCallback(w http.ResponseWriter, r *http.Request, client *Client, cookiePrefix string) (token, returnTo string, err error) {
|
||||
code := r.URL.Query().Get("code")
|
||||
state := r.URL.Query().Get("state")
|
||||
if code == "" || state == "" {
|
||||
return "", "", fmt.Errorf("sso: missing code or state parameter")
|
||||
}
|
||||
|
||||
if err := ValidateStateCookie(w, r, cookiePrefix, state); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
token, _, err = client.ExchangeCode(r.Context(), code)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
returnTo = ConsumeReturnToCookie(w, r, cookiePrefix)
|
||||
return token, returnTo, nil
|
||||
}
|
||||
225
sso/sso_test.go
Normal file
225
sso/sso_test.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package sso
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg Config
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid", Config{MciasURL: "https://mcias.example.com", ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb"}, false},
|
||||
{"missing url", Config{ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb"}, true},
|
||||
{"missing client_id", Config{MciasURL: "https://mcias.example.com", RedirectURI: "https://mcr.example.com/cb"}, true},
|
||||
{"missing redirect_uri", Config{MciasURL: "https://mcias.example.com", ClientID: "mcr"}, true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
_, err := New(tc.cfg)
|
||||
if tc.wantErr && err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
}
|
||||
if !tc.wantErr && err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthorizeURL(t *testing.T) {
|
||||
c, err := New(Config{
|
||||
MciasURL: "http://localhost:8443",
|
||||
ClientID: "mcr",
|
||||
RedirectURI: "https://mcr.example.com/sso/callback",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
u := c.AuthorizeURL("test-state")
|
||||
if u == "" {
|
||||
t.Fatal("AuthorizeURL returned empty string")
|
||||
}
|
||||
|
||||
// Should contain all required params.
|
||||
for _, want := range []string{"client_id=mcr", "state=test-state", "redirect_uri="} {
|
||||
if !contains(u, want) {
|
||||
t.Errorf("URL %q missing %q", u, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExchangeCode(t *testing.T) {
|
||||
// Fake MCIAS server.
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/sso/token" {
|
||||
http.Error(w, "not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Code string `json:"code"`
|
||||
ClientID string `json:"client_id"`
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if req.Code != "valid-code" {
|
||||
http.Error(w, `{"error":"invalid code"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"token": "jwt-token-here",
|
||||
"expires_at": "2026-03-30T23:00:00Z",
|
||||
})
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c, err := New(Config{
|
||||
MciasURL: srv.URL,
|
||||
ClientID: "mcr",
|
||||
RedirectURI: "https://mcr.example.com/cb",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
// Valid code.
|
||||
token, _, err := c.ExchangeCode(t.Context(), "valid-code")
|
||||
if err != nil {
|
||||
t.Fatalf("ExchangeCode: %v", err)
|
||||
}
|
||||
if token != "jwt-token-here" {
|
||||
t.Errorf("token = %q, want %q", token, "jwt-token-here")
|
||||
}
|
||||
|
||||
// Invalid code.
|
||||
_, _, err = c.ExchangeCode(t.Context(), "bad-code")
|
||||
if err == nil {
|
||||
t.Error("expected error for bad code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateState(t *testing.T) {
|
||||
s1, err := GenerateState()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateState: %v", err)
|
||||
}
|
||||
if len(s1) != 64 { // 32 bytes = 64 hex chars
|
||||
t.Errorf("state length = %d, want 64", len(s1))
|
||||
}
|
||||
|
||||
s2, err := GenerateState()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateState: %v", err)
|
||||
}
|
||||
if s1 == s2 {
|
||||
t.Error("two states should differ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateCookieRoundTrip(t *testing.T) {
|
||||
state := "test-state-value"
|
||||
rec := httptest.NewRecorder()
|
||||
SetStateCookie(rec, "mcr", state)
|
||||
|
||||
// Simulate a request with the cookie.
|
||||
req := httptest.NewRequest(http.MethodGet, "/sso/callback?state="+state, nil)
|
||||
for _, c := range rec.Result().Cookies() {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
if err := ValidateStateCookie(w, req, "mcr", state); err != nil {
|
||||
t.Fatalf("ValidateStateCookie: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStateCookieMismatch(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
SetStateCookie(rec, "mcr", "correct-state")
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/sso/callback?state=wrong-state", nil)
|
||||
for _, c := range rec.Result().Cookies() {
|
||||
req.AddCookie(c)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
if err := ValidateStateCookie(w, req, "mcr", "wrong-state"); err == nil {
|
||||
t.Error("expected error for state mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnToCookie(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/repositories/myrepo", nil)
|
||||
SetReturnToCookie(rec, req, "mcr")
|
||||
|
||||
// Read back.
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
|
||||
for _, c := range rec.Result().Cookies() {
|
||||
req2.AddCookie(c)
|
||||
}
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
path := ConsumeReturnToCookie(w2, req2, "mcr")
|
||||
if path != "/repositories/myrepo" {
|
||||
t.Errorf("return-to = %q, want %q", path, "/repositories/myrepo")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnToDefaultsToRoot(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
|
||||
w := httptest.NewRecorder()
|
||||
path := ConsumeReturnToCookie(w, req, "mcr")
|
||||
if path != "/" {
|
||||
t.Errorf("return-to = %q, want %q", path, "/")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReturnToSkipsLoginPaths(t *testing.T) {
|
||||
for _, p := range []string{"/login", "/sso/callback"} {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, p, nil)
|
||||
SetReturnToCookie(rec, req, "mcr")
|
||||
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/sso/callback", nil)
|
||||
for _, c := range rec.Result().Cookies() {
|
||||
req2.AddCookie(c)
|
||||
}
|
||||
|
||||
w2 := httptest.NewRecorder()
|
||||
path := ConsumeReturnToCookie(w2, req2, "mcr")
|
||||
if path != "/" {
|
||||
t.Errorf("return-to for %s = %q, want %q", p, path, "/")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s, sub string) bool {
|
||||
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub))
|
||||
}
|
||||
|
||||
func containsStr(s, sub string) bool {
|
||||
for i := 0; i <= len(s)-len(sub); i++ {
|
||||
if s[i:i+len(sub)] == sub {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user