diff --git a/sso/sso.go b/sso/sso.go new file mode 100644 index 0000000..d2e17dd --- /dev/null +++ b/sso/sso.go @@ -0,0 +1,304 @@ +// Package sso provides an SSO redirect client for Metacircular web services. +// +// Services redirect unauthenticated users to MCIAS for login. After +// authentication, MCIAS redirects back with an authorization code that +// the service exchanges for a JWT token. This package handles the +// redirect, state management, and code exchange. +// +// Security design: +// - State cookies use SameSite=Lax (not Strict) because the redirect from +// MCIAS back to the service is a cross-site navigation. +// - State is a 256-bit random value stored in an HttpOnly cookie. +// - Return-to URLs are stored in a separate cookie so MCIAS never sees them. +// - The code exchange is a server-to-server HTTPS call (TLS 1.3 minimum). +package sso + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "time" +) + +const ( + stateBytes = 32 // 256 bits + stateCookieAge = 5 * 60 // 5 minutes in seconds +) + +// Config holds the SSO client configuration. The values must match the +// SSO client registration in MCIAS config. +type Config struct { + // MciasURL is the base URL of the MCIAS server. + MciasURL string + + // ClientID is the registered SSO client identifier. + ClientID string + + // RedirectURI is the callback URL that MCIAS redirects to after login. + // Must exactly match the redirect_uri registered in MCIAS config. + RedirectURI string + + // CACert is an optional path to a PEM-encoded CA certificate for + // verifying the MCIAS server's TLS certificate. + CACert string +} + +// Client handles the SSO redirect flow with MCIAS. +type Client struct { + cfg Config + httpClient *http.Client +} + +// New creates an SSO client. TLS 1.3 is required for all HTTPS +// connections to MCIAS. +func New(cfg Config) (*Client, error) { + if cfg.MciasURL == "" { + return nil, fmt.Errorf("sso: mcias_url is required") + } + if cfg.ClientID == "" { + return nil, fmt.Errorf("sso: client_id is required") + } + if cfg.RedirectURI == "" { + return nil, fmt.Errorf("sso: redirect_uri is required") + } + + transport := &http.Transport{} + + if !strings.HasPrefix(cfg.MciasURL, "http://") { + tlsCfg := &tls.Config{ + MinVersion: tls.VersionTLS13, + } + + if cfg.CACert != "" { + pem, err := os.ReadFile(cfg.CACert) //nolint:gosec // CA cert path from operator config + if err != nil { + return nil, fmt.Errorf("sso: read CA cert %s: %w", cfg.CACert, err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(pem) { + return nil, fmt.Errorf("sso: no valid certificates in %s", cfg.CACert) + } + tlsCfg.RootCAs = pool + } + + transport.TLSClientConfig = tlsCfg + } + + return &Client{ + cfg: cfg, + httpClient: &http.Client{ + Transport: transport, + Timeout: 10 * time.Second, + }, + }, nil +} + +// AuthorizeURL returns the MCIAS authorize URL with the given state parameter. +func (c *Client) AuthorizeURL(state string) string { + base := strings.TrimRight(c.cfg.MciasURL, "/") + return base + "/sso/authorize?" + url.Values{ + "client_id": {c.cfg.ClientID}, + "redirect_uri": {c.cfg.RedirectURI}, + "state": {state}, + }.Encode() +} + +// ExchangeCode exchanges an authorization code for a JWT token by calling +// MCIAS POST /v1/sso/token. +func (c *Client) ExchangeCode(ctx context.Context, code string) (token string, expiresAt time.Time, err error) { + reqBody, _ := json.Marshal(map[string]string{ + "code": code, + "client_id": c.cfg.ClientID, + "redirect_uri": c.cfg.RedirectURI, + }) + + base := strings.TrimRight(c.cfg.MciasURL, "/") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, + base+"/v1/sso/token", bytes.NewReader(reqBody)) + if err != nil { + return "", time.Time{}, fmt.Errorf("sso: build exchange request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", time.Time{}, fmt.Errorf("sso: MCIAS exchange: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", time.Time{}, fmt.Errorf("sso: read exchange response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", time.Time{}, fmt.Errorf("sso: exchange failed (HTTP %d): %s", resp.StatusCode, body) + } + + var result struct { + Token string `json:"token"` + ExpiresAt string `json:"expires_at"` + } + if err := json.Unmarshal(body, &result); err != nil { + return "", time.Time{}, fmt.Errorf("sso: decode exchange response: %w", err) + } + + exp, parseErr := time.Parse(time.RFC3339, result.ExpiresAt) + if parseErr != nil { + exp = time.Now().Add(1 * time.Hour) + } + + return result.Token, exp, nil +} + +// GenerateState returns a cryptographically random hex-encoded state string. +func GenerateState() (string, error) { + raw := make([]byte, stateBytes) + if _, err := rand.Read(raw); err != nil { + return "", fmt.Errorf("sso: generate state: %w", err) + } + return hex.EncodeToString(raw), nil +} + +// StateCookieName returns the cookie name used for SSO state for a given +// service cookie prefix (e.g., "mcr" → "mcr_sso_state"). +func StateCookieName(prefix string) string { + return prefix + "_sso_state" +} + +// ReturnToCookieName returns the cookie name used for SSO return-to URL +// (e.g., "mcr" → "mcr_sso_return"). +func ReturnToCookieName(prefix string) string { + return prefix + "_sso_return" +} + +// SetStateCookie stores the SSO state in a short-lived cookie. +// +// Security: SameSite=Lax is required because the redirect from MCIAS back to +// the service is a cross-site top-level navigation. SameSite=Strict cookies +// would not be sent on that redirect. +func SetStateCookie(w http.ResponseWriter, prefix, state string) { + http.SetCookie(w, &http.Cookie{ + Name: StateCookieName(prefix), + Value: state, + Path: "/", + MaxAge: stateCookieAge, + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) +} + +// ValidateStateCookie compares the state query parameter against the state +// cookie. If they match, the cookie is cleared and nil is returned. +func ValidateStateCookie(w http.ResponseWriter, r *http.Request, prefix, queryState string) error { + c, err := r.Cookie(StateCookieName(prefix)) + if err != nil || c.Value == "" { + return fmt.Errorf("sso: missing state cookie") + } + + if c.Value != queryState { + return fmt.Errorf("sso: state mismatch") + } + + // Clear the state cookie (single-use). + http.SetCookie(w, &http.Cookie{ + Name: StateCookieName(prefix), + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + + return nil +} + +// SetReturnToCookie stores the current request path so the service can +// redirect back to it after SSO login completes. +func SetReturnToCookie(w http.ResponseWriter, r *http.Request, prefix string) { + path := r.URL.Path + if path == "" || path == "/login" || path == "/sso/callback" { + path = "/" + } + http.SetCookie(w, &http.Cookie{ + Name: ReturnToCookieName(prefix), + Value: path, + Path: "/", + MaxAge: stateCookieAge, + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) +} + +// ConsumeReturnToCookie reads and clears the return-to cookie, returning +// the path. Returns "/" if the cookie is missing or empty. +func ConsumeReturnToCookie(w http.ResponseWriter, r *http.Request, prefix string) string { + c, err := r.Cookie(ReturnToCookieName(prefix)) + path := "/" + if err == nil && c.Value != "" { + path = c.Value + } + + // Clear the cookie. + http.SetCookie(w, &http.Cookie{ + Name: ReturnToCookieName(prefix), + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + }) + + return path +} + +// RedirectToLogin generates a state, sets the state and return-to cookies, +// and redirects the user to the MCIAS authorize URL. +func RedirectToLogin(w http.ResponseWriter, r *http.Request, client *Client, cookiePrefix string) error { + state, err := GenerateState() + if err != nil { + return err + } + + SetStateCookie(w, cookiePrefix, state) + SetReturnToCookie(w, r, cookiePrefix) + http.Redirect(w, r, client.AuthorizeURL(state), http.StatusFound) + return nil +} + +// HandleCallback validates the state, exchanges the authorization code for +// a JWT, and returns the token and the return-to path. The caller should +// set the session cookie with the returned token. +func HandleCallback(w http.ResponseWriter, r *http.Request, client *Client, cookiePrefix string) (token, returnTo string, err error) { + code := r.URL.Query().Get("code") + state := r.URL.Query().Get("state") + if code == "" || state == "" { + return "", "", fmt.Errorf("sso: missing code or state parameter") + } + + if err := ValidateStateCookie(w, r, cookiePrefix, state); err != nil { + return "", "", err + } + + token, _, err = client.ExchangeCode(r.Context(), code) + if err != nil { + return "", "", err + } + + returnTo = ConsumeReturnToCookie(w, r, cookiePrefix) + return token, returnTo, nil +} diff --git a/sso/sso_test.go b/sso/sso_test.go new file mode 100644 index 0000000..568156b --- /dev/null +++ b/sso/sso_test.go @@ -0,0 +1,225 @@ +package sso + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewValidation(t *testing.T) { + tests := []struct { + name string + cfg Config + wantErr bool + }{ + {"valid", Config{MciasURL: "https://mcias.example.com", ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb"}, false}, + {"missing url", Config{ClientID: "mcr", RedirectURI: "https://mcr.example.com/cb"}, true}, + {"missing client_id", Config{MciasURL: "https://mcias.example.com", RedirectURI: "https://mcr.example.com/cb"}, true}, + {"missing redirect_uri", Config{MciasURL: "https://mcias.example.com", ClientID: "mcr"}, true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := New(tc.cfg) + if tc.wantErr && err == nil { + t.Error("expected error, got nil") + } + if !tc.wantErr && err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +func TestAuthorizeURL(t *testing.T) { + c, err := New(Config{ + MciasURL: "http://localhost:8443", + ClientID: "mcr", + RedirectURI: "https://mcr.example.com/sso/callback", + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + u := c.AuthorizeURL("test-state") + if u == "" { + t.Fatal("AuthorizeURL returned empty string") + } + + // Should contain all required params. + for _, want := range []string{"client_id=mcr", "state=test-state", "redirect_uri="} { + if !contains(u, want) { + t.Errorf("URL %q missing %q", u, want) + } + } +} + +func TestExchangeCode(t *testing.T) { + // Fake MCIAS server. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/sso/token" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + var req struct { + Code string `json:"code"` + ClientID string `json:"client_id"` + RedirectURI string `json:"redirect_uri"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + if req.Code != "valid-code" { + http.Error(w, `{"error":"invalid code"}`, http.StatusUnauthorized) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "token": "jwt-token-here", + "expires_at": "2026-03-30T23:00:00Z", + }) + })) + defer srv.Close() + + c, err := New(Config{ + MciasURL: srv.URL, + ClientID: "mcr", + RedirectURI: "https://mcr.example.com/cb", + }) + if err != nil { + t.Fatalf("New: %v", err) + } + + // Valid code. + token, _, err := c.ExchangeCode(t.Context(), "valid-code") + if err != nil { + t.Fatalf("ExchangeCode: %v", err) + } + if token != "jwt-token-here" { + t.Errorf("token = %q, want %q", token, "jwt-token-here") + } + + // Invalid code. + _, _, err = c.ExchangeCode(t.Context(), "bad-code") + if err == nil { + t.Error("expected error for bad code") + } +} + +func TestGenerateState(t *testing.T) { + s1, err := GenerateState() + if err != nil { + t.Fatalf("GenerateState: %v", err) + } + if len(s1) != 64 { // 32 bytes = 64 hex chars + t.Errorf("state length = %d, want 64", len(s1)) + } + + s2, err := GenerateState() + if err != nil { + t.Fatalf("GenerateState: %v", err) + } + if s1 == s2 { + t.Error("two states should differ") + } +} + +func TestStateCookieRoundTrip(t *testing.T) { + state := "test-state-value" + rec := httptest.NewRecorder() + SetStateCookie(rec, "mcr", state) + + // Simulate a request with the cookie. + req := httptest.NewRequest(http.MethodGet, "/sso/callback?state="+state, nil) + for _, c := range rec.Result().Cookies() { + req.AddCookie(c) + } + + w := httptest.NewRecorder() + if err := ValidateStateCookie(w, req, "mcr", state); err != nil { + t.Fatalf("ValidateStateCookie: %v", err) + } +} + +func TestStateCookieMismatch(t *testing.T) { + rec := httptest.NewRecorder() + SetStateCookie(rec, "mcr", "correct-state") + + req := httptest.NewRequest(http.MethodGet, "/sso/callback?state=wrong-state", nil) + for _, c := range rec.Result().Cookies() { + req.AddCookie(c) + } + + w := httptest.NewRecorder() + if err := ValidateStateCookie(w, req, "mcr", "wrong-state"); err == nil { + t.Error("expected error for state mismatch") + } +} + +func TestReturnToCookie(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/repositories/myrepo", nil) + SetReturnToCookie(rec, req, "mcr") + + // Read back. + req2 := httptest.NewRequest(http.MethodGet, "/sso/callback", nil) + for _, c := range rec.Result().Cookies() { + req2.AddCookie(c) + } + + w2 := httptest.NewRecorder() + path := ConsumeReturnToCookie(w2, req2, "mcr") + if path != "/repositories/myrepo" { + t.Errorf("return-to = %q, want %q", path, "/repositories/myrepo") + } +} + +func TestReturnToDefaultsToRoot(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/sso/callback", nil) + w := httptest.NewRecorder() + path := ConsumeReturnToCookie(w, req, "mcr") + if path != "/" { + t.Errorf("return-to = %q, want %q", path, "/") + } +} + +func TestReturnToSkipsLoginPaths(t *testing.T) { + for _, p := range []string{"/login", "/sso/callback"} { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, p, nil) + SetReturnToCookie(rec, req, "mcr") + + req2 := httptest.NewRequest(http.MethodGet, "/sso/callback", nil) + for _, c := range rec.Result().Cookies() { + req2.AddCookie(c) + } + + w2 := httptest.NewRecorder() + path := ConsumeReturnToCookie(w2, req2, "mcr") + if path != "/" { + t.Errorf("return-to for %s = %q, want %q", p, path, "/") + } + } +} + +func contains(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub)) +} + +func containsStr(s, sub string) bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false +}