diff --git a/.claude/skills/checkpoint/SKILL.md b/.claude/skills/checkpoint/SKILL.md index cae0149..437dfeb 100644 --- a/.claude/skills/checkpoint/SKILL.md +++ b/.claude/skills/checkpoint/SKILL.md @@ -4,5 +4,5 @@ 2. Run `go test ./...` abort if failures 3. Run `go vet ./...` 4. Run `git add -A && git status` show user what will be committed -5. Ask user for commit message +5. Generate an appropriate commit message based on your instructions. 6. Run `git commit -m ""` and verify with `git log -1` \ No newline at end of file diff --git a/.gitignore b/.gitignore index c27aa0f..5f470ab 100644 --- a/.gitignore +++ b/.gitignore @@ -34,5 +34,10 @@ clients/python/*.egg-info/ clients/lisp/**/*.fasl # manual testing -/run/ +run/ .env +/cmd/mciasctl/mciasctl +/cmd/mciasdb/mciasdb +/cmd/mciasgrpcctl/mciasgrpcctl +/cmd/mciassrv/mciassrv + diff --git a/clients/go/client.go b/clients/go/client.go index 272bb1d..490d867 100644 --- a/clients/go/client.go +++ b/clients/go/client.go @@ -3,6 +3,7 @@ // Security: bearer tokens are stored under a sync.RWMutex and are never written // to logs or included in error messages anywhere in this package. package mciasgoclient + import ( "bytes" "crypto/tls" @@ -15,32 +16,43 @@ import ( "strings" "sync" ) + // --------------------------------------------------------------------------- // Error types // --------------------------------------------------------------------------- + // MciasError is the base error type for all MCIAS client errors. type MciasError struct { StatusCode int Message string } + func (e *MciasError) Error() string { return fmt.Sprintf("mciasgoclient: HTTP %d: %s", e.StatusCode, e.Message) } + // MciasAuthError is returned for 401 Unauthorized responses. type MciasAuthError struct{ MciasError } + // MciasForbiddenError is returned for 403 Forbidden responses. type MciasForbiddenError struct{ MciasError } + // MciasNotFoundError is returned for 404 Not Found responses. type MciasNotFoundError struct{ MciasError } + // MciasInputError is returned for 400 Bad Request responses. type MciasInputError struct{ MciasError } + // MciasConflictError is returned for 409 Conflict responses. type MciasConflictError struct{ MciasError } + // MciasServerError is returned for 5xx responses. type MciasServerError struct{ MciasError } + // --------------------------------------------------------------------------- // Data types // --------------------------------------------------------------------------- + // Account represents a user or service account. type Account struct { ID string `json:"id"` @@ -51,6 +63,7 @@ type Account struct { UpdatedAt string `json:"updated_at"` TOTPEnabled bool `json:"totp_enabled"` } + // PublicKey represents the server's Ed25519 public key in JWK format. type PublicKey struct { Kty string `json:"kty"` @@ -59,6 +72,7 @@ type PublicKey struct { Use string `json:"use,omitempty"` Alg string `json:"alg,omitempty"` } + // TokenClaims is returned by ValidateToken. type TokenClaims struct { Valid bool `json:"valid"` @@ -66,6 +80,7 @@ type TokenClaims struct { Roles []string `json:"roles,omitempty"` ExpiresAt string `json:"expires_at,omitempty"` } + // PGCreds holds Postgres connection credentials. type PGCreds struct { Host string `json:"host"` @@ -74,9 +89,94 @@ type PGCreds struct { Username string `json:"username"` Password string `json:"password"` } + +// TOTPEnrollResponse is returned by EnrollTOTP. +type TOTPEnrollResponse struct { + Secret string `json:"secret"` + OTPAuthURI string `json:"otpauth_uri"` +} + +// AuditEvent is a single entry in the audit log. +type AuditEvent struct { + ID int `json:"id"` + EventType string `json:"event_type"` + EventTime string `json:"event_time"` + ActorID string `json:"actor_id,omitempty"` + TargetID string `json:"target_id,omitempty"` + IPAddress string `json:"ip_address"` + Details string `json:"details,omitempty"` +} + +// AuditListResponse is returned by ListAudit. +type AuditListResponse struct { + Events []AuditEvent `json:"events"` + Total int `json:"total"` + Limit int `json:"limit"` + Offset int `json:"offset"` +} + +// AuditFilter holds optional filter parameters for ListAudit. +type AuditFilter struct { + Limit int + Offset int + EventType string + ActorID string +} + +// PolicyRuleBody holds the match conditions and effect of a policy rule. +// All fields except Effect are optional; an omitted field acts as a wildcard. +type PolicyRuleBody struct { + Effect string `json:"effect"` + Roles []string `json:"roles,omitempty"` + AccountTypes []string `json:"account_types,omitempty"` + SubjectUUID string `json:"subject_uuid,omitempty"` + Actions []string `json:"actions,omitempty"` + ResourceType string `json:"resource_type,omitempty"` + OwnerMatchesSubject bool `json:"owner_matches_subject,omitempty"` + ServiceNames []string `json:"service_names,omitempty"` + RequiredTags []string `json:"required_tags,omitempty"` +} + +// PolicyRule is a complete operator-defined policy rule as returned by the API. +type PolicyRule struct { + ID int `json:"id"` + Priority int `json:"priority"` + Description string `json:"description"` + Rule PolicyRuleBody `json:"rule"` + Enabled bool `json:"enabled"` + NotBefore string `json:"not_before,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +// CreatePolicyRuleRequest holds the parameters for creating a policy rule. +type CreatePolicyRuleRequest struct { + Description string `json:"description"` + Priority int `json:"priority,omitempty"` + Rule PolicyRuleBody `json:"rule"` + NotBefore string `json:"not_before,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` +} + +// UpdatePolicyRuleRequest holds the parameters for updating a policy rule. +// All fields are optional; omitted fields are left unchanged. +// Set ClearNotBefore or ClearExpiresAt to true to remove those constraints. +type UpdatePolicyRuleRequest struct { + Description string `json:"description,omitempty"` + Priority *int `json:"priority,omitempty"` + Enabled *bool `json:"enabled,omitempty"` + Rule *PolicyRuleBody `json:"rule,omitempty"` + NotBefore string `json:"not_before,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` + ClearNotBefore bool `json:"clear_not_before,omitempty"` + ClearExpiresAt bool `json:"clear_expires_at,omitempty"` +} + // --------------------------------------------------------------------------- // Options and Client struct // --------------------------------------------------------------------------- + // Options configures the MCIAS client. type Options struct { // CACertPath is an optional path to a PEM-encoded CA certificate for TLS @@ -85,6 +185,7 @@ type Options struct { // Token is an optional pre-existing bearer token. Token string } + // Client is a thread-safe MCIAS REST API client. // Security: the bearer token is guarded by a sync.RWMutex; it is never // written to logs or included in error messages in this library. @@ -94,9 +195,11 @@ type Client struct { mu sync.RWMutex token string } + // --------------------------------------------------------------------------- // Constructor // --------------------------------------------------------------------------- + // New creates a new Client for the given serverURL. // TLS 1.2 is the minimum version enforced on all connections. // If opts.CACertPath is set, that CA certificate is added to the trust pool. @@ -126,20 +229,24 @@ func New(serverURL string, opts Options) (*Client, error) { } return c, nil } + // Token returns the current bearer token (empty string if not logged in). func (c *Client) Token() string { c.mu.RLock() defer c.mu.RUnlock() return c.token } + // --------------------------------------------------------------------------- // Internal helpers // --------------------------------------------------------------------------- + func (c *Client) setToken(tok string) { c.mu.Lock() defer c.mu.Unlock() c.token = tok } + func (c *Client) do(method, path string, body interface{}, out interface{}) error { var reqBody io.Reader if body != nil { @@ -195,6 +302,7 @@ func (c *Client) do(method, path string, body interface{}, out interface{}) erro } return nil } + func makeError(status int, msg string) error { base := MciasError{StatusCode: status, Message: msg} switch { @@ -212,13 +320,16 @@ func makeError(status int, msg string) error { return &MciasServerError{base} } } + // --------------------------------------------------------------------------- -// API methods +// API methods — Public // --------------------------------------------------------------------------- + // Health calls GET /v1/health. Returns nil if the server is healthy. func (c *Client) Health() error { return c.do(http.MethodGet, "/v1/health", nil, nil) } + // GetPublicKey returns the server's Ed25519 public key in JWK format. func (c *Client) GetPublicKey() (*PublicKey, error) { var pk PublicKey @@ -227,6 +338,7 @@ func (c *Client) GetPublicKey() (*PublicKey, error) { } return &pk, nil } + // Login authenticates with username and password. On success the token is // stored in the Client and returned along with the expiry timestamp. // totpCode may be empty for accounts without TOTP. @@ -245,6 +357,23 @@ func (c *Client) Login(username, password, totpCode string) (token, expiresAt st c.setToken(resp.Token) return resp.Token, resp.ExpiresAt, nil } + +// ValidateToken validates a token string against the server. +// Returns claims; Valid is false (not an error) if the token is expired or +// revoked. +func (c *Client) ValidateToken(token string) (*TokenClaims, error) { + var claims TokenClaims + if err := c.do(http.MethodPost, "/v1/token/validate", + map[string]string{"token": token}, &claims); err != nil { + return nil, err + } + return &claims, nil +} + +// --------------------------------------------------------------------------- +// API methods — Authenticated +// --------------------------------------------------------------------------- + // Logout revokes the current token on the server and clears it from the client. func (c *Client) Logout() error { if err := c.do(http.MethodPost, "/v1/auth/logout", nil, nil); err != nil { @@ -253,6 +382,7 @@ func (c *Client) Logout() error { c.setToken("") return nil } + // RenewToken exchanges the current token for a fresh one. // The new token is stored in the client and returned. func (c *Client) RenewToken() (token, expiresAt string, err error) { @@ -266,17 +396,63 @@ func (c *Client) RenewToken() (token, expiresAt string, err error) { c.setToken(resp.Token) return resp.Token, resp.ExpiresAt, nil } -// ValidateToken validates a token string against the server. -// Returns claims; Valid is false (not an error) if the token is expired or -// revoked. -func (c *Client) ValidateToken(token string) (*TokenClaims, error) { - var claims TokenClaims - if err := c.do(http.MethodPost, "/v1/token/validate", - map[string]string{"token": token}, &claims); err != nil { + +// EnrollTOTP begins TOTP enrollment for the authenticated account. +// Returns a base32 secret and an otpauth:// URI for QR-code generation. +// The secret is shown once; it is not retrievable after this call. +// TOTP is not enforced until confirmed via ConfirmTOTP. +func (c *Client) EnrollTOTP() (*TOTPEnrollResponse, error) { + var resp TOTPEnrollResponse + if err := c.do(http.MethodPost, "/v1/auth/totp/enroll", nil, &resp); err != nil { return nil, err } - return &claims, nil + return &resp, nil } + +// ConfirmTOTP completes TOTP enrollment by verifying the current code against +// the pending secret. On success, TOTP becomes required for all future logins. +func (c *Client) ConfirmTOTP(code string) error { + return c.do(http.MethodPost, "/v1/auth/totp/confirm", + map[string]string{"code": code}, nil) +} + +// ChangePassword changes the password of the currently authenticated human +// account. currentPassword is required to prevent token-theft attacks. +// On success, all active sessions except the caller's are revoked. +// +// Security: both passwords are transmitted over TLS only; the server verifies +// currentPassword with constant-time comparison before accepting the change. +func (c *Client) ChangePassword(currentPassword, newPassword string) error { + return c.do(http.MethodPut, "/v1/auth/password", map[string]string{ + "current_password": currentPassword, + "new_password": newPassword, + }, nil) +} + +// --------------------------------------------------------------------------- +// API methods — Admin: Auth +// --------------------------------------------------------------------------- + +// RemoveTOTP clears TOTP enrollment for the given account (admin). +// Use for account recovery when a user has lost their TOTP device. +func (c *Client) RemoveTOTP(accountID string) error { + return c.do(http.MethodDelete, "/v1/auth/totp", + map[string]string{"account_id": accountID}, nil) +} + +// --------------------------------------------------------------------------- +// API methods — Admin: Accounts +// --------------------------------------------------------------------------- + +// ListAccounts returns all accounts. Requires admin role. +func (c *Client) ListAccounts() ([]Account, error) { + var accounts []Account + if err := c.do(http.MethodGet, "/v1/accounts", nil, &accounts); err != nil { + return nil, err + } + return accounts, nil +} + // CreateAccount creates a new account. accountType is "human" or "system". // password is required for human accounts. func (c *Client) CreateAccount(username, accountType, password string) (*Account, error) { @@ -293,14 +469,7 @@ func (c *Client) CreateAccount(username, accountType, password string) (*Account } return &acct, nil } -// ListAccounts returns all accounts. Requires admin role. -func (c *Client) ListAccounts() ([]Account, error) { - var accounts []Account - if err := c.do(http.MethodGet, "/v1/accounts", nil, &accounts); err != nil { - return nil, err - } - return accounts, nil -} + // GetAccount returns the account with the given ID. Requires admin role. func (c *Client) GetAccount(id string) (*Account, error) { var acct Account @@ -309,23 +478,22 @@ func (c *Client) GetAccount(id string) (*Account, error) { } return &acct, nil } -// UpdateAccount updates mutable account fields. Requires admin role. -// Pass an empty string for fields that should not be changed. -func (c *Client) UpdateAccount(id, status string) (*Account, error) { + +// UpdateAccount updates mutable account fields (currently only status). +// Requires admin role. Returns nil on success (HTTP 204). +func (c *Client) UpdateAccount(id, status string) error { req := map[string]string{} if status != "" { req["status"] = status } - var acct Account - if err := c.do(http.MethodPatch, "/v1/accounts/"+id, req, &acct); err != nil { - return nil, err - } - return &acct, nil + return c.do(http.MethodPatch, "/v1/accounts/"+id, req, nil) } + // DeleteAccount soft-deletes the account with the given ID. Requires admin. func (c *Client) DeleteAccount(id string) error { return c.do(http.MethodDelete, "/v1/accounts/"+id, nil, nil) } + // GetRoles returns the roles for accountID. Requires admin. func (c *Client) GetRoles(accountID string) ([]string, error) { var resp struct { @@ -336,11 +504,49 @@ func (c *Client) GetRoles(accountID string) ([]string, error) { } return resp.Roles, nil } + // SetRoles replaces the role set for accountID. Requires admin. func (c *Client) SetRoles(accountID string, roles []string) error { return c.do(http.MethodPut, "/v1/accounts/"+accountID+"/roles", map[string][]string{"roles": roles}, nil) } + +// AdminSetPassword resets a human account's password without requiring the +// current password. Requires admin. All active sessions for the target account +// are revoked on success. +func (c *Client) AdminSetPassword(accountID, newPassword string) error { + return c.do(http.MethodPut, "/v1/accounts/"+accountID+"/password", + map[string]string{"new_password": newPassword}, nil) +} + +// GetAccountTags returns the current tag set for an account. Requires admin. +func (c *Client) GetAccountTags(accountID string) ([]string, error) { + var resp struct { + Tags []string `json:"tags"` + } + if err := c.do(http.MethodGet, "/v1/accounts/"+accountID+"/tags", nil, &resp); err != nil { + return nil, err + } + return resp.Tags, nil +} + +// SetAccountTags replaces the full tag set for an account atomically. +// Pass an empty slice to clear all tags. Requires admin. +func (c *Client) SetAccountTags(accountID string, tags []string) ([]string, error) { + var resp struct { + Tags []string `json:"tags"` + } + if err := c.do(http.MethodPut, "/v1/accounts/"+accountID+"/tags", + map[string][]string{"tags": tags}, &resp); err != nil { + return nil, err + } + return resp.Tags, nil +} + +// --------------------------------------------------------------------------- +// API methods — Admin: Tokens +// --------------------------------------------------------------------------- + // IssueServiceToken issues a long-lived token for a system account. Requires admin. func (c *Client) IssueServiceToken(accountID string) (token, expiresAt string, err error) { var resp struct { @@ -353,10 +559,16 @@ func (c *Client) IssueServiceToken(accountID string) (token, expiresAt string, e } return resp.Token, resp.ExpiresAt, nil } + // RevokeToken revokes a token by JTI. Requires admin. func (c *Client) RevokeToken(jti string) error { return c.do(http.MethodDelete, "/v1/token/"+jti, nil, nil) } + +// --------------------------------------------------------------------------- +// API methods — Admin: Credentials +// --------------------------------------------------------------------------- + // GetPGCreds returns Postgres credentials for accountID. Requires admin. func (c *Client) GetPGCreds(accountID string) (*PGCreds, error) { var creds PGCreds @@ -365,6 +577,7 @@ func (c *Client) GetPGCreds(accountID string) (*PGCreds, error) { } return &creds, nil } + // SetPGCreds stores Postgres credentials for accountID. Requires admin. // The password is sent over TLS and encrypted at rest server-side. func (c *Client) SetPGCreds(accountID, host string, port int, database, username, password string) error { @@ -376,3 +589,78 @@ func (c *Client) SetPGCreds(accountID, host string, port int, database, username "password": password, }, nil) } + +// --------------------------------------------------------------------------- +// API methods — Admin: Audit +// --------------------------------------------------------------------------- + +// ListAudit retrieves audit log entries, newest first. Requires admin. +// f may be zero-valued to use defaults (limit=50, offset=0, no filter). +func (c *Client) ListAudit(f AuditFilter) (*AuditListResponse, error) { + path := "/v1/audit?" + if f.Limit > 0 { + path += fmt.Sprintf("limit=%d&", f.Limit) + } + if f.Offset > 0 { + path += fmt.Sprintf("offset=%d&", f.Offset) + } + if f.EventType != "" { + path += fmt.Sprintf("event_type=%s&", f.EventType) + } + if f.ActorID != "" { + path += fmt.Sprintf("actor_id=%s&", f.ActorID) + } + path = strings.TrimRight(path, "&?") + var resp AuditListResponse + if err := c.do(http.MethodGet, path, nil, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// --------------------------------------------------------------------------- +// API methods — Admin: Policy +// --------------------------------------------------------------------------- + +// ListPolicyRules returns all operator-defined policy rules ordered by +// priority (ascending). Requires admin. +func (c *Client) ListPolicyRules() ([]PolicyRule, error) { + var rules []PolicyRule + if err := c.do(http.MethodGet, "/v1/policy/rules", nil, &rules); err != nil { + return nil, err + } + return rules, nil +} + +// CreatePolicyRule creates a new policy rule. Requires admin. +func (c *Client) CreatePolicyRule(req CreatePolicyRuleRequest) (*PolicyRule, error) { + var rule PolicyRule + if err := c.do(http.MethodPost, "/v1/policy/rules", req, &rule); err != nil { + return nil, err + } + return &rule, nil +} + +// GetPolicyRule returns a single policy rule by integer ID. Requires admin. +func (c *Client) GetPolicyRule(id int) (*PolicyRule, error) { + var rule PolicyRule + if err := c.do(http.MethodGet, fmt.Sprintf("/v1/policy/rules/%d", id), nil, &rule); err != nil { + return nil, err + } + return &rule, nil +} + +// UpdatePolicyRule updates one or more fields of an existing policy rule. +// Requires admin. +func (c *Client) UpdatePolicyRule(id int, req UpdatePolicyRuleRequest) (*PolicyRule, error) { + var rule PolicyRule + if err := c.do(http.MethodPatch, fmt.Sprintf("/v1/policy/rules/%d", id), req, &rule); err != nil { + return nil, err + } + return &rule, nil +} + +// DeletePolicyRule permanently deletes a policy rule. Requires admin. +func (c *Client) DeletePolicyRule(id int) error { + return c.do(http.MethodDelete, fmt.Sprintf("/v1/policy/rules/%d", id), nil, nil) +} diff --git a/clients/go/client_test.go b/clients/go/client_test.go index 777a104..5905fba 100644 --- a/clients/go/client_test.go +++ b/clients/go/client_test.go @@ -2,6 +2,7 @@ // All tests use inline httptest.NewServer mocks to keep this module // self-contained (no cross-module imports). package mciasgoclient_test + import ( "encoding/json" "errors" @@ -9,12 +10,14 @@ import ( "net/http/httptest" "strings" "testing" + mciasgoclient "git.wntrmute.dev/kyle/mcias/clients/go" ) + // --------------------------------------------------------------------------- // helpers // --------------------------------------------------------------------------- -// newTestClient creates a client pointed at the given test server URL. + func newTestClient(t *testing.T, serverURL string) *mciasgoclient.Client { t.Helper() c, err := mciasgoclient.New(serverURL, mciasgoclient.Options{}) @@ -23,19 +26,21 @@ func newTestClient(t *testing.T, serverURL string) *mciasgoclient.Client { } return c } -// writeJSON is a shorthand for writing a JSON response. + func writeJSON(w http.ResponseWriter, status int, v interface{}) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) _ = json.NewEncoder(w).Encode(v) } -// writeError writes a JSON error body with the given status code. + func writeError(w http.ResponseWriter, status int, msg string) { writeJSON(w, status, map[string]string{"error": msg}) } + // --------------------------------------------------------------------------- // TestNew // --------------------------------------------------------------------------- + func TestNew(t *testing.T) { c, err := mciasgoclient.New("https://example.com", mciasgoclient.Options{}) if err != nil { @@ -45,6 +50,7 @@ func TestNew(t *testing.T) { t.Fatal("expected non-nil client") } } + func TestNewWithPresetToken(t *testing.T) { c, err := mciasgoclient.New("https://example.com", mciasgoclient.Options{Token: "preset-tok"}) if err != nil { @@ -54,15 +60,18 @@ func TestNewWithPresetToken(t *testing.T) { t.Errorf("expected preset-tok, got %q", c.Token()) } } + func TestNewBadCACert(t *testing.T) { _, err := mciasgoclient.New("https://example.com", mciasgoclient.Options{CACertPath: "/nonexistent/ca.pem"}) if err == nil { t.Fatal("expected error for missing CA cert file") } } + // --------------------------------------------------------------------------- // TestHealth // --------------------------------------------------------------------------- + func TestHealth(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/health" || r.Method != http.MethodGet { @@ -77,9 +86,7 @@ func TestHealth(t *testing.T) { t.Fatalf("Health: unexpected error: %v", err) } } -// --------------------------------------------------------------------------- -// TestHealthError -// --------------------------------------------------------------------------- + func TestHealthError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusServiceUnavailable, "service unavailable") @@ -98,9 +105,11 @@ func TestHealthError(t *testing.T) { t.Errorf("expected StatusCode 503, got %d", srvErr.StatusCode) } } + // --------------------------------------------------------------------------- // TestGetPublicKey // --------------------------------------------------------------------------- + func TestGetPublicKey(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/keys/public" { @@ -131,9 +140,11 @@ func TestGetPublicKey(t *testing.T) { t.Error("expected non-empty x") } } + // --------------------------------------------------------------------------- // TestLogin // --------------------------------------------------------------------------- + func TestLogin(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/auth/login" || r.Method != http.MethodPost { @@ -157,14 +168,11 @@ func TestLogin(t *testing.T) { if exp == "" { t.Error("expected non-empty expires_at") } - // Token must be stored in the client. if c.Token() != "tok-abc123" { t.Errorf("Token() = %q, want tok-abc123", c.Token()) } } -// --------------------------------------------------------------------------- -// TestLoginUnauthorized -// --------------------------------------------------------------------------- + func TestLoginUnauthorized(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusUnauthorized, "invalid credentials") @@ -180,16 +188,17 @@ func TestLoginUnauthorized(t *testing.T) { t.Errorf("expected MciasAuthError, got %T: %v", err, err) } } + // --------------------------------------------------------------------------- // TestLogout // --------------------------------------------------------------------------- + func TestLogout(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/v1/auth/login": writeJSON(w, http.StatusOK, map[string]string{ - "token": "tok-logout", - "expires_at": "2099-01-01T00:00:00Z", + "token": "tok-logout", "expires_at": "2099-01-01T00:00:00Z", }) case "/v1/auth/logout": w.WriteHeader(http.StatusOK) @@ -212,21 +221,21 @@ func TestLogout(t *testing.T) { t.Errorf("expected empty token after logout, got %q", c.Token()) } } + // --------------------------------------------------------------------------- // TestRenewToken // --------------------------------------------------------------------------- + func TestRenewToken(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/v1/auth/login": writeJSON(w, http.StatusOK, map[string]string{ - "token": "tok-old", - "expires_at": "2099-01-01T00:00:00Z", + "token": "tok-old", "expires_at": "2099-01-01T00:00:00Z", }) case "/v1/auth/renew": writeJSON(w, http.StatusOK, map[string]string{ - "token": "tok-new", - "expires_at": "2099-06-01T00:00:00Z", + "token": "tok-new", "expires_at": "2099-06-01T00:00:00Z", }) default: http.Error(w, "not found", http.StatusNotFound) @@ -248,9 +257,125 @@ func TestRenewToken(t *testing.T) { t.Errorf("Token() = %q, want tok-new", c.Token()) } } + +// --------------------------------------------------------------------------- +// TestEnrollTOTP / TestConfirmTOTP +// --------------------------------------------------------------------------- + +func TestEnrollTOTP(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/auth/totp/enroll" || r.Method != http.MethodPost { + http.Error(w, "not found", http.StatusNotFound) + return + } + writeJSON(w, http.StatusOK, map[string]string{ + "secret": "JBSWY3DPEHPK3PXP", + "otpauth_uri": "otpauth://totp/MCIAS:alice?secret=JBSWY3DPEHPK3PXP&issuer=MCIAS", + }) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + resp, err := c.EnrollTOTP() + if err != nil { + t.Fatalf("EnrollTOTP: %v", err) + } + if resp.Secret != "JBSWY3DPEHPK3PXP" { + t.Errorf("expected secret=JBSWY3DPEHPK3PXP, got %q", resp.Secret) + } + if resp.OTPAuthURI == "" { + t.Error("expected non-empty otpauth_uri") + } +} + +func TestConfirmTOTP(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/auth/totp/confirm" || r.Method != http.MethodPost { + http.Error(w, "not found", http.StatusNotFound) + return + } + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + if err := c.ConfirmTOTP("123456"); err != nil { + t.Fatalf("ConfirmTOTP: unexpected error: %v", err) + } +} + +func TestConfirmTOTPBadCode(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeError(w, http.StatusBadRequest, "invalid TOTP code") + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + err := c.ConfirmTOTP("000000") + if err == nil { + t.Fatal("expected error for bad TOTP code") + } + var inputErr *mciasgoclient.MciasInputError + if !errors.As(err, &inputErr) { + t.Errorf("expected MciasInputError, got %T: %v", err, err) + } +} + +// --------------------------------------------------------------------------- +// TestChangePassword +// --------------------------------------------------------------------------- + +func TestChangePassword(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/auth/password" || r.Method != http.MethodPut { + http.Error(w, "not found", http.StatusNotFound) + return + } + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + if err := c.ChangePassword("old-s3cr3t", "new-s3cr3t-long"); err != nil { + t.Fatalf("ChangePassword: unexpected error: %v", err) + } +} + +func TestChangePasswordWrongCurrent(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeError(w, http.StatusUnauthorized, "current password is incorrect") + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + err := c.ChangePassword("wrong", "new-s3cr3t-long") + if err == nil { + t.Fatal("expected error for wrong current password") + } + var authErr *mciasgoclient.MciasAuthError + if !errors.As(err, &authErr) { + t.Errorf("expected MciasAuthError, got %T: %v", err, err) + } +} + +// --------------------------------------------------------------------------- +// TestRemoveTOTP +// --------------------------------------------------------------------------- + +func TestRemoveTOTP(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/auth/totp" || r.Method != http.MethodDelete { + http.Error(w, "not found", http.StatusNotFound) + return + } + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + if err := c.RemoveTOTP("acct-uuid-42"); err != nil { + t.Fatalf("RemoveTOTP: unexpected error: %v", err) + } +} + // --------------------------------------------------------------------------- // TestValidateToken // --------------------------------------------------------------------------- + func TestValidateToken(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/token/validate" { @@ -258,10 +383,8 @@ func TestValidateToken(t *testing.T) { return } writeJSON(w, http.StatusOK, map[string]interface{}{ - "valid": true, - "sub": "user-uuid-1", - "roles": []string{"admin"}, - "expires_at": "2099-01-01T00:00:00Z", + "valid": true, "sub": "user-uuid-1", + "roles": []string{"admin"}, "expires_at": "2099-01-01T00:00:00Z", }) })) defer srv.Close() @@ -277,15 +400,10 @@ func TestValidateToken(t *testing.T) { t.Errorf("expected sub=user-uuid-1, got %q", claims.Sub) } } -// --------------------------------------------------------------------------- -// TestValidateTokenInvalid -// --------------------------------------------------------------------------- + func TestValidateTokenInvalid(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Server returns 200 with valid=false for an expired/revoked token. - writeJSON(w, http.StatusOK, map[string]interface{}{ - "valid": false, - }) + writeJSON(w, http.StatusOK, map[string]interface{}{"valid": false}) })) defer srv.Close() c := newTestClient(t, srv.URL) @@ -297,9 +415,11 @@ func TestValidateTokenInvalid(t *testing.T) { t.Error("expected claims.Valid = false") } } + // --------------------------------------------------------------------------- // TestCreateAccount // --------------------------------------------------------------------------- + func TestCreateAccount(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/accounts" || r.Method != http.MethodPost { @@ -307,13 +427,9 @@ func TestCreateAccount(t *testing.T) { return } writeJSON(w, http.StatusCreated, map[string]interface{}{ - "id": "acct-uuid-1", - "username": "bob", - "account_type": "human", - "status": "active", - "created_at": "2024-01-01T00:00:00Z", - "updated_at": "2024-01-01T00:00:00Z", - "totp_enabled": false, + "id": "acct-uuid-1", "username": "bob", "account_type": "human", + "status": "active", "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", "totp_enabled": false, }) })) defer srv.Close() @@ -329,9 +445,7 @@ func TestCreateAccount(t *testing.T) { t.Errorf("expected username=bob, got %q", acct.Username) } } -// --------------------------------------------------------------------------- -// TestCreateAccountConflict -// --------------------------------------------------------------------------- + func TestCreateAccountConflict(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusConflict, "username already exists") @@ -347,21 +461,19 @@ func TestCreateAccountConflict(t *testing.T) { t.Errorf("expected MciasConflictError, got %T: %v", err, err) } } + // --------------------------------------------------------------------------- // TestListAccounts // --------------------------------------------------------------------------- + func TestListAccounts(t *testing.T) { accounts := []map[string]interface{}{ - { - "id": "acct-1", "username": "alice", "account_type": "human", + {"id": "acct-1", "username": "alice", "account_type": "human", "status": "active", "created_at": "2024-01-01T00:00:00Z", - "updated_at": "2024-01-01T00:00:00Z", "totp_enabled": false, - }, - { - "id": "acct-2", "username": "bob", "account_type": "human", + "updated_at": "2024-01-01T00:00:00Z", "totp_enabled": false}, + {"id": "acct-2", "username": "bob", "account_type": "human", "status": "active", "created_at": "2024-01-02T00:00:00Z", - "updated_at": "2024-01-02T00:00:00Z", "totp_enabled": false, - }, + "updated_at": "2024-01-02T00:00:00Z", "totp_enabled": false}, } srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/accounts" || r.Method != http.MethodGet { @@ -383,27 +495,21 @@ func TestListAccounts(t *testing.T) { t.Errorf("expected alice, got %q", list[0].Username) } } + // --------------------------------------------------------------------------- // TestGetAccount // --------------------------------------------------------------------------- + func TestGetAccount(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - if !strings.HasPrefix(r.URL.Path, "/v1/accounts/") { + if r.Method != http.MethodGet || !strings.HasPrefix(r.URL.Path, "/v1/accounts/") { http.Error(w, "not found", http.StatusNotFound) return } writeJSON(w, http.StatusOK, map[string]interface{}{ - "id": "acct-uuid-42", - "username": "carol", - "account_type": "human", - "status": "active", - "created_at": "2024-01-01T00:00:00Z", - "updated_at": "2024-01-01T00:00:00Z", - "totp_enabled": false, + "id": "acct-uuid-42", "username": "carol", "account_type": "human", + "status": "active", "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", "totp_enabled": false, }) })) defer srv.Close() @@ -416,38 +522,30 @@ func TestGetAccount(t *testing.T) { t.Errorf("expected acct-uuid-42, got %q", acct.ID) } } + // --------------------------------------------------------------------------- // TestUpdateAccount // --------------------------------------------------------------------------- + func TestUpdateAccount(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPatch { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - writeJSON(w, http.StatusOK, map[string]interface{}{ - "id": "acct-uuid-42", - "username": "carol", - "account_type": "human", - "status": "disabled", - "created_at": "2024-01-01T00:00:00Z", - "updated_at": "2024-02-01T00:00:00Z", - "totp_enabled": false, - }) + w.WriteHeader(http.StatusNoContent) })) defer srv.Close() c := newTestClient(t, srv.URL) - acct, err := c.UpdateAccount("acct-uuid-42", "disabled") - if err != nil { - t.Fatalf("UpdateAccount: %v", err) - } - if acct.Status != "disabled" { - t.Errorf("expected status=disabled, got %q", acct.Status) + if err := c.UpdateAccount("acct-uuid-42", "inactive"); err != nil { + t.Fatalf("UpdateAccount: unexpected error: %v", err) } } + // --------------------------------------------------------------------------- // TestDeleteAccount // --------------------------------------------------------------------------- + func TestDeleteAccount(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodDelete { @@ -462,16 +560,33 @@ func TestDeleteAccount(t *testing.T) { t.Fatalf("DeleteAccount: unexpected error: %v", err) } } + // --------------------------------------------------------------------------- -// TestGetRoles +// TestAdminSetPassword // --------------------------------------------------------------------------- -func TestGetRoles(t *testing.T) { + +func TestAdminSetPassword(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + if r.Method != http.MethodPut || !strings.HasSuffix(r.URL.Path, "/password") { + http.Error(w, "not found", http.StatusNotFound) return } - if !strings.HasSuffix(r.URL.Path, "/roles") { + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + if err := c.AdminSetPassword("acct-uuid-42", "new-s3cr3t-long"); err != nil { + t.Fatalf("AdminSetPassword: unexpected error: %v", err) + } +} + +// --------------------------------------------------------------------------- +// TestGetRoles / TestSetRoles +// --------------------------------------------------------------------------- + +func TestGetRoles(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || !strings.HasSuffix(r.URL.Path, "/roles") { http.Error(w, "not found", http.StatusNotFound) return } @@ -492,9 +607,7 @@ func TestGetRoles(t *testing.T) { t.Errorf("expected roles[0]=admin, got %q", roles[0]) } } -// --------------------------------------------------------------------------- -// TestSetRoles -// --------------------------------------------------------------------------- + func TestSetRoles(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPut { @@ -509,9 +622,79 @@ func TestSetRoles(t *testing.T) { t.Fatalf("SetRoles: unexpected error: %v", err) } } + // --------------------------------------------------------------------------- -// TestIssueServiceToken +// TestGetAccountTags / TestSetAccountTags // --------------------------------------------------------------------------- + +func TestGetAccountTags(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || !strings.HasSuffix(r.URL.Path, "/tags") { + http.Error(w, "not found", http.StatusNotFound) + return + } + writeJSON(w, http.StatusOK, map[string]interface{}{ + "tags": []string{"env:production", "svc:payments-api"}, + }) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + tags, err := c.GetAccountTags("acct-uuid-42") + if err != nil { + t.Fatalf("GetAccountTags: %v", err) + } + if len(tags) != 2 { + t.Errorf("expected 2 tags, got %d", len(tags)) + } + if tags[0] != "env:production" { + t.Errorf("expected tags[0]=env:production, got %q", tags[0]) + } +} + +func TestSetAccountTags(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut || !strings.HasSuffix(r.URL.Path, "/tags") { + http.Error(w, "not found", http.StatusNotFound) + return + } + writeJSON(w, http.StatusOK, map[string]interface{}{ + "tags": []string{"env:staging"}, + }) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + tags, err := c.SetAccountTags("acct-uuid-42", []string{"env:staging"}) + if err != nil { + t.Fatalf("SetAccountTags: unexpected error: %v", err) + } + if len(tags) != 1 || tags[0] != "env:staging" { + t.Errorf("expected [env:staging], got %v", tags) + } +} + +func TestSetAccountTagsClear(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + writeJSON(w, http.StatusOK, map[string]interface{}{"tags": []string{}}) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + tags, err := c.SetAccountTags("acct-uuid-42", []string{}) + if err != nil { + t.Fatalf("SetAccountTags (clear): unexpected error: %v", err) + } + if len(tags) != 0 { + t.Errorf("expected empty tags, got %v", tags) + } +} + +// --------------------------------------------------------------------------- +// TestIssueServiceToken / TestRevokeToken +// --------------------------------------------------------------------------- + func TestIssueServiceToken(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/v1/token/issue" || r.Method != http.MethodPost { @@ -519,8 +702,7 @@ func TestIssueServiceToken(t *testing.T) { return } writeJSON(w, http.StatusOK, map[string]string{ - "token": "svc-tok-xyz", - "expires_at": "2099-01-01T00:00:00Z", + "token": "svc-tok-xyz", "expires_at": "2099-01-01T00:00:00Z", }) })) defer srv.Close() @@ -536,16 +718,10 @@ func TestIssueServiceToken(t *testing.T) { t.Error("expected non-empty expires_at") } } -// --------------------------------------------------------------------------- -// TestRevokeToken -// --------------------------------------------------------------------------- + func TestRevokeToken(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodDelete { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - if !strings.HasPrefix(r.URL.Path, "/v1/token/") { + if r.Method != http.MethodDelete || !strings.HasPrefix(r.URL.Path, "/v1/token/") { http.Error(w, "not found", http.StatusNotFound) return } @@ -557,25 +733,20 @@ func TestRevokeToken(t *testing.T) { t.Fatalf("RevokeToken: unexpected error: %v", err) } } + // --------------------------------------------------------------------------- -// TestGetPGCreds +// TestGetPGCreds / TestSetPGCreds // --------------------------------------------------------------------------- + func TestGetPGCreds(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - if !strings.HasSuffix(r.URL.Path, "/pgcreds") { + if r.Method != http.MethodGet || !strings.HasSuffix(r.URL.Path, "/pgcreds") { http.Error(w, "not found", http.StatusNotFound) return } writeJSON(w, http.StatusOK, map[string]interface{}{ - "host": "db.example.com", - "port": 5432, - "database": "myapp", - "username": "appuser", - "password": "secretpw", + "host": "db.example.com", "port": 5432, + "database": "myapp", "username": "appuser", "password": "secretpw", }) })) defer srv.Close() @@ -594,16 +765,10 @@ func TestGetPGCreds(t *testing.T) { t.Errorf("expected password=secretpw, got %q", creds.Password) } } -// --------------------------------------------------------------------------- -// TestSetPGCreds -// --------------------------------------------------------------------------- + func TestSetPGCreds(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPut { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return - } - if !strings.HasSuffix(r.URL.Path, "/pgcreds") { + if r.Method != http.MethodPut || !strings.HasSuffix(r.URL.Path, "/pgcreds") { http.Error(w, "not found", http.StatusNotFound) return } @@ -611,14 +776,238 @@ func TestSetPGCreds(t *testing.T) { })) defer srv.Close() c := newTestClient(t, srv.URL) - err := c.SetPGCreds("acct-uuid-42", "db.example.com", 5432, "myapp", "appuser", "secretpw") - if err != nil { + if err := c.SetPGCreds("acct-uuid-42", "db.example.com", 5432, "myapp", "appuser", "secretpw"); err != nil { t.Fatalf("SetPGCreds: unexpected error: %v", err) } } + +// --------------------------------------------------------------------------- +// TestListAudit +// --------------------------------------------------------------------------- + +func TestListAudit(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v1/audit") || r.Method != http.MethodGet { + http.Error(w, "not found", http.StatusNotFound) + return + } + writeJSON(w, http.StatusOK, map[string]interface{}{ + "events": []map[string]interface{}{ + {"id": 42, "event_type": "login_ok", "event_time": "2026-03-11T09:01:23Z", + "actor_id": "acct-uuid-1", "ip_address": "192.0.2.1"}, + }, + "total": 1, "limit": 50, "offset": 0, + }) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + resp, err := c.ListAudit(mciasgoclient.AuditFilter{}) + if err != nil { + t.Fatalf("ListAudit: %v", err) + } + if resp.Total != 1 { + t.Errorf("expected total=1, got %d", resp.Total) + } + if len(resp.Events) != 1 { + t.Fatalf("expected 1 event, got %d", len(resp.Events)) + } + if resp.Events[0].EventType != "login_ok" { + t.Errorf("expected event_type=login_ok, got %q", resp.Events[0].EventType) + } +} + +func TestListAuditWithFilter(t *testing.T) { + var capturedQuery string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedQuery = r.URL.RawQuery + writeJSON(w, http.StatusOK, map[string]interface{}{ + "events": []map[string]interface{}{}, + "total": 0, "limit": 10, "offset": 5, + }) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + _, err := c.ListAudit(mciasgoclient.AuditFilter{ + Limit: 10, Offset: 5, EventType: "login_fail", ActorID: "acct-uuid-1", + }) + if err != nil { + t.Fatalf("ListAudit: %v", err) + } + for _, want := range []string{"limit=10", "offset=5", "event_type=login_fail", "actor_id=acct-uuid-1"} { + if !strings.Contains(capturedQuery, want) { + t.Errorf("expected %q in query string, got %q", want, capturedQuery) + } + } +} + +// --------------------------------------------------------------------------- +// TestListPolicyRules +// --------------------------------------------------------------------------- + +func TestListPolicyRules(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/policy/rules" || r.Method != http.MethodGet { + http.Error(w, "not found", http.StatusNotFound) + return + } + writeJSON(w, http.StatusOK, []map[string]interface{}{ + { + "id": 1, "priority": 100, + "description": "Allow payments-api to read its own pgcreds", + "rule": map[string]interface{}{"effect": "allow", "actions": []string{"pgcreds:read"}}, + "enabled": true, + "created_at": "2026-03-11T09:00:00Z", "updated_at": "2026-03-11T09:00:00Z", + }, + }) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + rules, err := c.ListPolicyRules() + if err != nil { + t.Fatalf("ListPolicyRules: %v", err) + } + if len(rules) != 1 { + t.Fatalf("expected 1 rule, got %d", len(rules)) + } + if rules[0].ID != 1 { + t.Errorf("expected id=1, got %d", rules[0].ID) + } + if rules[0].Description != "Allow payments-api to read its own pgcreds" { + t.Errorf("unexpected description: %q", rules[0].Description) + } +} + +// --------------------------------------------------------------------------- +// TestCreatePolicyRule +// --------------------------------------------------------------------------- + +func TestCreatePolicyRule(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/policy/rules" || r.Method != http.MethodPost { + http.Error(w, "not found", http.StatusNotFound) + return + } + writeJSON(w, http.StatusCreated, map[string]interface{}{ + "id": 7, "priority": 50, "description": "Test rule", + "rule": map[string]interface{}{"effect": "deny"}, + "enabled": true, + "created_at": "2026-03-11T09:00:00Z", "updated_at": "2026-03-11T09:00:00Z", + }) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + rule, err := c.CreatePolicyRule(mciasgoclient.CreatePolicyRuleRequest{ + Description: "Test rule", + Priority: 50, + Rule: mciasgoclient.PolicyRuleBody{Effect: "deny"}, + }) + if err != nil { + t.Fatalf("CreatePolicyRule: %v", err) + } + if rule.ID != 7 { + t.Errorf("expected id=7, got %d", rule.ID) + } + if rule.Priority != 50 { + t.Errorf("expected priority=50, got %d", rule.Priority) + } +} + +// --------------------------------------------------------------------------- +// TestGetPolicyRule +// --------------------------------------------------------------------------- + +func TestGetPolicyRule(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet || r.URL.Path != "/v1/policy/rules/7" { + http.Error(w, "not found", http.StatusNotFound) + return + } + writeJSON(w, http.StatusOK, map[string]interface{}{ + "id": 7, "priority": 50, "description": "Test rule", + "rule": map[string]interface{}{"effect": "allow"}, + "enabled": true, + "created_at": "2026-03-11T09:00:00Z", "updated_at": "2026-03-11T09:00:00Z", + }) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + rule, err := c.GetPolicyRule(7) + if err != nil { + t.Fatalf("GetPolicyRule: %v", err) + } + if rule.ID != 7 { + t.Errorf("expected id=7, got %d", rule.ID) + } +} + +func TestGetPolicyRuleNotFound(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + writeError(w, http.StatusNotFound, "rule not found") + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + _, err := c.GetPolicyRule(999) + if err == nil { + t.Fatal("expected error for 404") + } + var notFoundErr *mciasgoclient.MciasNotFoundError + if !errors.As(err, ¬FoundErr) { + t.Errorf("expected MciasNotFoundError, got %T: %v", err, err) + } +} + +// --------------------------------------------------------------------------- +// TestUpdatePolicyRule +// --------------------------------------------------------------------------- + +func TestUpdatePolicyRule(t *testing.T) { + enabled := false + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPatch || r.URL.Path != "/v1/policy/rules/7" { + http.Error(w, "not found", http.StatusNotFound) + return + } + writeJSON(w, http.StatusOK, map[string]interface{}{ + "id": 7, "priority": 50, "description": "Test rule", + "rule": map[string]interface{}{"effect": "allow"}, + "enabled": false, + "created_at": "2026-03-11T09:00:00Z", "updated_at": "2026-03-12T10:00:00Z", + }) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + rule, err := c.UpdatePolicyRule(7, mciasgoclient.UpdatePolicyRuleRequest{Enabled: &enabled}) + if err != nil { + t.Fatalf("UpdatePolicyRule: %v", err) + } + if rule.Enabled { + t.Error("expected enabled=false after update") + } +} + +// --------------------------------------------------------------------------- +// TestDeletePolicyRule +// --------------------------------------------------------------------------- + +func TestDeletePolicyRule(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete || r.URL.Path != "/v1/policy/rules/7" { + http.Error(w, "not found", http.StatusNotFound) + return + } + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + c := newTestClient(t, srv.URL) + if err := c.DeletePolicyRule(7); err != nil { + t.Fatalf("DeletePolicyRule: unexpected error: %v", err) + } +} + // --------------------------------------------------------------------------- // TestIntegration: full login → validate → logout flow // --------------------------------------------------------------------------- + func TestIntegration(t *testing.T) { const sessionToken = "integration-tok-999" mux := http.NewServeMux() @@ -640,8 +1029,7 @@ func TestIntegration(t *testing.T) { return } writeJSON(w, http.StatusOK, map[string]string{ - "token": sessionToken, - "expires_at": "2099-01-01T00:00:00Z", + "token": sessionToken, "expires_at": "2099-01-01T00:00:00Z", }) }) mux.HandleFunc("/v1/token/validate", func(w http.ResponseWriter, r *http.Request) { @@ -658,15 +1046,11 @@ func TestIntegration(t *testing.T) { } if body.Token == sessionToken { writeJSON(w, http.StatusOK, map[string]interface{}{ - "valid": true, - "sub": "alice-uuid", - "roles": []string{"user"}, - "expires_at": "2099-01-01T00:00:00Z", + "valid": true, "sub": "alice-uuid", + "roles": []string{"user"}, "expires_at": "2099-01-01T00:00:00Z", }) } else { - writeJSON(w, http.StatusOK, map[string]interface{}{ - "valid": false, - }) + writeJSON(w, http.StatusOK, map[string]interface{}{"valid": false}) } }) mux.HandleFunc("/v1/auth/logout", func(w http.ResponseWriter, r *http.Request) { @@ -674,9 +1058,7 @@ func TestIntegration(t *testing.T) { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } - // Verify Authorization header is present. - auth := r.Header.Get("Authorization") - if auth == "" { + if r.Header.Get("Authorization") == "" { writeError(w, http.StatusUnauthorized, "missing token") return } @@ -685,7 +1067,8 @@ func TestIntegration(t *testing.T) { srv := httptest.NewServer(mux) defer srv.Close() c := newTestClient(t, srv.URL) - // Step 1: login with wrong credentials should fail. + + // Step 1: wrong credentials → MciasAuthError. _, _, err := c.Login("alice", "wrong-password", "") if err == nil { t.Fatal("expected error for wrong credentials") @@ -694,7 +1077,8 @@ func TestIntegration(t *testing.T) { if !errors.As(err, &authErr) { t.Errorf("expected MciasAuthError, got %T", err) } - // Step 2: login with correct credentials. + + // Step 2: correct login. tok, _, err := c.Login("alice", "correct-horse", "") if err != nil { t.Fatalf("Login: %v", err) @@ -702,7 +1086,8 @@ func TestIntegration(t *testing.T) { if tok != sessionToken { t.Errorf("expected %q, got %q", sessionToken, tok) } - // Step 3: validate the returned token. + + // Step 3: validate → valid=true. claims, err := c.ValidateToken(tok) if err != nil { t.Fatalf("ValidateToken: %v", err) @@ -713,7 +1098,8 @@ func TestIntegration(t *testing.T) { if claims.Sub != "alice-uuid" { t.Errorf("expected sub=alice-uuid, got %q", claims.Sub) } - // Step 4: validate an unknown token returns Valid=false, not an error. + + // Step 4: garbage token → valid=false (not an error). claims2, err := c.ValidateToken("garbage-token") if err != nil { t.Fatalf("ValidateToken(garbage): unexpected error: %v", err) @@ -721,7 +1107,8 @@ func TestIntegration(t *testing.T) { if claims2.Valid { t.Error("expected Valid=false for garbage token") } - // Step 5: logout clears the stored token. + + // Step 5: logout clears stored token. if err := c.Logout(); err != nil { t.Fatalf("Logout: %v", err) } diff --git a/clients/python/mcias_client/__init__.py b/clients/python/mcias_client/__init__.py index 65cfdb5..2754c1a 100644 --- a/clients/python/mcias_client/__init__.py +++ b/clients/python/mcias_client/__init__.py @@ -7,9 +7,10 @@ from ._errors import ( MciasForbiddenError, MciasInputError, MciasNotFoundError, + MciasRateLimitError, MciasServerError, ) -from ._models import Account, PGCreds, PublicKey, TokenClaims +from ._models import Account, PGCreds, PolicyRule, PublicKey, RuleBody, TokenClaims __all__ = [ "Client", @@ -19,9 +20,12 @@ __all__ = [ "MciasNotFoundError", "MciasInputError", "MciasConflictError", + "MciasRateLimitError", "MciasServerError", "Account", "PublicKey", "TokenClaims", "PGCreds", + "PolicyRule", + "RuleBody", ] diff --git a/clients/python/mcias_client/_client.py b/clients/python/mcias_client/_client.py index 67f1688..de3f543 100644 --- a/clients/python/mcias_client/_client.py +++ b/clients/python/mcias_client/_client.py @@ -8,7 +8,7 @@ from typing import Any import httpx from ._errors import raise_for_status -from ._models import Account, PGCreds, PublicKey, TokenClaims +from ._models import Account, PGCreds, PolicyRule, PublicKey, RuleBody, TokenClaims class Client: @@ -76,6 +76,29 @@ class Client: if status == 204 or not response.content: return None return response.json() # type: ignore[no-any-return] + def _request_list( + self, + method: str, + path: str, + *, + json: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + """Send a request that returns a JSON array at the top level.""" + url = f"{self._base_url}{path}" + headers: dict[str, str] = {} + if self.token is not None: + headers["Authorization"] = f"Bearer {self.token}" + response = self._http.request(method, url, json=json, headers=headers) + status = response.status_code + if status >= 400: + try: + body = response.json() + message = str(body.get("error", response.text)) + except Exception: + message = response.text + raise_for_status(status, message) + return response.json() # type: ignore[no-any-return] + # ── Public ──────────────────────────────────────────────────────────────── def health(self) -> None: """GET /v1/health — liveness check.""" self._request("GET", "/v1/health") @@ -105,6 +128,12 @@ class Client: expires_at = str(data["expires_at"]) self.token = token return token, expires_at + def validate_token(self, token: str) -> TokenClaims: + """POST /v1/token/validate — check whether a token is valid.""" + data = self._request("POST", "/v1/token/validate", json={"token": token}) + assert data is not None + return TokenClaims.from_dict(data) + # ── Authenticated ────────────────────────────────────────────────────────── def logout(self) -> None: """POST /v1/auth/logout — invalidate the current token.""" self._request("POST", "/v1/auth/logout") @@ -119,11 +148,45 @@ class Client: expires_at = str(data["expires_at"]) self.token = token return token, expires_at - def validate_token(self, token: str) -> TokenClaims: - """POST /v1/token/validate — check whether a token is valid.""" - data = self._request("POST", "/v1/token/validate", json={"token": token}) + def enroll_totp(self) -> tuple[str, str]: + """POST /v1/auth/totp/enroll — begin TOTP enrollment. + Returns (secret, otpauth_uri). The secret is shown only once. + """ + data = self._request("POST", "/v1/auth/totp/enroll") assert data is not None - return TokenClaims.from_dict(data) + return str(data["secret"]), str(data["otpauth_uri"]) + def confirm_totp(self, code: str) -> None: + """POST /v1/auth/totp/confirm — confirm TOTP enrollment with a code.""" + self._request("POST", "/v1/auth/totp/confirm", json={"code": code}) + def change_password(self, current_password: str, new_password: str) -> None: + """PUT /v1/auth/password — change own password (self-service).""" + self._request( + "PUT", + "/v1/auth/password", + json={"current_password": current_password, "new_password": new_password}, + ) + # ── Admin — Auth ────────────────────────────────────────────────────────── + def remove_totp(self, account_id: str) -> None: + """DELETE /v1/auth/totp — remove TOTP from an account (admin).""" + self._request("DELETE", "/v1/auth/totp", json={"account_id": account_id}) + # ── Admin — Tokens ──────────────────────────────────────────────────────── + def issue_service_token(self, account_id: str) -> tuple[str, str]: + """POST /v1/token/issue — issue a long-lived service token (admin). + Returns (token, expires_at). + """ + data = self._request("POST", "/v1/token/issue", json={"account_id": account_id}) + assert data is not None + return str(data["token"]), str(data["expires_at"]) + def revoke_token(self, jti: str) -> None: + """DELETE /v1/token/{jti} — revoke a token by JTI (admin).""" + self._request("DELETE", f"/v1/token/{jti}") + # ── Admin — Accounts ────────────────────────────────────────────────────── + def list_accounts(self) -> list[Account]: + """GET /v1/accounts — list all accounts (admin). + The API returns a JSON array directly (no wrapper object). + """ + items = self._request_list("GET", "/v1/accounts") + return [Account.from_dict(a) for a in items] def create_account( self, username: str, @@ -131,7 +194,7 @@ class Client: *, password: str | None = None, ) -> Account: - """POST /v1/accounts — create a new account.""" + """POST /v1/accounts — create a new account (admin).""" payload: dict[str, Any] = { "username": username, "account_type": account_type, @@ -141,14 +204,8 @@ class Client: data = self._request("POST", "/v1/accounts", json=payload) assert data is not None return Account.from_dict(data) - def list_accounts(self) -> list[Account]: - """GET /v1/accounts — list all accounts.""" - data = self._request("GET", "/v1/accounts") - assert data is not None - accounts_raw = data.get("accounts") or [] - return [Account.from_dict(a) for a in accounts_raw] def get_account(self, account_id: str) -> Account: - """GET /v1/accounts/{id} — retrieve a single account.""" + """GET /v1/accounts/{id} — retrieve a single account (admin).""" data = self._request("GET", f"/v1/accounts/{account_id}") assert data is not None return Account.from_dict(data) @@ -157,42 +214,40 @@ class Client: account_id: str, *, status: str | None = None, - ) -> Account: - """PATCH /v1/accounts/{id} — update account fields.""" + ) -> None: + """PATCH /v1/accounts/{id} — update account fields (admin). + Currently only `status` is patchable. Returns None (204 No Content). + """ payload: dict[str, Any] = {} if status is not None: payload["status"] = status - data = self._request("PATCH", f"/v1/accounts/{account_id}", json=payload) - assert data is not None - return Account.from_dict(data) + self._request("PATCH", f"/v1/accounts/{account_id}", json=payload) def delete_account(self, account_id: str) -> None: - """DELETE /v1/accounts/{id} — permanently remove an account.""" + """DELETE /v1/accounts/{id} — soft-delete an account (admin).""" self._request("DELETE", f"/v1/accounts/{account_id}") def get_roles(self, account_id: str) -> list[str]: - """GET /v1/accounts/{id}/roles — list roles for an account.""" + """GET /v1/accounts/{id}/roles — list roles for an account (admin).""" data = self._request("GET", f"/v1/accounts/{account_id}/roles") assert data is not None roles_raw = data.get("roles") or [] return [str(r) for r in roles_raw] def set_roles(self, account_id: str, roles: list[str]) -> None: - """PUT /v1/accounts/{id}/roles — replace the full role set.""" + """PUT /v1/accounts/{id}/roles — replace the full role set (admin).""" self._request( "PUT", f"/v1/accounts/{account_id}/roles", json={"roles": roles}, ) - def issue_service_token(self, account_id: str) -> tuple[str, str]: - """POST /v1/accounts/{id}/token — issue a long-lived service token. - Returns (token, expires_at). - """ - data = self._request("POST", f"/v1/accounts/{account_id}/token") - assert data is not None - return str(data["token"]), str(data["expires_at"]) - def revoke_token(self, jti: str) -> None: - """DELETE /v1/token/{jti} — revoke a token by JTI.""" - self._request("DELETE", f"/v1/token/{jti}") + def admin_set_password(self, account_id: str, new_password: str) -> None: + """PUT /v1/accounts/{id}/password — reset a password without the old one (admin).""" + self._request( + "PUT", + f"/v1/accounts/{account_id}/password", + json={"new_password": new_password}, + ) + # ── Admin — Credentials ─────────────────────────────────────────────────── def get_pg_creds(self, account_id: str) -> PGCreds: - """GET /v1/accounts/{id}/pgcreds — retrieve Postgres credentials.""" + """GET /v1/accounts/{id}/pgcreds — retrieve Postgres credentials (admin).""" data = self._request("GET", f"/v1/accounts/{account_id}/pgcreds") assert data is not None return PGCreds.from_dict(data) @@ -205,7 +260,7 @@ class Client: username: str, password: str, ) -> None: - """PUT /v1/accounts/{id}/pgcreds — store or replace Postgres credentials.""" + """PUT /v1/accounts/{id}/pgcreds — store or replace Postgres credentials (admin).""" payload: dict[str, Any] = { "host": host, "port": port, @@ -214,3 +269,89 @@ class Client: "password": password, } self._request("PUT", f"/v1/accounts/{account_id}/pgcreds", json=payload) + # ── Admin — Policy ──────────────────────────────────────────────────────── + def get_account_tags(self, account_id: str) -> list[str]: + """GET /v1/accounts/{id}/tags — get account tags (admin).""" + data = self._request("GET", f"/v1/accounts/{account_id}/tags") + assert data is not None + return [str(t) for t in (data.get("tags") or [])] + def set_account_tags(self, account_id: str, tags: list[str]) -> list[str]: + """PUT /v1/accounts/{id}/tags — replace the full tag set (admin). + Returns the updated tag list. + """ + data = self._request( + "PUT", + f"/v1/accounts/{account_id}/tags", + json={"tags": tags}, + ) + assert data is not None + return [str(t) for t in (data.get("tags") or [])] + def list_policy_rules(self) -> list[PolicyRule]: + """GET /v1/policy/rules — list all operator policy rules (admin).""" + items = self._request_list("GET", "/v1/policy/rules") + return [PolicyRule.from_dict(r) for r in items] + def create_policy_rule( + self, + description: str, + rule: RuleBody, + *, + priority: int | None = None, + not_before: str | None = None, + expires_at: str | None = None, + ) -> PolicyRule: + """POST /v1/policy/rules — create a policy rule (admin).""" + payload: dict[str, Any] = { + "description": description, + "rule": rule.to_dict(), + } + if priority is not None: + payload["priority"] = priority + if not_before is not None: + payload["not_before"] = not_before + if expires_at is not None: + payload["expires_at"] = expires_at + data = self._request("POST", "/v1/policy/rules", json=payload) + assert data is not None + return PolicyRule.from_dict(data) + def get_policy_rule(self, rule_id: int) -> PolicyRule: + """GET /v1/policy/rules/{id} — get a policy rule (admin).""" + data = self._request("GET", f"/v1/policy/rules/{rule_id}") + assert data is not None + return PolicyRule.from_dict(data) + def update_policy_rule( + self, + rule_id: int, + *, + description: str | None = None, + priority: int | None = None, + enabled: bool | None = None, + rule: RuleBody | None = None, + not_before: str | None = None, + expires_at: str | None = None, + clear_not_before: bool | None = None, + clear_expires_at: bool | None = None, + ) -> PolicyRule: + """PATCH /v1/policy/rules/{id} — update a policy rule (admin).""" + payload: dict[str, Any] = {} + if description is not None: + payload["description"] = description + if priority is not None: + payload["priority"] = priority + if enabled is not None: + payload["enabled"] = enabled + if rule is not None: + payload["rule"] = rule.to_dict() + if not_before is not None: + payload["not_before"] = not_before + if expires_at is not None: + payload["expires_at"] = expires_at + if clear_not_before is not None: + payload["clear_not_before"] = clear_not_before + if clear_expires_at is not None: + payload["clear_expires_at"] = clear_expires_at + data = self._request("PATCH", f"/v1/policy/rules/{rule_id}", json=payload) + assert data is not None + return PolicyRule.from_dict(data) + def delete_policy_rule(self, rule_id: int) -> None: + """DELETE /v1/policy/rules/{id} — delete a policy rule (admin).""" + self._request("DELETE", f"/v1/policy/rules/{rule_id}") diff --git a/clients/python/mcias_client/_errors.py b/clients/python/mcias_client/_errors.py index fa21a0d..c141dfc 100644 --- a/clients/python/mcias_client/_errors.py +++ b/clients/python/mcias_client/_errors.py @@ -15,6 +15,8 @@ class MciasInputError(MciasError): """400 Bad Request — malformed request.""" class MciasConflictError(MciasError): """409 Conflict — e.g. duplicate username.""" +class MciasRateLimitError(MciasError): + """429 Too Many Requests — rate limit exceeded.""" class MciasServerError(MciasError): """5xx — unexpected server error.""" def raise_for_status(status_code: int, message: str) -> None: @@ -25,6 +27,7 @@ def raise_for_status(status_code: int, message: str) -> None: 403: MciasForbiddenError, 404: MciasNotFoundError, 409: MciasConflictError, + 429: MciasRateLimitError, } cls = exc_map.get(status_code, MciasServerError) raise cls(status_code, message) diff --git a/clients/python/mcias_client/_models.py b/clients/python/mcias_client/_models.py index 0fa026d..e84ea38 100644 --- a/clients/python/mcias_client/_models.py +++ b/clients/python/mcias_client/_models.py @@ -1,6 +1,6 @@ """Data models for MCIAS API responses.""" from dataclasses import dataclass, field -from typing import cast +from typing import Any, cast @dataclass @@ -74,3 +74,73 @@ class PGCreds: username=str(d["username"]), password=str(d["password"]), ) +@dataclass +class RuleBody: + """Match conditions and effect of a policy rule.""" + effect: str + roles: list[str] = field(default_factory=list) + account_types: list[str] = field(default_factory=list) + subject_uuid: str | None = None + actions: list[str] = field(default_factory=list) + resource_type: str | None = None + owner_matches_subject: bool | None = None + service_names: list[str] = field(default_factory=list) + required_tags: list[str] = field(default_factory=list) + @classmethod + def from_dict(cls, d: dict[str, object]) -> "RuleBody": + return cls( + effect=str(d["effect"]), + roles=[str(r) for r in cast(list[Any], d.get("roles") or [])], + account_types=[str(t) for t in cast(list[Any], d.get("account_types") or [])], + subject_uuid=str(d["subject_uuid"]) if d.get("subject_uuid") is not None else None, + actions=[str(a) for a in cast(list[Any], d.get("actions") or [])], + resource_type=str(d["resource_type"]) if d.get("resource_type") is not None else None, + owner_matches_subject=bool(d["owner_matches_subject"]) if d.get("owner_matches_subject") is not None else None, + service_names=[str(s) for s in cast(list[Any], d.get("service_names") or [])], + required_tags=[str(t) for t in cast(list[Any], d.get("required_tags") or [])], + ) + def to_dict(self) -> dict[str, Any]: + """Serialise to a JSON-compatible dict, omitting None/empty fields.""" + out: dict[str, Any] = {"effect": self.effect} + if self.roles: + out["roles"] = self.roles + if self.account_types: + out["account_types"] = self.account_types + if self.subject_uuid is not None: + out["subject_uuid"] = self.subject_uuid + if self.actions: + out["actions"] = self.actions + if self.resource_type is not None: + out["resource_type"] = self.resource_type + if self.owner_matches_subject is not None: + out["owner_matches_subject"] = self.owner_matches_subject + if self.service_names: + out["service_names"] = self.service_names + if self.required_tags: + out["required_tags"] = self.required_tags + return out +@dataclass +class PolicyRule: + """An operator-defined policy rule.""" + id: int + priority: int + description: str + rule: RuleBody + enabled: bool + created_at: str + updated_at: str + not_before: str | None = None + expires_at: str | None = None + @classmethod + def from_dict(cls, d: dict[str, object]) -> "PolicyRule": + return cls( + id=int(cast(int, d["id"])), + priority=int(cast(int, d["priority"])), + description=str(d["description"]), + rule=RuleBody.from_dict(cast(dict[str, object], d["rule"])), + enabled=bool(d["enabled"]), + created_at=str(d["created_at"]), + updated_at=str(d["updated_at"]), + not_before=str(d["not_before"]) if d.get("not_before") is not None else None, + expires_at=str(d["expires_at"]) if d.get("expires_at") is not None else None, + ) diff --git a/clients/python/tests/test_client.py b/clients/python/tests/test_client.py index 9e32919..6c40a6b 100644 --- a/clients/python/tests/test_client.py +++ b/clients/python/tests/test_client.py @@ -13,15 +13,16 @@ from mcias_client import ( MciasForbiddenError, MciasInputError, MciasNotFoundError, + MciasRateLimitError, MciasServerError, ) -from mcias_client._models import Account, PGCreds, PublicKey, TokenClaims +from mcias_client._models import Account, PGCreds, PolicyRule, PublicKey, RuleBody, TokenClaims BASE_URL = "https://auth.example.com" SAMPLE_ACCOUNT: dict[str, object] = { "id": "acc-001", "username": "alice", - "account_type": "user", + "account_type": "human", "status": "active", "created_at": "2024-01-01T00:00:00Z", "updated_at": "2024-01-01T00:00:00Z", @@ -34,6 +35,24 @@ SAMPLE_PK: dict[str, object] = { "use": "sig", "alg": "EdDSA", } +SAMPLE_RULE_BODY: dict[str, object] = { + "effect": "allow", + "roles": ["svc:payments-api"], + "actions": ["pgcreds:read"], + "resource_type": "pgcreds", + "owner_matches_subject": True, +} +SAMPLE_POLICY_RULE: dict[str, object] = { + "id": 1, + "priority": 100, + "description": "Allow payments-api to read its own pgcreds", + "rule": SAMPLE_RULE_BODY, + "enabled": True, + "not_before": None, + "expires_at": None, + "created_at": "2026-03-11T09:00:00Z", + "updated_at": "2026-03-11T09:00:00Z", +} @pytest.fixture def client() -> Client: return Client(BASE_URL) @@ -88,6 +107,16 @@ def test_login_success(client: Client) -> None: assert expires_at == "2099-01-01T00:00:00Z" assert client.token == "jwt-token-abc" @respx.mock +def test_login_with_totp(client: Client) -> None: + respx.post(f"{BASE_URL}/v1/auth/login").mock( + return_value=httpx.Response( + 200, + json={"token": "jwt-token-totp", "expires_at": "2099-01-01T00:00:00Z"}, + ) + ) + token, _ = client.login("alice", "s3cr3t", totp_code="123456") + assert token == "jwt-token-totp" +@respx.mock def test_login_unauthorized(client: Client) -> None: respx.post(f"{BASE_URL}/v1/auth/login").mock( return_value=httpx.Response( @@ -98,6 +127,14 @@ def test_login_unauthorized(client: Client) -> None: client.login("alice", "wrong") assert exc_info.value.status_code == 401 @respx.mock +def test_login_rate_limited(client: Client) -> None: + respx.post(f"{BASE_URL}/v1/auth/login").mock( + return_value=httpx.Response(429, json={"error": "rate limit exceeded", "code": "rate_limited"}) + ) + with pytest.raises(MciasRateLimitError) as exc_info: + client.login("alice", "s3cr3t") + assert exc_info.value.status_code == 429 +@respx.mock def test_logout_clears_token(admin_client: Client) -> None: respx.post(f"{BASE_URL}/v1/auth/logout").mock( return_value=httpx.Response(204) @@ -147,11 +184,58 @@ def test_validate_token_invalid(admin_client: Client) -> None: claims = admin_client.validate_token("expired-token") assert claims.valid is False @respx.mock +def test_enroll_totp(admin_client: Client) -> None: + respx.post(f"{BASE_URL}/v1/auth/totp/enroll").mock( + return_value=httpx.Response( + 200, + json={"secret": "JBSWY3DPEHPK3PXP", "otpauth_uri": "otpauth://totp/MCIAS:alice?secret=JBSWY3DPEHPK3PXP&issuer=MCIAS"}, + ) + ) + secret, uri = admin_client.enroll_totp() + assert secret == "JBSWY3DPEHPK3PXP" + assert "otpauth://totp/" in uri +@respx.mock +def test_confirm_totp(admin_client: Client) -> None: + respx.post(f"{BASE_URL}/v1/auth/totp/confirm").mock( + return_value=httpx.Response(204) + ) + admin_client.confirm_totp("123456") # should not raise +@respx.mock +def test_change_password(admin_client: Client) -> None: + respx.put(f"{BASE_URL}/v1/auth/password").mock( + return_value=httpx.Response(204) + ) + admin_client.change_password("old-pass", "new-pass-long-enough") # should not raise +@respx.mock +def test_remove_totp(admin_client: Client) -> None: + respx.delete(f"{BASE_URL}/v1/auth/totp").mock( + return_value=httpx.Response(204) + ) + admin_client.remove_totp("acc-001") # should not raise +@respx.mock +def test_issue_service_token(admin_client: Client) -> None: + respx.post(f"{BASE_URL}/v1/token/issue").mock( + return_value=httpx.Response( + 200, + json={"token": "svc-token-xyz", "expires_at": "2099-12-31T00:00:00Z"}, + ) + ) + token, expires_at = admin_client.issue_service_token("acc-001") + assert token == "svc-token-xyz" + assert expires_at == "2099-12-31T00:00:00Z" +@respx.mock +def test_revoke_token(admin_client: Client) -> None: + jti = "some-jti-uuid" + respx.delete(f"{BASE_URL}/v1/token/{jti}").mock( + return_value=httpx.Response(204) + ) + admin_client.revoke_token(jti) # should not raise +@respx.mock def test_create_account(admin_client: Client) -> None: respx.post(f"{BASE_URL}/v1/accounts").mock( return_value=httpx.Response(201, json=SAMPLE_ACCOUNT) ) - acc = admin_client.create_account("alice", "user", password="pass123") + acc = admin_client.create_account("alice", "human", password="pass123") assert isinstance(acc, Account) assert acc.id == "acc-001" assert acc.username == "alice" @@ -161,15 +245,14 @@ def test_create_account_conflict(admin_client: Client) -> None: return_value=httpx.Response(409, json={"error": "username already exists"}) ) with pytest.raises(MciasConflictError) as exc_info: - admin_client.create_account("alice", "user") + admin_client.create_account("alice", "human") assert exc_info.value.status_code == 409 @respx.mock def test_list_accounts(admin_client: Client) -> None: second = {**SAMPLE_ACCOUNT, "id": "acc-002"} + # API returns a plain JSON array, not a wrapper object respx.get(f"{BASE_URL}/v1/accounts").mock( - return_value=httpx.Response( - 200, json={"accounts": [SAMPLE_ACCOUNT, second]} - ) + return_value=httpx.Response(200, json=[SAMPLE_ACCOUNT, second]) ) accounts = admin_client.list_accounts() assert len(accounts) == 2 @@ -183,12 +266,12 @@ def test_get_account(admin_client: Client) -> None: assert acc.id == "acc-001" @respx.mock def test_update_account(admin_client: Client) -> None: - updated = {**SAMPLE_ACCOUNT, "status": "suspended"} + # PATCH /v1/accounts/{id} returns 204 No Content respx.patch(f"{BASE_URL}/v1/accounts/acc-001").mock( - return_value=httpx.Response(200, json=updated) + return_value=httpx.Response(204) ) - acc = admin_client.update_account("acc-001", status="suspended") - assert acc.status == "suspended" + result = admin_client.update_account("acc-001", status="inactive") + assert result is None @respx.mock def test_delete_account(admin_client: Client) -> None: respx.delete(f"{BASE_URL}/v1/accounts/acc-001").mock( @@ -209,23 +292,11 @@ def test_set_roles(admin_client: Client) -> None: ) admin_client.set_roles("acc-001", ["viewer"]) # should not raise @respx.mock -def test_issue_service_token(admin_client: Client) -> None: - respx.post(f"{BASE_URL}/v1/accounts/acc-001/token").mock( - return_value=httpx.Response( - 200, - json={"token": "svc-token-xyz", "expires_at": "2099-12-31T00:00:00Z"}, - ) - ) - token, expires_at = admin_client.issue_service_token("acc-001") - assert token == "svc-token-xyz" - assert expires_at == "2099-12-31T00:00:00Z" -@respx.mock -def test_revoke_token(admin_client: Client) -> None: - jti = "some-jti-uuid" - respx.delete(f"{BASE_URL}/v1/token/{jti}").mock( +def test_admin_set_password(admin_client: Client) -> None: + respx.put(f"{BASE_URL}/v1/accounts/acc-001/password").mock( return_value=httpx.Response(204) ) - admin_client.revoke_token(jti) # should not raise + admin_client.admin_set_password("acc-001", "new-secure-password") # should not raise SAMPLE_PG_CREDS: dict[str, object] = { "host": "db.example.com", "port": 5432, @@ -256,6 +327,68 @@ def test_set_pg_creds(admin_client: Client) -> None: username="appuser", password="s3cr3t", ) # should not raise +@respx.mock +def test_get_account_tags(admin_client: Client) -> None: + respx.get(f"{BASE_URL}/v1/accounts/acc-001/tags").mock( + return_value=httpx.Response(200, json={"tags": ["env:production", "svc:payments-api"]}) + ) + tags = admin_client.get_account_tags("acc-001") + assert tags == ["env:production", "svc:payments-api"] +@respx.mock +def test_set_account_tags(admin_client: Client) -> None: + respx.put(f"{BASE_URL}/v1/accounts/acc-001/tags").mock( + return_value=httpx.Response(200, json={"tags": ["env:staging"]}) + ) + tags = admin_client.set_account_tags("acc-001", ["env:staging"]) + assert tags == ["env:staging"] +@respx.mock +def test_list_policy_rules(admin_client: Client) -> None: + respx.get(f"{BASE_URL}/v1/policy/rules").mock( + return_value=httpx.Response(200, json=[SAMPLE_POLICY_RULE]) + ) + rules = admin_client.list_policy_rules() + assert len(rules) == 1 + assert isinstance(rules[0], PolicyRule) + assert rules[0].id == 1 + assert rules[0].rule.effect == "allow" +@respx.mock +def test_create_policy_rule(admin_client: Client) -> None: + respx.post(f"{BASE_URL}/v1/policy/rules").mock( + return_value=httpx.Response(201, json=SAMPLE_POLICY_RULE) + ) + rule_body = RuleBody(effect="allow", actions=["pgcreds:read"], resource_type="pgcreds") + rule = admin_client.create_policy_rule( + "Allow payments-api to read its own pgcreds", + rule_body, + priority=50, + ) + assert isinstance(rule, PolicyRule) + assert rule.id == 1 + assert rule.description == "Allow payments-api to read its own pgcreds" +@respx.mock +def test_get_policy_rule(admin_client: Client) -> None: + respx.get(f"{BASE_URL}/v1/policy/rules/1").mock( + return_value=httpx.Response(200, json=SAMPLE_POLICY_RULE) + ) + rule = admin_client.get_policy_rule(1) + assert isinstance(rule, PolicyRule) + assert rule.id == 1 + assert rule.enabled is True +@respx.mock +def test_update_policy_rule(admin_client: Client) -> None: + updated = {**SAMPLE_POLICY_RULE, "enabled": False} + respx.patch(f"{BASE_URL}/v1/policy/rules/1").mock( + return_value=httpx.Response(200, json=updated) + ) + rule = admin_client.update_policy_rule(1, enabled=False) + assert isinstance(rule, PolicyRule) + assert rule.enabled is False +@respx.mock +def test_delete_policy_rule(admin_client: Client) -> None: + respx.delete(f"{BASE_URL}/v1/policy/rules/1").mock( + return_value=httpx.Response(204) + ) + admin_client.delete_policy_rule(1) # should not raise @pytest.mark.parametrize( ("status_code", "exc_class"), [ @@ -264,6 +397,7 @@ def test_set_pg_creds(admin_client: Client) -> None: (403, MciasForbiddenError), (404, MciasNotFoundError), (409, MciasConflictError), + (429, MciasRateLimitError), (500, MciasServerError), ], ) diff --git a/clients/rust/src/lib.rs b/clients/rust/src/lib.rs index fa4a856..349b884 100644 --- a/clients/rust/src/lib.rs +++ b/clients/rust/src/lib.rs @@ -70,7 +70,7 @@ pub enum MciasError { Decode(String), } -// ---- Data types ---- +// ---- Public data types ---- /// Account information returned by the server. #[derive(Debug, Clone, Deserialize)] @@ -101,6 +101,11 @@ pub struct TokenClaims { pub struct PublicKey { pub kty: String, pub crv: String, + /// Key use — always `"sig"` for the MCIAS signing key. + #[serde(rename = "use")] + pub key_use: Option, + /// Algorithm — always `"EdDSA"`. Validate this before trusting the key. + pub alg: Option, pub x: String, } @@ -114,6 +119,106 @@ pub struct PgCreds { pub password: String, } +/// Audit log entry returned by `GET /v1/audit`. +#[derive(Debug, Clone, Deserialize)] +pub struct AuditEvent { + pub id: i64, + pub event_type: String, + pub event_time: String, + pub ip_address: String, + pub actor_id: Option, + pub target_id: Option, + pub details: Option, +} + +/// Paginated response from `GET /v1/audit`. +#[derive(Debug, Clone, Deserialize)] +pub struct AuditPage { + pub events: Vec, + pub total: i64, + pub limit: i64, + pub offset: i64, +} + +/// Query parameters for `GET /v1/audit`. +#[derive(Debug, Clone, Default)] +pub struct AuditQuery { + pub limit: Option, + pub offset: Option, + pub event_type: Option, + pub actor_id: Option, +} + +/// A single operator-defined policy rule. +#[derive(Debug, Clone, Deserialize)] +pub struct PolicyRule { + pub id: i64, + pub priority: i64, + pub description: String, + pub rule: RuleBody, + pub enabled: bool, + pub not_before: Option, + pub expires_at: Option, + pub created_at: String, + pub updated_at: String, +} + +/// The match conditions and effect of a policy rule. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RuleBody { + pub effect: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub roles: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub account_types: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub subject_uuid: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub actions: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub resource_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub owner_matches_subject: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub service_names: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub required_tags: Option>, +} + +/// Request body for `POST /v1/policy/rules`. +#[derive(Debug, Clone, Serialize)] +pub struct CreatePolicyRuleRequest { + pub description: String, + pub rule: RuleBody, + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub not_before: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_at: Option, +} + +/// Request body for `PATCH /v1/policy/rules/{id}`. +#[derive(Debug, Clone, Serialize, Default)] +pub struct UpdatePolicyRuleRequest { + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub enabled: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub rule: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub not_before: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub clear_not_before: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub clear_expires_at: Option, +} + // ---- Internal request/response types ---- #[derive(Serialize)] @@ -136,6 +241,22 @@ struct ErrorResponse { error: String, } +#[derive(Deserialize)] +struct RolesResponse { + roles: Vec, +} + +#[derive(Deserialize)] +struct TagsResponse { + tags: Vec, +} + +#[derive(Deserialize)] +struct TotpEnrollResponse { + secret: String, + otpauth_uri: String, +} + // ---- Client options ---- /// Configuration options for the MCIAS client. @@ -160,6 +281,7 @@ pub struct Client { base_url: String, http: reqwest::Client, /// Bearer token storage. `Arc>` so clones share the token. + /// Security: the token is never logged or included in error messages. token: Arc>>, } @@ -285,9 +407,9 @@ impl Client { } /// Update an account's status. Allowed values: `"active"`, `"inactive"`. - pub async fn update_account(&self, id: &str, status: &str) -> Result { + pub async fn update_account(&self, id: &str, status: &str) -> Result<(), MciasError> { let body = serde_json::json!({ "status": status }); - self.patch(&format!("/v1/accounts/{id}"), &body).await + self.patch_no_content(&format!("/v1/accounts/{id}"), &body).await } /// Soft-delete an account and revoke all its tokens. @@ -299,13 +421,17 @@ impl Client { /// Get all roles assigned to an account. pub async fn get_roles(&self, account_id: &str) -> Result, MciasError> { - self.get(&format!("/v1/accounts/{account_id}/roles")).await + // Security: spec wraps roles in {"roles": [...]}, unwrap before returning. + let resp: RolesResponse = self.get(&format!("/v1/accounts/{account_id}/roles")).await?; + Ok(resp.roles) } /// Replace the complete role set for an account. pub async fn set_roles(&self, account_id: &str, roles: &[&str]) -> Result<(), MciasError> { let url = format!("/v1/accounts/{account_id}/roles"); - self.put_no_content(&url, roles).await + // Spec requires {"roles": [...]} wrapper. + let body = serde_json::json!({ "roles": roles }); + self.put_no_content(&url, &body).await } // ---- Token management (admin only) ---- @@ -354,10 +480,142 @@ impl Client { .await } + // ---- TOTP enrollment (authenticated) ---- + + /// Begin TOTP enrollment. Returns `(secret, otpauth_uri)`. + /// The secret is shown once; store it in an authenticator app immediately. + pub async fn enroll_totp(&self) -> Result<(String, String), MciasError> { + let resp: TotpEnrollResponse = + self.post("/v1/auth/totp/enroll", &serde_json::json!({})).await?; + Ok((resp.secret, resp.otpauth_uri)) + } + + /// Confirm TOTP enrollment with the current 6-digit code. + /// On success, TOTP becomes required for all future logins. + pub async fn confirm_totp(&self, code: &str) -> Result<(), MciasError> { + let body = serde_json::json!({ "code": code }); + self.post_empty_body("/v1/auth/totp/confirm", &body).await + } + + // ---- Password management ---- + + /// Change the caller's own password (self-service). Requires the current + /// password to guard against token-theft attacks. + pub async fn change_password( + &self, + current_password: &str, + new_password: &str, + ) -> Result<(), MciasError> { + let body = serde_json::json!({ + "current_password": current_password, + "new_password": new_password, + }); + self.put_no_content("/v1/auth/password", &body).await + } + + // ---- Admin: TOTP removal ---- + + /// Remove TOTP enrollment from an account (admin). Use for recovery when + /// a user loses their TOTP device. + pub async fn remove_totp(&self, account_id: &str) -> Result<(), MciasError> { + let body = serde_json::json!({ "account_id": account_id }); + self.delete_with_body("/v1/auth/totp", &body).await + } + + // ---- Admin: password reset ---- + + /// Reset an account's password without requiring the current password. + pub async fn admin_set_password( + &self, + account_id: &str, + new_password: &str, + ) -> Result<(), MciasError> { + let body = serde_json::json!({ "new_password": new_password }); + self.put_no_content(&format!("/v1/accounts/{account_id}/password"), &body) + .await + } + + // ---- Account tags (admin) ---- + + /// Get all tags for an account. + pub async fn get_tags(&self, account_id: &str) -> Result, MciasError> { + let resp: TagsResponse = + self.get(&format!("/v1/accounts/{account_id}/tags")).await?; + Ok(resp.tags) + } + + /// Replace the full tag set for an account atomically. Pass an empty slice + /// to clear all tags. Returns the updated tag list. + pub async fn set_tags( + &self, + account_id: &str, + tags: &[&str], + ) -> Result, MciasError> { + let body = serde_json::json!({ "tags": tags }); + let resp: TagsResponse = + self.put_with_response(&format!("/v1/accounts/{account_id}/tags"), &body).await?; + Ok(resp.tags) + } + + // ---- Audit log (admin) ---- + + /// Query the audit log. Returns a paginated [`AuditPage`]. + pub async fn list_audit(&self, query: AuditQuery) -> Result { + let mut params: Vec<(&str, String)> = Vec::new(); + if let Some(limit) = query.limit { + params.push(("limit", limit.to_string())); + } + if let Some(offset) = query.offset { + params.push(("offset", offset.to_string())); + } + if let Some(ref et) = query.event_type { + params.push(("event_type", et.clone())); + } + if let Some(ref aid) = query.actor_id { + params.push(("actor_id", aid.clone())); + } + self.get_with_query("/v1/audit", ¶ms).await + } + + // ---- Policy rules (admin) ---- + + /// List all operator-defined policy rules ordered by priority. + pub async fn list_policy_rules(&self) -> Result, MciasError> { + self.get("/v1/policy/rules").await + } + + /// Create a new policy rule. + pub async fn create_policy_rule( + &self, + req: CreatePolicyRuleRequest, + ) -> Result { + self.post_expect_status("/v1/policy/rules", &req, StatusCode::CREATED) + .await + } + + /// Get a single policy rule by ID. + pub async fn get_policy_rule(&self, id: i64) -> Result { + self.get(&format!("/v1/policy/rules/{id}")).await + } + + /// Update a policy rule. Omitted fields are left unchanged. + pub async fn update_policy_rule( + &self, + id: i64, + req: UpdatePolicyRuleRequest, + ) -> Result { + self.patch(&format!("/v1/policy/rules/{id}"), &req).await + } + + /// Delete a policy rule permanently. + pub async fn delete_policy_rule(&self, id: i64) -> Result<(), MciasError> { + self.delete(&format!("/v1/policy/rules/{id}")).await + } + // ---- HTTP helpers ---- - /// Build a request with the Authorization header set from the stored token. - /// Security: the token is read under a read-lock and is not logged. + /// Build the Authorization header value from the stored token. + /// Security: the token is read under a read-lock and is never logged. async fn auth_header(&self) -> Option { let guard = self.token.read().await; guard.as_deref().and_then(|tok| { @@ -383,6 +641,22 @@ impl Client { self.expect_success(resp).await } + async fn get_with_query Deserialize<'de>>( + &self, + path: &str, + params: &[(&str, String)], + ) -> Result { + let mut req = self + .http + .get(format!("{}{path}", self.base_url)) + .query(params); + if let Some(auth) = self.auth_header().await { + req = req.header(header::AUTHORIZATION, auth); + } + let resp = req.send().await?; + self.decode(resp).await + } + async fn post Deserialize<'de>>( &self, path: &str, @@ -434,6 +708,19 @@ impl Client { self.expect_success(resp).await } + /// POST with a JSON body that expects a 2xx (no body) response. + async fn post_empty_body(&self, path: &str, body: &B) -> Result<(), MciasError> { + let mut req = self + .http + .post(format!("{}{path}", self.base_url)) + .json(body); + if let Some(auth) = self.auth_header().await { + req = req.header(header::AUTHORIZATION, auth); + } + let resp = req.send().await?; + self.expect_success(resp).await + } + async fn patch Deserialize<'de>>( &self, path: &str, @@ -450,6 +737,18 @@ impl Client { self.decode(resp).await } + async fn patch_no_content(&self, path: &str, body: &B) -> Result<(), MciasError> { + let mut req = self + .http + .patch(format!("{}{path}", self.base_url)) + .json(body); + if let Some(auth) = self.auth_header().await { + req = req.header(header::AUTHORIZATION, auth); + } + let resp = req.send().await?; + self.expect_success(resp).await + } + async fn put_no_content(&self, path: &str, body: &B) -> Result<(), MciasError> { let mut req = self .http @@ -462,6 +761,22 @@ impl Client { self.expect_success(resp).await } + async fn put_with_response Deserialize<'de>>( + &self, + path: &str, + body: &B, + ) -> Result { + let mut req = self + .http + .put(format!("{}{path}", self.base_url)) + .json(body); + if let Some(auth) = self.auth_header().await { + req = req.header(header::AUTHORIZATION, auth); + } + let resp = req.send().await?; + self.decode(resp).await + } + async fn delete(&self, path: &str) -> Result<(), MciasError> { let mut req = self.http.delete(format!("{}{path}", self.base_url)); if let Some(auth) = self.auth_header().await { @@ -471,6 +786,19 @@ impl Client { self.expect_success(resp).await } + /// DELETE with a JSON request body (used by `DELETE /v1/auth/totp`). + async fn delete_with_body(&self, path: &str, body: &B) -> Result<(), MciasError> { + let mut req = self + .http + .delete(format!("{}{path}", self.base_url)) + .json(body); + if let Some(auth) = self.auth_header().await { + req = req.header(header::AUTHORIZATION, auth); + } + let resp = req.send().await?; + self.expect_success(resp).await + } + async fn decode Deserialize<'de>>( &self, resp: reqwest::Response, diff --git a/clients/rust/tests/client_tests.rs b/clients/rust/tests/client_tests.rs index f433d8c..06a8980 100644 --- a/clients/rust/tests/client_tests.rs +++ b/clients/rust/tests/client_tests.rs @@ -1,12 +1,18 @@ -use mcias_client::{Client, ClientOptions, MciasError}; +use mcias_client::{ + AuditQuery, Client, ClientOptions, CreatePolicyRuleRequest, MciasError, RuleBody, + UpdatePolicyRuleRequest, +}; use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; async fn admin_client(server: &MockServer) -> Client { - Client::new(&server.uri(), ClientOptions { - token: Some("admin-token".to_string()), - ..Default::default() - }) + Client::new( + &server.uri(), + ClientOptions { + token: Some("admin-token".to_string()), + ..Default::default() + }, + ) .unwrap() } @@ -48,7 +54,10 @@ async fn test_health_server_error() { let c = Client::new(&server.uri(), ClientOptions::default()).unwrap(); let err = c.health().await.unwrap_err(); - assert!(matches!(err, MciasError::Server { .. }), "expected Server error, got {err:?}"); + assert!( + matches!(err, MciasError::Server { .. }), + "expected Server error, got {err:?}" + ); } // ---- public key ---- @@ -61,6 +70,8 @@ async fn test_get_public_key() { .respond_with(json_body(serde_json::json!({ "kty": "OKP", "crv": "Ed25519", + "use": "sig", + "alg": "EdDSA", "x": "11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo" }))) .mount(&server) @@ -70,6 +81,8 @@ async fn test_get_public_key() { let pk = c.get_public_key().await.expect("get_public_key should succeed"); assert_eq!(pk.kty, "OKP"); assert_eq!(pk.crv, "Ed25519"); + assert_eq!(pk.key_use.as_deref(), Some("sig")); + assert_eq!(pk.alg.as_deref(), Some("EdDSA")); } // ---- login ---- @@ -99,7 +112,10 @@ async fn test_login_bad_credentials() { let server = MockServer::start().await; Mock::given(method("POST")) .and(path("/v1/auth/login")) - .respond_with(json_body_status(401, serde_json::json!({"error": "invalid credentials"}))) + .respond_with(json_body_status( + 401, + serde_json::json!({"error": "invalid credentials"}), + )) .mount(&server) .await; @@ -119,10 +135,13 @@ async fn test_logout_clears_token() { .mount(&server) .await; - let c = Client::new(&server.uri(), ClientOptions { - token: Some("existing-token".to_string()), - ..Default::default() - }) + let c = Client::new( + &server.uri(), + ClientOptions { + token: Some("existing-token".to_string()), + ..Default::default() + }, + ) .unwrap(); c.logout().await.unwrap(); assert!(c.token().await.is_none(), "token should be cleared after logout"); @@ -142,10 +161,13 @@ async fn test_renew_token() { .mount(&server) .await; - let c = Client::new(&server.uri(), ClientOptions { - token: Some("old-token".to_string()), - ..Default::default() - }) + let c = Client::new( + &server.uri(), + ClientOptions { + token: Some("old-token".to_string()), + ..Default::default() + }, + ) .unwrap(); let (tok, _) = c.renew_token().await.unwrap(); assert_eq!(tok, "new-token"); @@ -224,7 +246,10 @@ async fn test_create_account_conflict() { let server = MockServer::start().await; Mock::given(method("POST")) .and(path("/v1/accounts")) - .respond_with(json_body_status(409, serde_json::json!({"error": "username already exists"}))) + .respond_with(json_body_status( + 409, + serde_json::json!({"error": "username already exists"}), + )) .mount(&server) .await; @@ -259,7 +284,10 @@ async fn test_get_account_not_found() { let server = MockServer::start().await; Mock::given(method("GET")) .and(path("/v1/accounts/missing")) - .respond_with(json_body_status(404, serde_json::json!({"error": "account not found"}))) + .respond_with(json_body_status( + 404, + serde_json::json!({"error": "account not found"}), + )) .mount(&server) .await; @@ -271,19 +299,15 @@ async fn test_get_account_not_found() { #[tokio::test] async fn test_update_account() { let server = MockServer::start().await; + // PATCH /v1/accounts/{id} returns 204 No Content per spec. Mock::given(method("PATCH")) .and(path("/v1/accounts/uuid-1")) - .respond_with(json_body(serde_json::json!({ - "id": "uuid-1", "username": "alice", "account_type": "human", - "status": "inactive", "created_at": "2023-11-15T12:00:00Z", - "updated_at": "2023-11-15T13:00:00Z", "totp_enabled": false - }))) + .respond_with(ResponseTemplate::new(204)) .mount(&server) .await; let c = admin_client(&server).await; - let a = c.update_account("uuid-1", "inactive").await.unwrap(); - assert_eq!(a.status, "inactive"); + c.update_account("uuid-1", "inactive").await.unwrap(); } #[tokio::test] @@ -305,12 +329,14 @@ async fn test_delete_account() { async fn test_get_set_roles() { let server = MockServer::start().await; + // Spec wraps the array: {"roles": [...]} Mock::given(method("GET")) .and(path("/v1/accounts/uuid-1/roles")) - .respond_with(json_body(serde_json::json!(["admin", "viewer"]))) + .respond_with(json_body(serde_json::json!({"roles": ["admin", "viewer"]}))) .mount(&server) .await; + // Spec requires {"roles": [...]} in the PUT body. Mock::given(method("PUT")) .and(path("/v1/accounts/uuid-1/roles")) .respond_with(ResponseTemplate::new(204)) @@ -363,7 +389,10 @@ async fn test_pg_creds_not_found() { let server = MockServer::start().await; Mock::given(method("GET")) .and(path("/v1/accounts/uuid-1/pgcreds")) - .respond_with(json_body_status(404, serde_json::json!({"error": "no pg credentials found"}))) + .respond_with(json_body_status( + 404, + serde_json::json!({"error": "no pg credentials found"}), + )) .mount(&server) .await; @@ -405,6 +434,298 @@ async fn test_set_get_pg_creds() { assert_eq!(creds.password, "dbpass"); } +// ---- TOTP ---- + +#[tokio::test] +async fn test_enroll_totp() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/auth/totp/enroll")) + .respond_with(json_body(serde_json::json!({ + "secret": "JBSWY3DPEHPK3PXP", + "otpauth_uri": "otpauth://totp/MCIAS:alice?secret=JBSWY3DPEHPK3PXP&issuer=MCIAS" + }))) + .mount(&server) + .await; + + let c = admin_client(&server).await; + let (secret, uri) = c.enroll_totp().await.unwrap(); + assert_eq!(secret, "JBSWY3DPEHPK3PXP"); + assert!(uri.starts_with("otpauth://totp/")); +} + +#[tokio::test] +async fn test_confirm_totp() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/auth/totp/confirm")) + .respond_with(ResponseTemplate::new(204)) + .mount(&server) + .await; + + let c = admin_client(&server).await; + c.confirm_totp("123456").await.unwrap(); +} + +#[tokio::test] +async fn test_remove_totp() { + let server = MockServer::start().await; + Mock::given(method("DELETE")) + .and(path("/v1/auth/totp")) + .respond_with(ResponseTemplate::new(204)) + .mount(&server) + .await; + + let c = admin_client(&server).await; + c.remove_totp("some-account-uuid").await.unwrap(); +} + +// ---- password management ---- + +#[tokio::test] +async fn test_change_password() { + let server = MockServer::start().await; + Mock::given(method("PUT")) + .and(path("/v1/auth/password")) + .respond_with(ResponseTemplate::new(204)) + .mount(&server) + .await; + + let c = admin_client(&server).await; + c.change_password("old-pass", "new-pass-long-enough").await.unwrap(); +} + +#[tokio::test] +async fn test_change_password_wrong_current() { + let server = MockServer::start().await; + Mock::given(method("PUT")) + .and(path("/v1/auth/password")) + .respond_with(json_body_status( + 401, + serde_json::json!({"error": "current password is incorrect", "code": "unauthorized"}), + )) + .mount(&server) + .await; + + let c = admin_client(&server).await; + let err = c + .change_password("wrong", "new-pass-long-enough") + .await + .unwrap_err(); + assert!(matches!(err, MciasError::Auth(_))); +} + +#[tokio::test] +async fn test_admin_set_password() { + let server = MockServer::start().await; + Mock::given(method("PUT")) + .and(path("/v1/accounts/uuid-1/password")) + .respond_with(ResponseTemplate::new(204)) + .mount(&server) + .await; + + let c = admin_client(&server).await; + c.admin_set_password("uuid-1", "new-pass-long-enough").await.unwrap(); +} + +// ---- tags ---- + +#[tokio::test] +async fn test_get_set_tags() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/v1/accounts/uuid-1/tags")) + .respond_with(json_body( + serde_json::json!({"tags": ["env:production", "svc:payments-api"]}), + )) + .mount(&server) + .await; + + Mock::given(method("PUT")) + .and(path("/v1/accounts/uuid-1/tags")) + .respond_with(json_body(serde_json::json!({"tags": ["env:staging"]}))) + .mount(&server) + .await; + + let c = admin_client(&server).await; + let tags = c.get_tags("uuid-1").await.unwrap(); + assert_eq!(tags, vec!["env:production", "svc:payments-api"]); + + let updated = c.set_tags("uuid-1", &["env:staging"]).await.unwrap(); + assert_eq!(updated, vec!["env:staging"]); +} + +// ---- audit log ---- + +#[tokio::test] +async fn test_list_audit() { + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/v1/audit")) + .respond_with(json_body(serde_json::json!({ + "events": [ + { + "id": 1, + "event_type": "login_ok", + "event_time": "2026-03-11T09:01:23Z", + "ip_address": "192.0.2.1", + "actor_id": "uuid-1", + "target_id": null, + "details": null + } + ], + "total": 1, + "limit": 50, + "offset": 0 + }))) + .mount(&server) + .await; + + let c = admin_client(&server).await; + let page = c.list_audit(AuditQuery::default()).await.unwrap(); + assert_eq!(page.total, 1); + assert_eq!(page.events.len(), 1); + assert_eq!(page.events[0].event_type, "login_ok"); +} + +// ---- policy rules ---- + +#[tokio::test] +async fn test_list_policy_rules() { + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/v1/policy/rules")) + .respond_with(json_body(serde_json::json!([]))) + .mount(&server) + .await; + + let c = admin_client(&server).await; + let rules = c.list_policy_rules().await.unwrap(); + assert!(rules.is_empty()); +} + +#[tokio::test] +async fn test_create_policy_rule() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/policy/rules")) + .respond_with( + ResponseTemplate::new(201) + .set_body_json(serde_json::json!({ + "id": 1, + "priority": 100, + "description": "Allow payments-api to read its own pgcreds", + "rule": {"effect": "allow", "roles": ["svc:payments-api"]}, + "enabled": true, + "not_before": null, + "expires_at": null, + "created_at": "2026-03-11T09:00:00Z", + "updated_at": "2026-03-11T09:00:00Z" + })) + .insert_header("content-type", "application/json"), + ) + .mount(&server) + .await; + + let c = admin_client(&server).await; + let rule = c + .create_policy_rule(CreatePolicyRuleRequest { + description: "Allow payments-api to read its own pgcreds".to_string(), + rule: RuleBody { + effect: "allow".to_string(), + roles: Some(vec!["svc:payments-api".to_string()]), + account_types: None, + subject_uuid: None, + actions: None, + resource_type: None, + owner_matches_subject: None, + service_names: None, + required_tags: None, + }, + priority: None, + not_before: None, + expires_at: None, + }) + .await + .unwrap(); + assert_eq!(rule.id, 1); + assert_eq!(rule.description, "Allow payments-api to read its own pgcreds"); +} + +#[tokio::test] +async fn test_get_policy_rule() { + let server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/v1/policy/rules/1")) + .respond_with(json_body(serde_json::json!({ + "id": 1, + "priority": 100, + "description": "test rule", + "rule": {"effect": "deny"}, + "enabled": true, + "not_before": null, + "expires_at": null, + "created_at": "2026-03-11T09:00:00Z", + "updated_at": "2026-03-11T09:00:00Z" + }))) + .mount(&server) + .await; + + let c = admin_client(&server).await; + let rule = c.get_policy_rule(1).await.unwrap(); + assert_eq!(rule.id, 1); + assert_eq!(rule.rule.effect, "deny"); +} + +#[tokio::test] +async fn test_update_policy_rule() { + let server = MockServer::start().await; + Mock::given(method("PATCH")) + .and(path("/v1/policy/rules/1")) + .respond_with(json_body(serde_json::json!({ + "id": 1, + "priority": 75, + "description": "updated rule", + "rule": {"effect": "allow"}, + "enabled": false, + "not_before": null, + "expires_at": null, + "created_at": "2026-03-11T09:00:00Z", + "updated_at": "2026-03-11T10:00:00Z" + }))) + .mount(&server) + .await; + + let c = admin_client(&server).await; + let rule = c + .update_policy_rule( + 1, + UpdatePolicyRuleRequest { + enabled: Some(false), + priority: Some(75), + ..Default::default() + }, + ) + .await + .unwrap(); + assert!(!rule.enabled); + assert_eq!(rule.priority, 75); +} + +#[tokio::test] +async fn test_delete_policy_rule() { + let server = MockServer::start().await; + Mock::given(method("DELETE")) + .and(path("/v1/policy/rules/1")) + .respond_with(ResponseTemplate::new(204)) + .mount(&server) + .await; + + let c = admin_client(&server).await; + c.delete_policy_rule(1).await.unwrap(); +} + // ---- error type coverage ---- #[tokio::test] @@ -416,11 +737,13 @@ async fn test_forbidden_error() { .mount(&server) .await; - // Use a non-admin token. - let c = Client::new(&server.uri(), ClientOptions { - token: Some("user-token".to_string()), - ..Default::default() - }) + let c = Client::new( + &server.uri(), + ClientOptions { + token: Some("user-token".to_string()), + ..Default::default() + }, + ) .unwrap(); let err = c.list_accounts().await.unwrap_err(); assert!(matches!(err, MciasError::Forbidden(_)));