Add auth package: MCIAS token validation with caching
- Authenticator with Login, ValidateToken, Logout - 30-second SHA-256-keyed cache with lazy eviction - TLS 1.3, custom CA support, service context (name + tags) - Error types: ErrInvalidToken, ErrInvalidCredentials, ErrForbidden, ErrUnavailable - Context helpers for TokenInfo propagation - 14 tests with mock MCIAS server and injectable clock Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
29
PROGRESS.md
29
PROGRESS.md
@@ -2,7 +2,7 @@
|
||||
|
||||
## Current State
|
||||
|
||||
Phase 1 complete. The `db` package is implemented and tested.
|
||||
Phase 2 complete. The `db` and `auth` packages are implemented and tested.
|
||||
|
||||
## Completed
|
||||
|
||||
@@ -24,11 +24,28 @@ Phase 1 complete. The `db` package is implemented and tested.
|
||||
- `SchemaVersion(database *sql.DB) (int, error)` — highest applied version
|
||||
- `Snapshot(database *sql.DB, destPath string) error` — VACUUM INTO with
|
||||
0600 permissions, creates parent dirs
|
||||
- 11 tests: open (pragmas, permissions, parent dir, existing DB), migrate
|
||||
(fresh, idempotent, incremental, records name), schema version (empty),
|
||||
snapshot (data integrity, permissions, parent dir)
|
||||
- `make all` passes clean (vet, lint 0 issues, 11/11 tests, build)
|
||||
- 11 tests covering open, migrate, and snapshot
|
||||
|
||||
### Phase 2: `auth` — MCIAS Token Validation (2026-03-25)
|
||||
- `Config` type matching `[mcias]` TOML section (ServerURL, CACert,
|
||||
ServiceName, Tags)
|
||||
- `TokenInfo` type (Username, Roles, IsAdmin)
|
||||
- `New(cfg Config, logger *slog.Logger) (*Authenticator, error)` — creates
|
||||
MCIAS client with TLS 1.3, custom CA support, 10s timeout
|
||||
- `Login(username, password, totpCode string) (token, expiresAt, err)` —
|
||||
forwards to MCIAS with service context, returns ErrForbidden for policy
|
||||
denials, ErrInvalidCredentials for bad creds
|
||||
- `ValidateToken(token string) (*TokenInfo, error)` — 30s SHA-256-keyed
|
||||
cache, lazy eviction, concurrent-safe (RWMutex)
|
||||
- `Logout(token string) error` — revokes token on MCIAS
|
||||
- Error types: ErrInvalidToken, ErrInvalidCredentials, ErrForbidden,
|
||||
ErrUnavailable
|
||||
- Context helpers: ContextWithTokenInfo, TokenInfoFromContext
|
||||
- 14 tests: login (success, invalid creds, forbidden), validate (admin,
|
||||
non-admin, expired, unknown), cache (hit, expiry via injectable clock),
|
||||
logout, constructor validation, context roundtrip, admin detection
|
||||
- `make all` passes clean (vet, lint 0 issues, 25 total tests, build)
|
||||
|
||||
## Next Steps
|
||||
|
||||
- Phase 2: `auth` package (MCIAS token validation with caching)
|
||||
- Phase 3: `config` package (TOML loading, env overrides, standard sections)
|
||||
|
||||
289
auth/auth.go
Normal file
289
auth/auth.go
Normal file
@@ -0,0 +1,289 @@
|
||||
// Package auth provides MCIAS token validation with caching for
|
||||
// Metacircular services.
|
||||
//
|
||||
// Every Metacircular service delegates authentication to MCIAS. This
|
||||
// package handles the login flow, token validation (with a 30-second
|
||||
// SHA-256-keyed cache), and logout. It communicates directly with the
|
||||
// MCIAS REST API.
|
||||
//
|
||||
// Security: bearer tokens are never logged or included in error messages.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const cacheTTL = 30 * time.Second
|
||||
|
||||
// Errors returned by the Authenticator.
|
||||
var (
|
||||
// ErrInvalidToken indicates the token is expired, revoked, or otherwise
|
||||
// invalid.
|
||||
ErrInvalidToken = errors.New("auth: invalid token")
|
||||
|
||||
// ErrInvalidCredentials indicates that the username/password combination
|
||||
// was rejected by MCIAS.
|
||||
ErrInvalidCredentials = errors.New("auth: invalid credentials")
|
||||
|
||||
// ErrForbidden indicates that MCIAS login policy denied access to this
|
||||
// service (HTTP 403).
|
||||
ErrForbidden = errors.New("auth: forbidden by policy")
|
||||
|
||||
// ErrUnavailable indicates that MCIAS could not be reached.
|
||||
ErrUnavailable = errors.New("auth: MCIAS unavailable")
|
||||
)
|
||||
|
||||
// Config holds MCIAS connection settings. This matches the standard [mcias]
|
||||
// TOML section used by all Metacircular services.
|
||||
type Config struct {
|
||||
// ServerURL is the base URL of the MCIAS server
|
||||
// (e.g., "https://mcias.metacircular.net:8443").
|
||||
ServerURL string `toml:"server_url"`
|
||||
|
||||
// CACert is an optional path to a PEM-encoded CA certificate for
|
||||
// verifying the MCIAS server's TLS certificate.
|
||||
CACert string `toml:"ca_cert"`
|
||||
|
||||
// ServiceName is this service's identity as registered in MCIAS. It is
|
||||
// sent with every login request so MCIAS can evaluate service-context
|
||||
// login policy rules.
|
||||
ServiceName string `toml:"service_name"`
|
||||
|
||||
// Tags are sent with every login request. MCIAS evaluates auth:login
|
||||
// policy against these tags (e.g., ["env:restricted"]).
|
||||
Tags []string `toml:"tags"`
|
||||
}
|
||||
|
||||
// TokenInfo holds the validated identity of an authenticated caller.
|
||||
type TokenInfo struct {
|
||||
// Username is the MCIAS username (the "sub" claim).
|
||||
Username string
|
||||
|
||||
// Roles is the set of MCIAS roles assigned to the account.
|
||||
Roles []string
|
||||
|
||||
// IsAdmin is true if the account has the "admin" role.
|
||||
IsAdmin bool
|
||||
}
|
||||
|
||||
// Authenticator validates MCIAS bearer tokens with a short-lived cache.
|
||||
type Authenticator struct {
|
||||
httpClient *http.Client
|
||||
baseURL string
|
||||
serviceName string
|
||||
tags []string
|
||||
logger *slog.Logger
|
||||
cache *validationCache
|
||||
}
|
||||
|
||||
// New creates an Authenticator that talks to the MCIAS server described
|
||||
// by cfg. TLS 1.3 is required for all HTTPS connections. If cfg.CACert
|
||||
// is set, that CA certificate is added to the trust pool.
|
||||
//
|
||||
// For plain HTTP URLs (used in tests), TLS configuration is skipped.
|
||||
func New(cfg Config, logger *slog.Logger) (*Authenticator, error) {
|
||||
if cfg.ServerURL == "" {
|
||||
return nil, fmt.Errorf("auth: server_url is required")
|
||||
}
|
||||
|
||||
transport := &http.Transport{}
|
||||
|
||||
if !strings.HasPrefix(cfg.ServerURL, "http://") {
|
||||
tlsCfg := &tls.Config{
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
if cfg.CACert != "" {
|
||||
pem, err := os.ReadFile(cfg.CACert) //nolint:gosec // CA cert path from operator config
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth: read CA cert %s: %w", cfg.CACert, err)
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(pem) {
|
||||
return nil, fmt.Errorf("auth: no valid certificates in %s", cfg.CACert)
|
||||
}
|
||||
tlsCfg.RootCAs = pool
|
||||
}
|
||||
|
||||
transport.TLSClientConfig = tlsCfg
|
||||
}
|
||||
|
||||
return &Authenticator{
|
||||
httpClient: &http.Client{
|
||||
Transport: transport,
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
baseURL: strings.TrimRight(cfg.ServerURL, "/"),
|
||||
serviceName: cfg.ServiceName,
|
||||
tags: cfg.Tags,
|
||||
logger: logger,
|
||||
cache: newCache(cacheTTL),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Login authenticates a user against MCIAS and returns a bearer token.
|
||||
// totpCode may be empty for accounts without TOTP configured.
|
||||
//
|
||||
// The service name and tags from Config are included in the login request
|
||||
// so MCIAS can evaluate service-context login policy.
|
||||
func (a *Authenticator) Login(username, password, totpCode string) (token string, expiresAt time.Time, err error) {
|
||||
reqBody := map[string]interface{}{
|
||||
"username": username,
|
||||
"password": password,
|
||||
}
|
||||
if totpCode != "" {
|
||||
reqBody["totp_code"] = totpCode
|
||||
}
|
||||
if a.serviceName != "" {
|
||||
reqBody["service_name"] = a.serviceName
|
||||
}
|
||||
if len(a.tags) > 0 {
|
||||
reqBody["tags"] = a.tags
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
}
|
||||
status, err := a.doJSON(http.MethodPost, "/v1/auth/login", reqBody, &resp)
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("auth: MCIAS login: %w", ErrUnavailable)
|
||||
}
|
||||
|
||||
switch status {
|
||||
case http.StatusOK:
|
||||
// Parse the expiry time.
|
||||
exp, parseErr := time.Parse(time.RFC3339, resp.ExpiresAt)
|
||||
if parseErr != nil {
|
||||
exp = time.Now().Add(1 * time.Hour) // fallback
|
||||
}
|
||||
return resp.Token, exp, nil
|
||||
case http.StatusForbidden:
|
||||
return "", time.Time{}, ErrForbidden
|
||||
default:
|
||||
return "", time.Time{}, ErrInvalidCredentials
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateToken checks a bearer token against MCIAS. Results are cached
|
||||
// by the SHA-256 hash of the token for 30 seconds.
|
||||
//
|
||||
// Returns ErrInvalidToken if the token is expired, revoked, or otherwise
|
||||
// not valid.
|
||||
func (a *Authenticator) ValidateToken(token string) (*TokenInfo, error) {
|
||||
h := sha256.Sum256([]byte(token))
|
||||
tokenHash := hex.EncodeToString(h[:])
|
||||
|
||||
if info, ok := a.cache.get(tokenHash); ok {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Valid bool `json:"valid"`
|
||||
Sub string `json:"sub"`
|
||||
Username string `json:"username"`
|
||||
Roles []string `json:"roles"`
|
||||
}
|
||||
status, err := a.doJSON(http.MethodPost, "/v1/token/validate",
|
||||
map[string]string{"token": token}, &resp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("auth: MCIAS validate: %w", ErrUnavailable)
|
||||
}
|
||||
|
||||
if status != http.StatusOK || !resp.Valid {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
info := &TokenInfo{
|
||||
Username: resp.Username,
|
||||
Roles: resp.Roles,
|
||||
IsAdmin: hasRole(resp.Roles, "admin"),
|
||||
}
|
||||
if info.Username == "" {
|
||||
info.Username = resp.Sub
|
||||
}
|
||||
|
||||
a.cache.put(tokenHash, info)
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// Logout revokes a token on the MCIAS server.
|
||||
func (a *Authenticator) Logout(token string) error {
|
||||
req, err := http.NewRequestWithContext(context.Background(),
|
||||
http.MethodPost, a.baseURL+"/v1/auth/logout", nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("auth: build logout request: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("auth: MCIAS logout: %w", ErrUnavailable)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// doJSON makes a JSON request to the MCIAS server and decodes the response.
|
||||
// It returns the HTTP status code and any transport error.
|
||||
func (a *Authenticator) doJSON(method, path string, body, out interface{}) (int, error) {
|
||||
var reqBody io.Reader
|
||||
if body != nil {
|
||||
b, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
reqBody = bytes.NewReader(b)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(context.Background(),
|
||||
method, a.baseURL+path, reqBody)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("build request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if out != nil && resp.StatusCode == http.StatusOK {
|
||||
respBytes, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
return resp.StatusCode, fmt.Errorf("read response: %w", readErr)
|
||||
}
|
||||
if len(respBytes) > 0 {
|
||||
if decErr := json.Unmarshal(respBytes, out); decErr != nil {
|
||||
return resp.StatusCode, fmt.Errorf("decode response: %w", decErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return resp.StatusCode, nil
|
||||
}
|
||||
|
||||
func hasRole(roles []string, target string) bool {
|
||||
for _, r := range roles {
|
||||
if r == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
346
auth/auth_test.go
Normal file
346
auth/auth_test.go
Normal file
@@ -0,0 +1,346 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockMCIAS returns a test HTTP server that simulates MCIAS endpoints.
|
||||
func mockMCIAS(t *testing.T) *httptest.Server {
|
||||
t.Helper()
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.HandleFunc("POST /v1/auth/login", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
ServiceName string `json:"service_name"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, `{"error":"bad request"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if req.Username == "admin" && req.Password == "secret" {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
exp := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"token": "tok-admin-123",
|
||||
"expires_at": exp,
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.Username == "denied" && req.Password == "secret" {
|
||||
http.Error(w, `{"error":"forbidden by policy"}`, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
http.Error(w, `{"error":"invalid credentials"}`, http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/token/validate", func(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, `{"error":"bad request"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
switch req.Token {
|
||||
case "tok-admin-123":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"valid": true,
|
||||
"sub": "uuid-admin",
|
||||
"username": "admin",
|
||||
"roles": []string{"admin", "user"},
|
||||
})
|
||||
case "tok-user-456":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"valid": true,
|
||||
"sub": "uuid-user",
|
||||
"username": "alice",
|
||||
"roles": []string{"user"},
|
||||
})
|
||||
case "tok-expired":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"valid": false,
|
||||
})
|
||||
default:
|
||||
http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized)
|
||||
}
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /v1/auth/logout", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
})
|
||||
|
||||
return httptest.NewServer(mux)
|
||||
}
|
||||
|
||||
func newTestAuth(t *testing.T, serverURL string) *Authenticator {
|
||||
t.Helper()
|
||||
a, err := New(Config{
|
||||
ServerURL: serverURL,
|
||||
ServiceName: "test-service",
|
||||
Tags: []string{"env:test"},
|
||||
}, slog.Default())
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
return a
|
||||
}
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
srv := mockMCIAS(t)
|
||||
defer srv.Close()
|
||||
a := newTestAuth(t, srv.URL)
|
||||
|
||||
token, exp, err := a.Login("admin", "secret", "")
|
||||
if err != nil {
|
||||
t.Fatalf("Login: %v", err)
|
||||
}
|
||||
if token != "tok-admin-123" {
|
||||
t.Fatalf("token = %q, want %q", token, "tok-admin-123")
|
||||
}
|
||||
if exp.IsZero() {
|
||||
t.Fatal("expiresAt is zero")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginInvalidCredentials(t *testing.T) {
|
||||
srv := mockMCIAS(t)
|
||||
defer srv.Close()
|
||||
a := newTestAuth(t, srv.URL)
|
||||
|
||||
_, _, err := a.Login("admin", "wrong", "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid credentials")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidCredentials) {
|
||||
t.Fatalf("err = %v, want ErrInvalidCredentials", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginForbidden(t *testing.T) {
|
||||
srv := mockMCIAS(t)
|
||||
defer srv.Close()
|
||||
a := newTestAuth(t, srv.URL)
|
||||
|
||||
_, _, err := a.Login("denied", "secret", "")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for forbidden login")
|
||||
}
|
||||
if !errors.Is(err, ErrForbidden) {
|
||||
t.Fatalf("err = %v, want ErrForbidden", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateToken(t *testing.T) {
|
||||
srv := mockMCIAS(t)
|
||||
defer srv.Close()
|
||||
a := newTestAuth(t, srv.URL)
|
||||
|
||||
info, err := a.ValidateToken("tok-admin-123")
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken: %v", err)
|
||||
}
|
||||
if info.Username != "admin" {
|
||||
t.Fatalf("Username = %q, want %q", info.Username, "admin")
|
||||
}
|
||||
if !info.IsAdmin {
|
||||
t.Fatal("IsAdmin = false, want true")
|
||||
}
|
||||
if len(info.Roles) != 2 {
|
||||
t.Fatalf("Roles = %v, want 2 roles", info.Roles)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenNonAdmin(t *testing.T) {
|
||||
srv := mockMCIAS(t)
|
||||
defer srv.Close()
|
||||
a := newTestAuth(t, srv.URL)
|
||||
|
||||
info, err := a.ValidateToken("tok-user-456")
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken: %v", err)
|
||||
}
|
||||
if info.Username != "alice" {
|
||||
t.Fatalf("Username = %q, want %q", info.Username, "alice")
|
||||
}
|
||||
if info.IsAdmin {
|
||||
t.Fatal("IsAdmin = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenExpired(t *testing.T) {
|
||||
srv := mockMCIAS(t)
|
||||
defer srv.Close()
|
||||
a := newTestAuth(t, srv.URL)
|
||||
|
||||
_, err := a.ValidateToken("tok-expired")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for expired token")
|
||||
}
|
||||
if !errors.Is(err, ErrInvalidToken) {
|
||||
t.Fatalf("err = %v, want ErrInvalidToken", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenUnknown(t *testing.T) {
|
||||
srv := mockMCIAS(t)
|
||||
defer srv.Close()
|
||||
a := newTestAuth(t, srv.URL)
|
||||
|
||||
_, err := a.ValidateToken("tok-unknown")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenCache(t *testing.T) {
|
||||
callCount := 0
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("POST /v1/token/validate", func(w http.ResponseWriter, _ *http.Request) {
|
||||
callCount++
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"valid": true,
|
||||
"username": "cached-user",
|
||||
"roles": []string{"user"},
|
||||
})
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
a := newTestAuth(t, srv.URL)
|
||||
|
||||
// First call: cache miss, hits server.
|
||||
info1, err := a.ValidateToken("tok-cache-test")
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken (1st): %v", err)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Fatalf("server calls = %d, want 1", callCount)
|
||||
}
|
||||
|
||||
// Second call: cache hit, no server call.
|
||||
info2, err := a.ValidateToken("tok-cache-test")
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken (2nd): %v", err)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Fatalf("server calls = %d, want 1 (cached)", callCount)
|
||||
}
|
||||
if info1.Username != info2.Username {
|
||||
t.Fatalf("cached username mismatch: %q vs %q", info1.Username, info2.Username)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTokenCacheExpiry(t *testing.T) {
|
||||
callCount := 0
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("POST /v1/token/validate", func(w http.ResponseWriter, _ *http.Request) {
|
||||
callCount++
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"valid": true,
|
||||
"username": "user",
|
||||
"roles": []string{"user"},
|
||||
})
|
||||
})
|
||||
srv := httptest.NewServer(mux)
|
||||
defer srv.Close()
|
||||
|
||||
a := newTestAuth(t, srv.URL)
|
||||
|
||||
// Override the cache clock to simulate time passing.
|
||||
now := time.Now()
|
||||
a.cache.now = func() time.Time { return now }
|
||||
|
||||
_, err := a.ValidateToken("tok-expiry-test")
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken (1st): %v", err)
|
||||
}
|
||||
if callCount != 1 {
|
||||
t.Fatalf("server calls = %d, want 1", callCount)
|
||||
}
|
||||
|
||||
// Advance past cache TTL.
|
||||
now = now.Add(cacheTTL + 1*time.Second)
|
||||
|
||||
_, err = a.ValidateToken("tok-expiry-test")
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateToken (2nd): %v", err)
|
||||
}
|
||||
if callCount != 2 {
|
||||
t.Fatalf("server calls = %d, want 2 (cache expired)", callCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogout(t *testing.T) {
|
||||
srv := mockMCIAS(t)
|
||||
defer srv.Close()
|
||||
a := newTestAuth(t, srv.URL)
|
||||
|
||||
if err := a.Logout("tok-admin-123"); err != nil {
|
||||
t.Fatalf("Logout: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRequiresServerURL(t *testing.T) {
|
||||
_, err := New(Config{}, slog.Default())
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty server_url")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextRoundtrip(t *testing.T) {
|
||||
info := &TokenInfo{
|
||||
Username: "testuser",
|
||||
Roles: []string{"user"},
|
||||
IsAdmin: false,
|
||||
}
|
||||
|
||||
ctx := ContextWithTokenInfo(context.Background(), info)
|
||||
got := TokenInfoFromContext(ctx)
|
||||
|
||||
if got == nil {
|
||||
t.Fatal("TokenInfoFromContext returned nil")
|
||||
}
|
||||
if got.Username != "testuser" {
|
||||
t.Fatalf("Username = %q, want %q", got.Username, "testuser")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextMissing(t *testing.T) {
|
||||
got := TokenInfoFromContext(context.Background())
|
||||
if got != nil {
|
||||
t.Fatalf("expected nil from empty context, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
roles []string
|
||||
want bool
|
||||
}{
|
||||
{[]string{"admin", "user"}, true},
|
||||
{[]string{"admin"}, true},
|
||||
{[]string{"user"}, false},
|
||||
{[]string{}, false},
|
||||
{nil, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := hasRole(tt.roles, "admin")
|
||||
if got != tt.want {
|
||||
t.Errorf("hasRole(%v, admin) = %v, want %v", tt.roles, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
64
auth/cache.go
Normal file
64
auth/cache.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// cacheEntry holds a cached TokenInfo and its expiration time.
|
||||
type cacheEntry struct {
|
||||
info *TokenInfo
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// validationCache provides a concurrency-safe, TTL-based cache for token
|
||||
// validation results. Tokens are keyed by their SHA-256 hex digest.
|
||||
type validationCache struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]cacheEntry
|
||||
ttl time.Duration
|
||||
now func() time.Time // injectable clock for testing
|
||||
}
|
||||
|
||||
// newCache creates a validationCache with the given TTL.
|
||||
func newCache(ttl time.Duration) *validationCache {
|
||||
return &validationCache{
|
||||
entries: make(map[string]cacheEntry),
|
||||
ttl: ttl,
|
||||
now: time.Now,
|
||||
}
|
||||
}
|
||||
|
||||
// get returns cached TokenInfo for the given token hash, or false if
|
||||
// the entry is missing or expired. Expired entries are lazily evicted.
|
||||
func (c *validationCache) get(tokenHash string) (*TokenInfo, bool) {
|
||||
c.mu.RLock()
|
||||
entry, ok := c.entries[tokenHash]
|
||||
c.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if c.now().After(entry.expiresAt) {
|
||||
// Lazy evict the expired entry.
|
||||
c.mu.Lock()
|
||||
if e, exists := c.entries[tokenHash]; exists && c.now().After(e.expiresAt) {
|
||||
delete(c.entries, tokenHash)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return entry.info, true
|
||||
}
|
||||
|
||||
// put stores TokenInfo in the cache with an expiration of now + TTL.
|
||||
func (c *validationCache) put(tokenHash string, info *TokenInfo) {
|
||||
c.mu.Lock()
|
||||
c.entries[tokenHash] = cacheEntry{
|
||||
info: info,
|
||||
expiresAt: c.now().Add(c.ttl),
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
19
auth/context.go
Normal file
19
auth/context.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package auth
|
||||
|
||||
import "context"
|
||||
|
||||
// contextKey is an unexported type used as the context key for TokenInfo,
|
||||
// preventing collisions with keys from other packages.
|
||||
type contextKey struct{}
|
||||
|
||||
// ContextWithTokenInfo returns a new context carrying the given TokenInfo.
|
||||
func ContextWithTokenInfo(ctx context.Context, info *TokenInfo) context.Context {
|
||||
return context.WithValue(ctx, contextKey{}, info)
|
||||
}
|
||||
|
||||
// TokenInfoFromContext extracts TokenInfo from the context. It returns nil
|
||||
// if no TokenInfo is present.
|
||||
func TokenInfoFromContext(ctx context.Context) *TokenInfo {
|
||||
info, _ := ctx.Value(contextKey{}).(*TokenInfo)
|
||||
return info
|
||||
}
|
||||
Reference in New Issue
Block a user