clients: expand Go, Python, Rust client APIs

- Add TOTP enrollment/confirmation/removal to all clients
- Add password change and admin set-password endpoints
- Add account listing, status update, and tag management
- Add audit log listing with filter support
- Add policy rule CRUD operations
- Expand test coverage for all new endpoints across clients
- Fix .gitignore to exclude built binaries

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-12 20:29:11 -07:00
parent ec7c966ad2
commit cbcb1a0533
11 changed files with 1938 additions and 255 deletions

View File

@@ -4,5 +4,5 @@
2. Run `go test ./...` abort if failures 2. Run `go test ./...` abort if failures
3. Run `go vet ./...` 3. Run `go vet ./...`
4. Run `git add -A && git status` show user what will be committed 4. Run `git add -A && git status` show user what will be committed
5. Ask user for commit message 5. Generate an appropriate commit message based on your instructions.
6. Run `git commit -m "<message>"` and verify with `git log -1` 6. Run `git commit -m "<message>"` and verify with `git log -1`

7
.gitignore vendored
View File

@@ -34,5 +34,10 @@ clients/python/*.egg-info/
clients/lisp/**/*.fasl clients/lisp/**/*.fasl
# manual testing # manual testing
/run/ run/
.env .env
/cmd/mciasctl/mciasctl
/cmd/mciasdb/mciasdb
/cmd/mciasgrpcctl/mciasgrpcctl
/cmd/mciassrv/mciassrv

View File

@@ -3,6 +3,7 @@
// Security: bearer tokens are stored under a sync.RWMutex and are never written // Security: bearer tokens are stored under a sync.RWMutex and are never written
// to logs or included in error messages anywhere in this package. // to logs or included in error messages anywhere in this package.
package mciasgoclient package mciasgoclient
import ( import (
"bytes" "bytes"
"crypto/tls" "crypto/tls"
@@ -15,32 +16,43 @@ import (
"strings" "strings"
"sync" "sync"
) )
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Error types // Error types
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// MciasError is the base error type for all MCIAS client errors. // MciasError is the base error type for all MCIAS client errors.
type MciasError struct { type MciasError struct {
StatusCode int StatusCode int
Message string Message string
} }
func (e *MciasError) Error() string { func (e *MciasError) Error() string {
return fmt.Sprintf("mciasgoclient: HTTP %d: %s", e.StatusCode, e.Message) return fmt.Sprintf("mciasgoclient: HTTP %d: %s", e.StatusCode, e.Message)
} }
// MciasAuthError is returned for 401 Unauthorized responses. // MciasAuthError is returned for 401 Unauthorized responses.
type MciasAuthError struct{ MciasError } type MciasAuthError struct{ MciasError }
// MciasForbiddenError is returned for 403 Forbidden responses. // MciasForbiddenError is returned for 403 Forbidden responses.
type MciasForbiddenError struct{ MciasError } type MciasForbiddenError struct{ MciasError }
// MciasNotFoundError is returned for 404 Not Found responses. // MciasNotFoundError is returned for 404 Not Found responses.
type MciasNotFoundError struct{ MciasError } type MciasNotFoundError struct{ MciasError }
// MciasInputError is returned for 400 Bad Request responses. // MciasInputError is returned for 400 Bad Request responses.
type MciasInputError struct{ MciasError } type MciasInputError struct{ MciasError }
// MciasConflictError is returned for 409 Conflict responses. // MciasConflictError is returned for 409 Conflict responses.
type MciasConflictError struct{ MciasError } type MciasConflictError struct{ MciasError }
// MciasServerError is returned for 5xx responses. // MciasServerError is returned for 5xx responses.
type MciasServerError struct{ MciasError } type MciasServerError struct{ MciasError }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Data types // Data types
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Account represents a user or service account. // Account represents a user or service account.
type Account struct { type Account struct {
ID string `json:"id"` ID string `json:"id"`
@@ -51,6 +63,7 @@ type Account struct {
UpdatedAt string `json:"updated_at"` UpdatedAt string `json:"updated_at"`
TOTPEnabled bool `json:"totp_enabled"` TOTPEnabled bool `json:"totp_enabled"`
} }
// PublicKey represents the server's Ed25519 public key in JWK format. // PublicKey represents the server's Ed25519 public key in JWK format.
type PublicKey struct { type PublicKey struct {
Kty string `json:"kty"` Kty string `json:"kty"`
@@ -59,6 +72,7 @@ type PublicKey struct {
Use string `json:"use,omitempty"` Use string `json:"use,omitempty"`
Alg string `json:"alg,omitempty"` Alg string `json:"alg,omitempty"`
} }
// TokenClaims is returned by ValidateToken. // TokenClaims is returned by ValidateToken.
type TokenClaims struct { type TokenClaims struct {
Valid bool `json:"valid"` Valid bool `json:"valid"`
@@ -66,6 +80,7 @@ type TokenClaims struct {
Roles []string `json:"roles,omitempty"` Roles []string `json:"roles,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"` ExpiresAt string `json:"expires_at,omitempty"`
} }
// PGCreds holds Postgres connection credentials. // PGCreds holds Postgres connection credentials.
type PGCreds struct { type PGCreds struct {
Host string `json:"host"` Host string `json:"host"`
@@ -74,9 +89,94 @@ type PGCreds struct {
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
} }
// TOTPEnrollResponse is returned by EnrollTOTP.
type TOTPEnrollResponse struct {
Secret string `json:"secret"`
OTPAuthURI string `json:"otpauth_uri"`
}
// AuditEvent is a single entry in the audit log.
type AuditEvent struct {
ID int `json:"id"`
EventType string `json:"event_type"`
EventTime string `json:"event_time"`
ActorID string `json:"actor_id,omitempty"`
TargetID string `json:"target_id,omitempty"`
IPAddress string `json:"ip_address"`
Details string `json:"details,omitempty"`
}
// AuditListResponse is returned by ListAudit.
type AuditListResponse struct {
Events []AuditEvent `json:"events"`
Total int `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
// AuditFilter holds optional filter parameters for ListAudit.
type AuditFilter struct {
Limit int
Offset int
EventType string
ActorID string
}
// PolicyRuleBody holds the match conditions and effect of a policy rule.
// All fields except Effect are optional; an omitted field acts as a wildcard.
type PolicyRuleBody struct {
Effect string `json:"effect"`
Roles []string `json:"roles,omitempty"`
AccountTypes []string `json:"account_types,omitempty"`
SubjectUUID string `json:"subject_uuid,omitempty"`
Actions []string `json:"actions,omitempty"`
ResourceType string `json:"resource_type,omitempty"`
OwnerMatchesSubject bool `json:"owner_matches_subject,omitempty"`
ServiceNames []string `json:"service_names,omitempty"`
RequiredTags []string `json:"required_tags,omitempty"`
}
// PolicyRule is a complete operator-defined policy rule as returned by the API.
type PolicyRule struct {
ID int `json:"id"`
Priority int `json:"priority"`
Description string `json:"description"`
Rule PolicyRuleBody `json:"rule"`
Enabled bool `json:"enabled"`
NotBefore string `json:"not_before,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
// CreatePolicyRuleRequest holds the parameters for creating a policy rule.
type CreatePolicyRuleRequest struct {
Description string `json:"description"`
Priority int `json:"priority,omitempty"`
Rule PolicyRuleBody `json:"rule"`
NotBefore string `json:"not_before,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
}
// UpdatePolicyRuleRequest holds the parameters for updating a policy rule.
// All fields are optional; omitted fields are left unchanged.
// Set ClearNotBefore or ClearExpiresAt to true to remove those constraints.
type UpdatePolicyRuleRequest struct {
Description string `json:"description,omitempty"`
Priority *int `json:"priority,omitempty"`
Enabled *bool `json:"enabled,omitempty"`
Rule *PolicyRuleBody `json:"rule,omitempty"`
NotBefore string `json:"not_before,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
ClearNotBefore bool `json:"clear_not_before,omitempty"`
ClearExpiresAt bool `json:"clear_expires_at,omitempty"`
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Options and Client struct // Options and Client struct
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Options configures the MCIAS client. // Options configures the MCIAS client.
type Options struct { type Options struct {
// CACertPath is an optional path to a PEM-encoded CA certificate for TLS // CACertPath is an optional path to a PEM-encoded CA certificate for TLS
@@ -85,6 +185,7 @@ type Options struct {
// Token is an optional pre-existing bearer token. // Token is an optional pre-existing bearer token.
Token string Token string
} }
// Client is a thread-safe MCIAS REST API client. // Client is a thread-safe MCIAS REST API client.
// Security: the bearer token is guarded by a sync.RWMutex; it is never // Security: the bearer token is guarded by a sync.RWMutex; it is never
// written to logs or included in error messages in this library. // written to logs or included in error messages in this library.
@@ -94,9 +195,11 @@ type Client struct {
mu sync.RWMutex mu sync.RWMutex
token string token string
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Constructor // Constructor
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// New creates a new Client for the given serverURL. // New creates a new Client for the given serverURL.
// TLS 1.2 is the minimum version enforced on all connections. // TLS 1.2 is the minimum version enforced on all connections.
// If opts.CACertPath is set, that CA certificate is added to the trust pool. // If opts.CACertPath is set, that CA certificate is added to the trust pool.
@@ -126,20 +229,24 @@ func New(serverURL string, opts Options) (*Client, error) {
} }
return c, nil return c, nil
} }
// Token returns the current bearer token (empty string if not logged in). // Token returns the current bearer token (empty string if not logged in).
func (c *Client) Token() string { func (c *Client) Token() string {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
return c.token return c.token
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Internal helpers // Internal helpers
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func (c *Client) setToken(tok string) { func (c *Client) setToken(tok string) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.token = tok c.token = tok
} }
func (c *Client) do(method, path string, body interface{}, out interface{}) error { func (c *Client) do(method, path string, body interface{}, out interface{}) error {
var reqBody io.Reader var reqBody io.Reader
if body != nil { if body != nil {
@@ -195,6 +302,7 @@ func (c *Client) do(method, path string, body interface{}, out interface{}) erro
} }
return nil return nil
} }
func makeError(status int, msg string) error { func makeError(status int, msg string) error {
base := MciasError{StatusCode: status, Message: msg} base := MciasError{StatusCode: status, Message: msg}
switch { switch {
@@ -212,13 +320,16 @@ func makeError(status int, msg string) error {
return &MciasServerError{base} return &MciasServerError{base}
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// API methods // API methods — Public
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Health calls GET /v1/health. Returns nil if the server is healthy. // Health calls GET /v1/health. Returns nil if the server is healthy.
func (c *Client) Health() error { func (c *Client) Health() error {
return c.do(http.MethodGet, "/v1/health", nil, nil) return c.do(http.MethodGet, "/v1/health", nil, nil)
} }
// GetPublicKey returns the server's Ed25519 public key in JWK format. // GetPublicKey returns the server's Ed25519 public key in JWK format.
func (c *Client) GetPublicKey() (*PublicKey, error) { func (c *Client) GetPublicKey() (*PublicKey, error) {
var pk PublicKey var pk PublicKey
@@ -227,6 +338,7 @@ func (c *Client) GetPublicKey() (*PublicKey, error) {
} }
return &pk, nil return &pk, nil
} }
// Login authenticates with username and password. On success the token is // Login authenticates with username and password. On success the token is
// stored in the Client and returned along with the expiry timestamp. // stored in the Client and returned along with the expiry timestamp.
// totpCode may be empty for accounts without TOTP. // totpCode may be empty for accounts without TOTP.
@@ -245,6 +357,23 @@ func (c *Client) Login(username, password, totpCode string) (token, expiresAt st
c.setToken(resp.Token) c.setToken(resp.Token)
return resp.Token, resp.ExpiresAt, nil return resp.Token, resp.ExpiresAt, nil
} }
// ValidateToken validates a token string against the server.
// Returns claims; Valid is false (not an error) if the token is expired or
// revoked.
func (c *Client) ValidateToken(token string) (*TokenClaims, error) {
var claims TokenClaims
if err := c.do(http.MethodPost, "/v1/token/validate",
map[string]string{"token": token}, &claims); err != nil {
return nil, err
}
return &claims, nil
}
// ---------------------------------------------------------------------------
// API methods — Authenticated
// ---------------------------------------------------------------------------
// Logout revokes the current token on the server and clears it from the client. // Logout revokes the current token on the server and clears it from the client.
func (c *Client) Logout() error { func (c *Client) Logout() error {
if err := c.do(http.MethodPost, "/v1/auth/logout", nil, nil); err != nil { if err := c.do(http.MethodPost, "/v1/auth/logout", nil, nil); err != nil {
@@ -253,6 +382,7 @@ func (c *Client) Logout() error {
c.setToken("") c.setToken("")
return nil return nil
} }
// RenewToken exchanges the current token for a fresh one. // RenewToken exchanges the current token for a fresh one.
// The new token is stored in the client and returned. // The new token is stored in the client and returned.
func (c *Client) RenewToken() (token, expiresAt string, err error) { func (c *Client) RenewToken() (token, expiresAt string, err error) {
@@ -266,17 +396,63 @@ func (c *Client) RenewToken() (token, expiresAt string, err error) {
c.setToken(resp.Token) c.setToken(resp.Token)
return resp.Token, resp.ExpiresAt, nil return resp.Token, resp.ExpiresAt, nil
} }
// ValidateToken validates a token string against the server.
// Returns claims; Valid is false (not an error) if the token is expired or // EnrollTOTP begins TOTP enrollment for the authenticated account.
// revoked. // Returns a base32 secret and an otpauth:// URI for QR-code generation.
func (c *Client) ValidateToken(token string) (*TokenClaims, error) { // The secret is shown once; it is not retrievable after this call.
var claims TokenClaims // TOTP is not enforced until confirmed via ConfirmTOTP.
if err := c.do(http.MethodPost, "/v1/token/validate", func (c *Client) EnrollTOTP() (*TOTPEnrollResponse, error) {
map[string]string{"token": token}, &claims); err != nil { var resp TOTPEnrollResponse
if err := c.do(http.MethodPost, "/v1/auth/totp/enroll", nil, &resp); err != nil {
return nil, err return nil, err
} }
return &claims, nil return &resp, nil
} }
// ConfirmTOTP completes TOTP enrollment by verifying the current code against
// the pending secret. On success, TOTP becomes required for all future logins.
func (c *Client) ConfirmTOTP(code string) error {
return c.do(http.MethodPost, "/v1/auth/totp/confirm",
map[string]string{"code": code}, nil)
}
// ChangePassword changes the password of the currently authenticated human
// account. currentPassword is required to prevent token-theft attacks.
// On success, all active sessions except the caller's are revoked.
//
// Security: both passwords are transmitted over TLS only; the server verifies
// currentPassword with constant-time comparison before accepting the change.
func (c *Client) ChangePassword(currentPassword, newPassword string) error {
return c.do(http.MethodPut, "/v1/auth/password", map[string]string{
"current_password": currentPassword,
"new_password": newPassword,
}, nil)
}
// ---------------------------------------------------------------------------
// API methods — Admin: Auth
// ---------------------------------------------------------------------------
// RemoveTOTP clears TOTP enrollment for the given account (admin).
// Use for account recovery when a user has lost their TOTP device.
func (c *Client) RemoveTOTP(accountID string) error {
return c.do(http.MethodDelete, "/v1/auth/totp",
map[string]string{"account_id": accountID}, nil)
}
// ---------------------------------------------------------------------------
// API methods — Admin: Accounts
// ---------------------------------------------------------------------------
// ListAccounts returns all accounts. Requires admin role.
func (c *Client) ListAccounts() ([]Account, error) {
var accounts []Account
if err := c.do(http.MethodGet, "/v1/accounts", nil, &accounts); err != nil {
return nil, err
}
return accounts, nil
}
// CreateAccount creates a new account. accountType is "human" or "system". // CreateAccount creates a new account. accountType is "human" or "system".
// password is required for human accounts. // password is required for human accounts.
func (c *Client) CreateAccount(username, accountType, password string) (*Account, error) { func (c *Client) CreateAccount(username, accountType, password string) (*Account, error) {
@@ -293,14 +469,7 @@ func (c *Client) CreateAccount(username, accountType, password string) (*Account
} }
return &acct, nil return &acct, nil
} }
// ListAccounts returns all accounts. Requires admin role.
func (c *Client) ListAccounts() ([]Account, error) {
var accounts []Account
if err := c.do(http.MethodGet, "/v1/accounts", nil, &accounts); err != nil {
return nil, err
}
return accounts, nil
}
// GetAccount returns the account with the given ID. Requires admin role. // GetAccount returns the account with the given ID. Requires admin role.
func (c *Client) GetAccount(id string) (*Account, error) { func (c *Client) GetAccount(id string) (*Account, error) {
var acct Account var acct Account
@@ -309,23 +478,22 @@ func (c *Client) GetAccount(id string) (*Account, error) {
} }
return &acct, nil return &acct, nil
} }
// UpdateAccount updates mutable account fields. Requires admin role.
// Pass an empty string for fields that should not be changed. // UpdateAccount updates mutable account fields (currently only status).
func (c *Client) UpdateAccount(id, status string) (*Account, error) { // Requires admin role. Returns nil on success (HTTP 204).
func (c *Client) UpdateAccount(id, status string) error {
req := map[string]string{} req := map[string]string{}
if status != "" { if status != "" {
req["status"] = status req["status"] = status
} }
var acct Account return c.do(http.MethodPatch, "/v1/accounts/"+id, req, nil)
if err := c.do(http.MethodPatch, "/v1/accounts/"+id, req, &acct); err != nil {
return nil, err
}
return &acct, nil
} }
// DeleteAccount soft-deletes the account with the given ID. Requires admin. // DeleteAccount soft-deletes the account with the given ID. Requires admin.
func (c *Client) DeleteAccount(id string) error { func (c *Client) DeleteAccount(id string) error {
return c.do(http.MethodDelete, "/v1/accounts/"+id, nil, nil) return c.do(http.MethodDelete, "/v1/accounts/"+id, nil, nil)
} }
// GetRoles returns the roles for accountID. Requires admin. // GetRoles returns the roles for accountID. Requires admin.
func (c *Client) GetRoles(accountID string) ([]string, error) { func (c *Client) GetRoles(accountID string) ([]string, error) {
var resp struct { var resp struct {
@@ -336,11 +504,49 @@ func (c *Client) GetRoles(accountID string) ([]string, error) {
} }
return resp.Roles, nil return resp.Roles, nil
} }
// SetRoles replaces the role set for accountID. Requires admin. // SetRoles replaces the role set for accountID. Requires admin.
func (c *Client) SetRoles(accountID string, roles []string) error { func (c *Client) SetRoles(accountID string, roles []string) error {
return c.do(http.MethodPut, "/v1/accounts/"+accountID+"/roles", return c.do(http.MethodPut, "/v1/accounts/"+accountID+"/roles",
map[string][]string{"roles": roles}, nil) map[string][]string{"roles": roles}, nil)
} }
// AdminSetPassword resets a human account's password without requiring the
// current password. Requires admin. All active sessions for the target account
// are revoked on success.
func (c *Client) AdminSetPassword(accountID, newPassword string) error {
return c.do(http.MethodPut, "/v1/accounts/"+accountID+"/password",
map[string]string{"new_password": newPassword}, nil)
}
// GetAccountTags returns the current tag set for an account. Requires admin.
func (c *Client) GetAccountTags(accountID string) ([]string, error) {
var resp struct {
Tags []string `json:"tags"`
}
if err := c.do(http.MethodGet, "/v1/accounts/"+accountID+"/tags", nil, &resp); err != nil {
return nil, err
}
return resp.Tags, nil
}
// SetAccountTags replaces the full tag set for an account atomically.
// Pass an empty slice to clear all tags. Requires admin.
func (c *Client) SetAccountTags(accountID string, tags []string) ([]string, error) {
var resp struct {
Tags []string `json:"tags"`
}
if err := c.do(http.MethodPut, "/v1/accounts/"+accountID+"/tags",
map[string][]string{"tags": tags}, &resp); err != nil {
return nil, err
}
return resp.Tags, nil
}
// ---------------------------------------------------------------------------
// API methods — Admin: Tokens
// ---------------------------------------------------------------------------
// IssueServiceToken issues a long-lived token for a system account. Requires admin. // IssueServiceToken issues a long-lived token for a system account. Requires admin.
func (c *Client) IssueServiceToken(accountID string) (token, expiresAt string, err error) { func (c *Client) IssueServiceToken(accountID string) (token, expiresAt string, err error) {
var resp struct { var resp struct {
@@ -353,10 +559,16 @@ func (c *Client) IssueServiceToken(accountID string) (token, expiresAt string, e
} }
return resp.Token, resp.ExpiresAt, nil return resp.Token, resp.ExpiresAt, nil
} }
// RevokeToken revokes a token by JTI. Requires admin. // RevokeToken revokes a token by JTI. Requires admin.
func (c *Client) RevokeToken(jti string) error { func (c *Client) RevokeToken(jti string) error {
return c.do(http.MethodDelete, "/v1/token/"+jti, nil, nil) return c.do(http.MethodDelete, "/v1/token/"+jti, nil, nil)
} }
// ---------------------------------------------------------------------------
// API methods — Admin: Credentials
// ---------------------------------------------------------------------------
// GetPGCreds returns Postgres credentials for accountID. Requires admin. // GetPGCreds returns Postgres credentials for accountID. Requires admin.
func (c *Client) GetPGCreds(accountID string) (*PGCreds, error) { func (c *Client) GetPGCreds(accountID string) (*PGCreds, error) {
var creds PGCreds var creds PGCreds
@@ -365,6 +577,7 @@ func (c *Client) GetPGCreds(accountID string) (*PGCreds, error) {
} }
return &creds, nil return &creds, nil
} }
// SetPGCreds stores Postgres credentials for accountID. Requires admin. // SetPGCreds stores Postgres credentials for accountID. Requires admin.
// The password is sent over TLS and encrypted at rest server-side. // The password is sent over TLS and encrypted at rest server-side.
func (c *Client) SetPGCreds(accountID, host string, port int, database, username, password string) error { func (c *Client) SetPGCreds(accountID, host string, port int, database, username, password string) error {
@@ -376,3 +589,78 @@ func (c *Client) SetPGCreds(accountID, host string, port int, database, username
"password": password, "password": password,
}, nil) }, nil)
} }
// ---------------------------------------------------------------------------
// API methods — Admin: Audit
// ---------------------------------------------------------------------------
// ListAudit retrieves audit log entries, newest first. Requires admin.
// f may be zero-valued to use defaults (limit=50, offset=0, no filter).
func (c *Client) ListAudit(f AuditFilter) (*AuditListResponse, error) {
path := "/v1/audit?"
if f.Limit > 0 {
path += fmt.Sprintf("limit=%d&", f.Limit)
}
if f.Offset > 0 {
path += fmt.Sprintf("offset=%d&", f.Offset)
}
if f.EventType != "" {
path += fmt.Sprintf("event_type=%s&", f.EventType)
}
if f.ActorID != "" {
path += fmt.Sprintf("actor_id=%s&", f.ActorID)
}
path = strings.TrimRight(path, "&?")
var resp AuditListResponse
if err := c.do(http.MethodGet, path, nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// ---------------------------------------------------------------------------
// API methods — Admin: Policy
// ---------------------------------------------------------------------------
// ListPolicyRules returns all operator-defined policy rules ordered by
// priority (ascending). Requires admin.
func (c *Client) ListPolicyRules() ([]PolicyRule, error) {
var rules []PolicyRule
if err := c.do(http.MethodGet, "/v1/policy/rules", nil, &rules); err != nil {
return nil, err
}
return rules, nil
}
// CreatePolicyRule creates a new policy rule. Requires admin.
func (c *Client) CreatePolicyRule(req CreatePolicyRuleRequest) (*PolicyRule, error) {
var rule PolicyRule
if err := c.do(http.MethodPost, "/v1/policy/rules", req, &rule); err != nil {
return nil, err
}
return &rule, nil
}
// GetPolicyRule returns a single policy rule by integer ID. Requires admin.
func (c *Client) GetPolicyRule(id int) (*PolicyRule, error) {
var rule PolicyRule
if err := c.do(http.MethodGet, fmt.Sprintf("/v1/policy/rules/%d", id), nil, &rule); err != nil {
return nil, err
}
return &rule, nil
}
// UpdatePolicyRule updates one or more fields of an existing policy rule.
// Requires admin.
func (c *Client) UpdatePolicyRule(id int, req UpdatePolicyRuleRequest) (*PolicyRule, error) {
var rule PolicyRule
if err := c.do(http.MethodPatch, fmt.Sprintf("/v1/policy/rules/%d", id), req, &rule); err != nil {
return nil, err
}
return &rule, nil
}
// DeletePolicyRule permanently deletes a policy rule. Requires admin.
func (c *Client) DeletePolicyRule(id int) error {
return c.do(http.MethodDelete, fmt.Sprintf("/v1/policy/rules/%d", id), nil, nil)
}

View File

@@ -2,6 +2,7 @@
// All tests use inline httptest.NewServer mocks to keep this module // All tests use inline httptest.NewServer mocks to keep this module
// self-contained (no cross-module imports). // self-contained (no cross-module imports).
package mciasgoclient_test package mciasgoclient_test
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
@@ -9,12 +10,14 @@ import (
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
mciasgoclient "git.wntrmute.dev/kyle/mcias/clients/go" mciasgoclient "git.wntrmute.dev/kyle/mcias/clients/go"
) )
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// helpers // helpers
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// newTestClient creates a client pointed at the given test server URL.
func newTestClient(t *testing.T, serverURL string) *mciasgoclient.Client { func newTestClient(t *testing.T, serverURL string) *mciasgoclient.Client {
t.Helper() t.Helper()
c, err := mciasgoclient.New(serverURL, mciasgoclient.Options{}) c, err := mciasgoclient.New(serverURL, mciasgoclient.Options{})
@@ -23,19 +26,21 @@ func newTestClient(t *testing.T, serverURL string) *mciasgoclient.Client {
} }
return c return c
} }
// writeJSON is a shorthand for writing a JSON response.
func writeJSON(w http.ResponseWriter, status int, v interface{}) { func writeJSON(w http.ResponseWriter, status int, v interface{}) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status) w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(v) _ = json.NewEncoder(w).Encode(v)
} }
// writeError writes a JSON error body with the given status code.
func writeError(w http.ResponseWriter, status int, msg string) { func writeError(w http.ResponseWriter, status int, msg string) {
writeJSON(w, status, map[string]string{"error": msg}) writeJSON(w, status, map[string]string{"error": msg})
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestNew // TestNew
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
c, err := mciasgoclient.New("https://example.com", mciasgoclient.Options{}) c, err := mciasgoclient.New("https://example.com", mciasgoclient.Options{})
if err != nil { if err != nil {
@@ -45,6 +50,7 @@ func TestNew(t *testing.T) {
t.Fatal("expected non-nil client") t.Fatal("expected non-nil client")
} }
} }
func TestNewWithPresetToken(t *testing.T) { func TestNewWithPresetToken(t *testing.T) {
c, err := mciasgoclient.New("https://example.com", mciasgoclient.Options{Token: "preset-tok"}) c, err := mciasgoclient.New("https://example.com", mciasgoclient.Options{Token: "preset-tok"})
if err != nil { if err != nil {
@@ -54,15 +60,18 @@ func TestNewWithPresetToken(t *testing.T) {
t.Errorf("expected preset-tok, got %q", c.Token()) t.Errorf("expected preset-tok, got %q", c.Token())
} }
} }
func TestNewBadCACert(t *testing.T) { func TestNewBadCACert(t *testing.T) {
_, err := mciasgoclient.New("https://example.com", mciasgoclient.Options{CACertPath: "/nonexistent/ca.pem"}) _, err := mciasgoclient.New("https://example.com", mciasgoclient.Options{CACertPath: "/nonexistent/ca.pem"})
if err == nil { if err == nil {
t.Fatal("expected error for missing CA cert file") t.Fatal("expected error for missing CA cert file")
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestHealth // TestHealth
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestHealth(t *testing.T) { func TestHealth(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/health" || r.Method != http.MethodGet { if r.URL.Path != "/v1/health" || r.Method != http.MethodGet {
@@ -77,9 +86,7 @@ func TestHealth(t *testing.T) {
t.Fatalf("Health: unexpected error: %v", err) t.Fatalf("Health: unexpected error: %v", err)
} }
} }
// ---------------------------------------------------------------------------
// TestHealthError
// ---------------------------------------------------------------------------
func TestHealthError(t *testing.T) { func TestHealthError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusServiceUnavailable, "service unavailable") writeError(w, http.StatusServiceUnavailable, "service unavailable")
@@ -98,9 +105,11 @@ func TestHealthError(t *testing.T) {
t.Errorf("expected StatusCode 503, got %d", srvErr.StatusCode) t.Errorf("expected StatusCode 503, got %d", srvErr.StatusCode)
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestGetPublicKey // TestGetPublicKey
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestGetPublicKey(t *testing.T) { func TestGetPublicKey(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/keys/public" { if r.URL.Path != "/v1/keys/public" {
@@ -131,9 +140,11 @@ func TestGetPublicKey(t *testing.T) {
t.Error("expected non-empty x") t.Error("expected non-empty x")
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestLogin // TestLogin
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestLogin(t *testing.T) { func TestLogin(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/auth/login" || r.Method != http.MethodPost { if r.URL.Path != "/v1/auth/login" || r.Method != http.MethodPost {
@@ -157,14 +168,11 @@ func TestLogin(t *testing.T) {
if exp == "" { if exp == "" {
t.Error("expected non-empty expires_at") t.Error("expected non-empty expires_at")
} }
// Token must be stored in the client.
if c.Token() != "tok-abc123" { if c.Token() != "tok-abc123" {
t.Errorf("Token() = %q, want tok-abc123", c.Token()) t.Errorf("Token() = %q, want tok-abc123", c.Token())
} }
} }
// ---------------------------------------------------------------------------
// TestLoginUnauthorized
// ---------------------------------------------------------------------------
func TestLoginUnauthorized(t *testing.T) { func TestLoginUnauthorized(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusUnauthorized, "invalid credentials") writeError(w, http.StatusUnauthorized, "invalid credentials")
@@ -180,16 +188,17 @@ func TestLoginUnauthorized(t *testing.T) {
t.Errorf("expected MciasAuthError, got %T: %v", err, err) t.Errorf("expected MciasAuthError, got %T: %v", err, err)
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestLogout // TestLogout
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestLogout(t *testing.T) { func TestLogout(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path { switch r.URL.Path {
case "/v1/auth/login": case "/v1/auth/login":
writeJSON(w, http.StatusOK, map[string]string{ writeJSON(w, http.StatusOK, map[string]string{
"token": "tok-logout", "token": "tok-logout", "expires_at": "2099-01-01T00:00:00Z",
"expires_at": "2099-01-01T00:00:00Z",
}) })
case "/v1/auth/logout": case "/v1/auth/logout":
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -212,21 +221,21 @@ func TestLogout(t *testing.T) {
t.Errorf("expected empty token after logout, got %q", c.Token()) t.Errorf("expected empty token after logout, got %q", c.Token())
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestRenewToken // TestRenewToken
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestRenewToken(t *testing.T) { func TestRenewToken(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path { switch r.URL.Path {
case "/v1/auth/login": case "/v1/auth/login":
writeJSON(w, http.StatusOK, map[string]string{ writeJSON(w, http.StatusOK, map[string]string{
"token": "tok-old", "token": "tok-old", "expires_at": "2099-01-01T00:00:00Z",
"expires_at": "2099-01-01T00:00:00Z",
}) })
case "/v1/auth/renew": case "/v1/auth/renew":
writeJSON(w, http.StatusOK, map[string]string{ writeJSON(w, http.StatusOK, map[string]string{
"token": "tok-new", "token": "tok-new", "expires_at": "2099-06-01T00:00:00Z",
"expires_at": "2099-06-01T00:00:00Z",
}) })
default: default:
http.Error(w, "not found", http.StatusNotFound) http.Error(w, "not found", http.StatusNotFound)
@@ -248,9 +257,125 @@ func TestRenewToken(t *testing.T) {
t.Errorf("Token() = %q, want tok-new", c.Token()) t.Errorf("Token() = %q, want tok-new", c.Token())
} }
} }
// ---------------------------------------------------------------------------
// TestEnrollTOTP / TestConfirmTOTP
// ---------------------------------------------------------------------------
func TestEnrollTOTP(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/auth/totp/enroll" || r.Method != http.MethodPost {
http.Error(w, "not found", http.StatusNotFound)
return
}
writeJSON(w, http.StatusOK, map[string]string{
"secret": "JBSWY3DPEHPK3PXP",
"otpauth_uri": "otpauth://totp/MCIAS:alice?secret=JBSWY3DPEHPK3PXP&issuer=MCIAS",
})
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
resp, err := c.EnrollTOTP()
if err != nil {
t.Fatalf("EnrollTOTP: %v", err)
}
if resp.Secret != "JBSWY3DPEHPK3PXP" {
t.Errorf("expected secret=JBSWY3DPEHPK3PXP, got %q", resp.Secret)
}
if resp.OTPAuthURI == "" {
t.Error("expected non-empty otpauth_uri")
}
}
func TestConfirmTOTP(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/auth/totp/confirm" || r.Method != http.MethodPost {
http.Error(w, "not found", http.StatusNotFound)
return
}
w.WriteHeader(http.StatusNoContent)
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
if err := c.ConfirmTOTP("123456"); err != nil {
t.Fatalf("ConfirmTOTP: unexpected error: %v", err)
}
}
func TestConfirmTOTPBadCode(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "invalid TOTP code")
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
err := c.ConfirmTOTP("000000")
if err == nil {
t.Fatal("expected error for bad TOTP code")
}
var inputErr *mciasgoclient.MciasInputError
if !errors.As(err, &inputErr) {
t.Errorf("expected MciasInputError, got %T: %v", err, err)
}
}
// ---------------------------------------------------------------------------
// TestChangePassword
// ---------------------------------------------------------------------------
func TestChangePassword(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/auth/password" || r.Method != http.MethodPut {
http.Error(w, "not found", http.StatusNotFound)
return
}
w.WriteHeader(http.StatusNoContent)
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
if err := c.ChangePassword("old-s3cr3t", "new-s3cr3t-long"); err != nil {
t.Fatalf("ChangePassword: unexpected error: %v", err)
}
}
func TestChangePasswordWrongCurrent(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusUnauthorized, "current password is incorrect")
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
err := c.ChangePassword("wrong", "new-s3cr3t-long")
if err == nil {
t.Fatal("expected error for wrong current password")
}
var authErr *mciasgoclient.MciasAuthError
if !errors.As(err, &authErr) {
t.Errorf("expected MciasAuthError, got %T: %v", err, err)
}
}
// ---------------------------------------------------------------------------
// TestRemoveTOTP
// ---------------------------------------------------------------------------
func TestRemoveTOTP(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/auth/totp" || r.Method != http.MethodDelete {
http.Error(w, "not found", http.StatusNotFound)
return
}
w.WriteHeader(http.StatusNoContent)
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
if err := c.RemoveTOTP("acct-uuid-42"); err != nil {
t.Fatalf("RemoveTOTP: unexpected error: %v", err)
}
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestValidateToken // TestValidateToken
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestValidateToken(t *testing.T) { func TestValidateToken(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/token/validate" { if r.URL.Path != "/v1/token/validate" {
@@ -258,10 +383,8 @@ func TestValidateToken(t *testing.T) {
return return
} }
writeJSON(w, http.StatusOK, map[string]interface{}{ writeJSON(w, http.StatusOK, map[string]interface{}{
"valid": true, "valid": true, "sub": "user-uuid-1",
"sub": "user-uuid-1", "roles": []string{"admin"}, "expires_at": "2099-01-01T00:00:00Z",
"roles": []string{"admin"},
"expires_at": "2099-01-01T00:00:00Z",
}) })
})) }))
defer srv.Close() defer srv.Close()
@@ -277,15 +400,10 @@ func TestValidateToken(t *testing.T) {
t.Errorf("expected sub=user-uuid-1, got %q", claims.Sub) t.Errorf("expected sub=user-uuid-1, got %q", claims.Sub)
} }
} }
// ---------------------------------------------------------------------------
// TestValidateTokenInvalid
// ---------------------------------------------------------------------------
func TestValidateTokenInvalid(t *testing.T) { func TestValidateTokenInvalid(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Server returns 200 with valid=false for an expired/revoked token. writeJSON(w, http.StatusOK, map[string]interface{}{"valid": false})
writeJSON(w, http.StatusOK, map[string]interface{}{
"valid": false,
})
})) }))
defer srv.Close() defer srv.Close()
c := newTestClient(t, srv.URL) c := newTestClient(t, srv.URL)
@@ -297,9 +415,11 @@ func TestValidateTokenInvalid(t *testing.T) {
t.Error("expected claims.Valid = false") t.Error("expected claims.Valid = false")
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestCreateAccount // TestCreateAccount
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestCreateAccount(t *testing.T) { func TestCreateAccount(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/accounts" || r.Method != http.MethodPost { if r.URL.Path != "/v1/accounts" || r.Method != http.MethodPost {
@@ -307,13 +427,9 @@ func TestCreateAccount(t *testing.T) {
return return
} }
writeJSON(w, http.StatusCreated, map[string]interface{}{ writeJSON(w, http.StatusCreated, map[string]interface{}{
"id": "acct-uuid-1", "id": "acct-uuid-1", "username": "bob", "account_type": "human",
"username": "bob", "status": "active", "created_at": "2024-01-01T00:00:00Z",
"account_type": "human", "updated_at": "2024-01-01T00:00:00Z", "totp_enabled": false,
"status": "active",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z",
"totp_enabled": false,
}) })
})) }))
defer srv.Close() defer srv.Close()
@@ -329,9 +445,7 @@ func TestCreateAccount(t *testing.T) {
t.Errorf("expected username=bob, got %q", acct.Username) t.Errorf("expected username=bob, got %q", acct.Username)
} }
} }
// ---------------------------------------------------------------------------
// TestCreateAccountConflict
// ---------------------------------------------------------------------------
func TestCreateAccountConflict(t *testing.T) { func TestCreateAccountConflict(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusConflict, "username already exists") writeError(w, http.StatusConflict, "username already exists")
@@ -347,21 +461,19 @@ func TestCreateAccountConflict(t *testing.T) {
t.Errorf("expected MciasConflictError, got %T: %v", err, err) t.Errorf("expected MciasConflictError, got %T: %v", err, err)
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestListAccounts // TestListAccounts
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestListAccounts(t *testing.T) { func TestListAccounts(t *testing.T) {
accounts := []map[string]interface{}{ accounts := []map[string]interface{}{
{ {"id": "acct-1", "username": "alice", "account_type": "human",
"id": "acct-1", "username": "alice", "account_type": "human",
"status": "active", "created_at": "2024-01-01T00:00:00Z", "status": "active", "created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z", "totp_enabled": false, "updated_at": "2024-01-01T00:00:00Z", "totp_enabled": false},
}, {"id": "acct-2", "username": "bob", "account_type": "human",
{
"id": "acct-2", "username": "bob", "account_type": "human",
"status": "active", "created_at": "2024-01-02T00:00:00Z", "status": "active", "created_at": "2024-01-02T00:00:00Z",
"updated_at": "2024-01-02T00:00:00Z", "totp_enabled": false, "updated_at": "2024-01-02T00:00:00Z", "totp_enabled": false},
},
} }
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/accounts" || r.Method != http.MethodGet { if r.URL.Path != "/v1/accounts" || r.Method != http.MethodGet {
@@ -383,27 +495,21 @@ func TestListAccounts(t *testing.T) {
t.Errorf("expected alice, got %q", list[0].Username) t.Errorf("expected alice, got %q", list[0].Username)
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestGetAccount // TestGetAccount
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestGetAccount(t *testing.T) { func TestGetAccount(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet || !strings.HasPrefix(r.URL.Path, "/v1/accounts/") {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !strings.HasPrefix(r.URL.Path, "/v1/accounts/") {
http.Error(w, "not found", http.StatusNotFound) http.Error(w, "not found", http.StatusNotFound)
return return
} }
writeJSON(w, http.StatusOK, map[string]interface{}{ writeJSON(w, http.StatusOK, map[string]interface{}{
"id": "acct-uuid-42", "id": "acct-uuid-42", "username": "carol", "account_type": "human",
"username": "carol", "status": "active", "created_at": "2024-01-01T00:00:00Z",
"account_type": "human", "updated_at": "2024-01-01T00:00:00Z", "totp_enabled": false,
"status": "active",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z",
"totp_enabled": false,
}) })
})) }))
defer srv.Close() defer srv.Close()
@@ -416,38 +522,30 @@ func TestGetAccount(t *testing.T) {
t.Errorf("expected acct-uuid-42, got %q", acct.ID) t.Errorf("expected acct-uuid-42, got %q", acct.ID)
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestUpdateAccount // TestUpdateAccount
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestUpdateAccount(t *testing.T) { func TestUpdateAccount(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPatch { if r.Method != http.MethodPatch {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed) http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return return
} }
writeJSON(w, http.StatusOK, map[string]interface{}{ w.WriteHeader(http.StatusNoContent)
"id": "acct-uuid-42",
"username": "carol",
"account_type": "human",
"status": "disabled",
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-02-01T00:00:00Z",
"totp_enabled": false,
})
})) }))
defer srv.Close() defer srv.Close()
c := newTestClient(t, srv.URL) c := newTestClient(t, srv.URL)
acct, err := c.UpdateAccount("acct-uuid-42", "disabled") if err := c.UpdateAccount("acct-uuid-42", "inactive"); err != nil {
if err != nil { t.Fatalf("UpdateAccount: unexpected error: %v", err)
t.Fatalf("UpdateAccount: %v", err)
}
if acct.Status != "disabled" {
t.Errorf("expected status=disabled, got %q", acct.Status)
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestDeleteAccount // TestDeleteAccount
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestDeleteAccount(t *testing.T) { func TestDeleteAccount(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete { if r.Method != http.MethodDelete {
@@ -462,16 +560,33 @@ func TestDeleteAccount(t *testing.T) {
t.Fatalf("DeleteAccount: unexpected error: %v", err) t.Fatalf("DeleteAccount: unexpected error: %v", err)
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestGetRoles // TestAdminSetPassword
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestGetRoles(t *testing.T) {
func TestAdminSetPassword(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodPut || !strings.HasSuffix(r.URL.Path, "/password") {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed) http.Error(w, "not found", http.StatusNotFound)
return return
} }
if !strings.HasSuffix(r.URL.Path, "/roles") { w.WriteHeader(http.StatusNoContent)
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
if err := c.AdminSetPassword("acct-uuid-42", "new-s3cr3t-long"); err != nil {
t.Fatalf("AdminSetPassword: unexpected error: %v", err)
}
}
// ---------------------------------------------------------------------------
// TestGetRoles / TestSetRoles
// ---------------------------------------------------------------------------
func TestGetRoles(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet || !strings.HasSuffix(r.URL.Path, "/roles") {
http.Error(w, "not found", http.StatusNotFound) http.Error(w, "not found", http.StatusNotFound)
return return
} }
@@ -492,9 +607,7 @@ func TestGetRoles(t *testing.T) {
t.Errorf("expected roles[0]=admin, got %q", roles[0]) t.Errorf("expected roles[0]=admin, got %q", roles[0])
} }
} }
// ---------------------------------------------------------------------------
// TestSetRoles
// ---------------------------------------------------------------------------
func TestSetRoles(t *testing.T) { func TestSetRoles(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut { if r.Method != http.MethodPut {
@@ -509,9 +622,79 @@ func TestSetRoles(t *testing.T) {
t.Fatalf("SetRoles: unexpected error: %v", err) t.Fatalf("SetRoles: unexpected error: %v", err)
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestIssueServiceToken // TestGetAccountTags / TestSetAccountTags
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestGetAccountTags(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet || !strings.HasSuffix(r.URL.Path, "/tags") {
http.Error(w, "not found", http.StatusNotFound)
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"tags": []string{"env:production", "svc:payments-api"},
})
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
tags, err := c.GetAccountTags("acct-uuid-42")
if err != nil {
t.Fatalf("GetAccountTags: %v", err)
}
if len(tags) != 2 {
t.Errorf("expected 2 tags, got %d", len(tags))
}
if tags[0] != "env:production" {
t.Errorf("expected tags[0]=env:production, got %q", tags[0])
}
}
func TestSetAccountTags(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut || !strings.HasSuffix(r.URL.Path, "/tags") {
http.Error(w, "not found", http.StatusNotFound)
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"tags": []string{"env:staging"},
})
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
tags, err := c.SetAccountTags("acct-uuid-42", []string{"env:staging"})
if err != nil {
t.Fatalf("SetAccountTags: unexpected error: %v", err)
}
if len(tags) != 1 || tags[0] != "env:staging" {
t.Errorf("expected [env:staging], got %v", tags)
}
}
func TestSetAccountTagsClear(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{"tags": []string{}})
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
tags, err := c.SetAccountTags("acct-uuid-42", []string{})
if err != nil {
t.Fatalf("SetAccountTags (clear): unexpected error: %v", err)
}
if len(tags) != 0 {
t.Errorf("expected empty tags, got %v", tags)
}
}
// ---------------------------------------------------------------------------
// TestIssueServiceToken / TestRevokeToken
// ---------------------------------------------------------------------------
func TestIssueServiceToken(t *testing.T) { func TestIssueServiceToken(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/token/issue" || r.Method != http.MethodPost { if r.URL.Path != "/v1/token/issue" || r.Method != http.MethodPost {
@@ -519,8 +702,7 @@ func TestIssueServiceToken(t *testing.T) {
return return
} }
writeJSON(w, http.StatusOK, map[string]string{ writeJSON(w, http.StatusOK, map[string]string{
"token": "svc-tok-xyz", "token": "svc-tok-xyz", "expires_at": "2099-01-01T00:00:00Z",
"expires_at": "2099-01-01T00:00:00Z",
}) })
})) }))
defer srv.Close() defer srv.Close()
@@ -536,16 +718,10 @@ func TestIssueServiceToken(t *testing.T) {
t.Error("expected non-empty expires_at") t.Error("expected non-empty expires_at")
} }
} }
// ---------------------------------------------------------------------------
// TestRevokeToken
// ---------------------------------------------------------------------------
func TestRevokeToken(t *testing.T) { func TestRevokeToken(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete { if r.Method != http.MethodDelete || !strings.HasPrefix(r.URL.Path, "/v1/token/") {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !strings.HasPrefix(r.URL.Path, "/v1/token/") {
http.Error(w, "not found", http.StatusNotFound) http.Error(w, "not found", http.StatusNotFound)
return return
} }
@@ -557,25 +733,20 @@ func TestRevokeToken(t *testing.T) {
t.Fatalf("RevokeToken: unexpected error: %v", err) t.Fatalf("RevokeToken: unexpected error: %v", err)
} }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestGetPGCreds // TestGetPGCreds / TestSetPGCreds
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestGetPGCreds(t *testing.T) { func TestGetPGCreds(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet || !strings.HasSuffix(r.URL.Path, "/pgcreds") {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !strings.HasSuffix(r.URL.Path, "/pgcreds") {
http.Error(w, "not found", http.StatusNotFound) http.Error(w, "not found", http.StatusNotFound)
return return
} }
writeJSON(w, http.StatusOK, map[string]interface{}{ writeJSON(w, http.StatusOK, map[string]interface{}{
"host": "db.example.com", "host": "db.example.com", "port": 5432,
"port": 5432, "database": "myapp", "username": "appuser", "password": "secretpw",
"database": "myapp",
"username": "appuser",
"password": "secretpw",
}) })
})) }))
defer srv.Close() defer srv.Close()
@@ -594,16 +765,10 @@ func TestGetPGCreds(t *testing.T) {
t.Errorf("expected password=secretpw, got %q", creds.Password) t.Errorf("expected password=secretpw, got %q", creds.Password)
} }
} }
// ---------------------------------------------------------------------------
// TestSetPGCreds
// ---------------------------------------------------------------------------
func TestSetPGCreds(t *testing.T) { func TestSetPGCreds(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPut { if r.Method != http.MethodPut || !strings.HasSuffix(r.URL.Path, "/pgcreds") {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !strings.HasSuffix(r.URL.Path, "/pgcreds") {
http.Error(w, "not found", http.StatusNotFound) http.Error(w, "not found", http.StatusNotFound)
return return
} }
@@ -611,14 +776,238 @@ func TestSetPGCreds(t *testing.T) {
})) }))
defer srv.Close() defer srv.Close()
c := newTestClient(t, srv.URL) c := newTestClient(t, srv.URL)
err := c.SetPGCreds("acct-uuid-42", "db.example.com", 5432, "myapp", "appuser", "secretpw") if err := c.SetPGCreds("acct-uuid-42", "db.example.com", 5432, "myapp", "appuser", "secretpw"); err != nil {
if err != nil {
t.Fatalf("SetPGCreds: unexpected error: %v", err) t.Fatalf("SetPGCreds: unexpected error: %v", err)
} }
} }
// ---------------------------------------------------------------------------
// TestListAudit
// ---------------------------------------------------------------------------
func TestListAudit(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/v1/audit") || r.Method != http.MethodGet {
http.Error(w, "not found", http.StatusNotFound)
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"events": []map[string]interface{}{
{"id": 42, "event_type": "login_ok", "event_time": "2026-03-11T09:01:23Z",
"actor_id": "acct-uuid-1", "ip_address": "192.0.2.1"},
},
"total": 1, "limit": 50, "offset": 0,
})
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
resp, err := c.ListAudit(mciasgoclient.AuditFilter{})
if err != nil {
t.Fatalf("ListAudit: %v", err)
}
if resp.Total != 1 {
t.Errorf("expected total=1, got %d", resp.Total)
}
if len(resp.Events) != 1 {
t.Fatalf("expected 1 event, got %d", len(resp.Events))
}
if resp.Events[0].EventType != "login_ok" {
t.Errorf("expected event_type=login_ok, got %q", resp.Events[0].EventType)
}
}
func TestListAuditWithFilter(t *testing.T) {
var capturedQuery string
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedQuery = r.URL.RawQuery
writeJSON(w, http.StatusOK, map[string]interface{}{
"events": []map[string]interface{}{},
"total": 0, "limit": 10, "offset": 5,
})
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
_, err := c.ListAudit(mciasgoclient.AuditFilter{
Limit: 10, Offset: 5, EventType: "login_fail", ActorID: "acct-uuid-1",
})
if err != nil {
t.Fatalf("ListAudit: %v", err)
}
for _, want := range []string{"limit=10", "offset=5", "event_type=login_fail", "actor_id=acct-uuid-1"} {
if !strings.Contains(capturedQuery, want) {
t.Errorf("expected %q in query string, got %q", want, capturedQuery)
}
}
}
// ---------------------------------------------------------------------------
// TestListPolicyRules
// ---------------------------------------------------------------------------
func TestListPolicyRules(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/policy/rules" || r.Method != http.MethodGet {
http.Error(w, "not found", http.StatusNotFound)
return
}
writeJSON(w, http.StatusOK, []map[string]interface{}{
{
"id": 1, "priority": 100,
"description": "Allow payments-api to read its own pgcreds",
"rule": map[string]interface{}{"effect": "allow", "actions": []string{"pgcreds:read"}},
"enabled": true,
"created_at": "2026-03-11T09:00:00Z", "updated_at": "2026-03-11T09:00:00Z",
},
})
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
rules, err := c.ListPolicyRules()
if err != nil {
t.Fatalf("ListPolicyRules: %v", err)
}
if len(rules) != 1 {
t.Fatalf("expected 1 rule, got %d", len(rules))
}
if rules[0].ID != 1 {
t.Errorf("expected id=1, got %d", rules[0].ID)
}
if rules[0].Description != "Allow payments-api to read its own pgcreds" {
t.Errorf("unexpected description: %q", rules[0].Description)
}
}
// ---------------------------------------------------------------------------
// TestCreatePolicyRule
// ---------------------------------------------------------------------------
func TestCreatePolicyRule(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/policy/rules" || r.Method != http.MethodPost {
http.Error(w, "not found", http.StatusNotFound)
return
}
writeJSON(w, http.StatusCreated, map[string]interface{}{
"id": 7, "priority": 50, "description": "Test rule",
"rule": map[string]interface{}{"effect": "deny"},
"enabled": true,
"created_at": "2026-03-11T09:00:00Z", "updated_at": "2026-03-11T09:00:00Z",
})
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
rule, err := c.CreatePolicyRule(mciasgoclient.CreatePolicyRuleRequest{
Description: "Test rule",
Priority: 50,
Rule: mciasgoclient.PolicyRuleBody{Effect: "deny"},
})
if err != nil {
t.Fatalf("CreatePolicyRule: %v", err)
}
if rule.ID != 7 {
t.Errorf("expected id=7, got %d", rule.ID)
}
if rule.Priority != 50 {
t.Errorf("expected priority=50, got %d", rule.Priority)
}
}
// ---------------------------------------------------------------------------
// TestGetPolicyRule
// ---------------------------------------------------------------------------
func TestGetPolicyRule(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet || r.URL.Path != "/v1/policy/rules/7" {
http.Error(w, "not found", http.StatusNotFound)
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"id": 7, "priority": 50, "description": "Test rule",
"rule": map[string]interface{}{"effect": "allow"},
"enabled": true,
"created_at": "2026-03-11T09:00:00Z", "updated_at": "2026-03-11T09:00:00Z",
})
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
rule, err := c.GetPolicyRule(7)
if err != nil {
t.Fatalf("GetPolicyRule: %v", err)
}
if rule.ID != 7 {
t.Errorf("expected id=7, got %d", rule.ID)
}
}
func TestGetPolicyRuleNotFound(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusNotFound, "rule not found")
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
_, err := c.GetPolicyRule(999)
if err == nil {
t.Fatal("expected error for 404")
}
var notFoundErr *mciasgoclient.MciasNotFoundError
if !errors.As(err, &notFoundErr) {
t.Errorf("expected MciasNotFoundError, got %T: %v", err, err)
}
}
// ---------------------------------------------------------------------------
// TestUpdatePolicyRule
// ---------------------------------------------------------------------------
func TestUpdatePolicyRule(t *testing.T) {
enabled := false
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPatch || r.URL.Path != "/v1/policy/rules/7" {
http.Error(w, "not found", http.StatusNotFound)
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"id": 7, "priority": 50, "description": "Test rule",
"rule": map[string]interface{}{"effect": "allow"},
"enabled": false,
"created_at": "2026-03-11T09:00:00Z", "updated_at": "2026-03-12T10:00:00Z",
})
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
rule, err := c.UpdatePolicyRule(7, mciasgoclient.UpdatePolicyRuleRequest{Enabled: &enabled})
if err != nil {
t.Fatalf("UpdatePolicyRule: %v", err)
}
if rule.Enabled {
t.Error("expected enabled=false after update")
}
}
// ---------------------------------------------------------------------------
// TestDeletePolicyRule
// ---------------------------------------------------------------------------
func TestDeletePolicyRule(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodDelete || r.URL.Path != "/v1/policy/rules/7" {
http.Error(w, "not found", http.StatusNotFound)
return
}
w.WriteHeader(http.StatusNoContent)
}))
defer srv.Close()
c := newTestClient(t, srv.URL)
if err := c.DeletePolicyRule(7); err != nil {
t.Fatalf("DeletePolicyRule: unexpected error: %v", err)
}
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// TestIntegration: full login → validate → logout flow // TestIntegration: full login → validate → logout flow
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
func TestIntegration(t *testing.T) { func TestIntegration(t *testing.T) {
const sessionToken = "integration-tok-999" const sessionToken = "integration-tok-999"
mux := http.NewServeMux() mux := http.NewServeMux()
@@ -640,8 +1029,7 @@ func TestIntegration(t *testing.T) {
return return
} }
writeJSON(w, http.StatusOK, map[string]string{ writeJSON(w, http.StatusOK, map[string]string{
"token": sessionToken, "token": sessionToken, "expires_at": "2099-01-01T00:00:00Z",
"expires_at": "2099-01-01T00:00:00Z",
}) })
}) })
mux.HandleFunc("/v1/token/validate", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/v1/token/validate", func(w http.ResponseWriter, r *http.Request) {
@@ -658,15 +1046,11 @@ func TestIntegration(t *testing.T) {
} }
if body.Token == sessionToken { if body.Token == sessionToken {
writeJSON(w, http.StatusOK, map[string]interface{}{ writeJSON(w, http.StatusOK, map[string]interface{}{
"valid": true, "valid": true, "sub": "alice-uuid",
"sub": "alice-uuid", "roles": []string{"user"}, "expires_at": "2099-01-01T00:00:00Z",
"roles": []string{"user"},
"expires_at": "2099-01-01T00:00:00Z",
}) })
} else { } else {
writeJSON(w, http.StatusOK, map[string]interface{}{ writeJSON(w, http.StatusOK, map[string]interface{}{"valid": false})
"valid": false,
})
} }
}) })
mux.HandleFunc("/v1/auth/logout", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/v1/auth/logout", func(w http.ResponseWriter, r *http.Request) {
@@ -674,9 +1058,7 @@ func TestIntegration(t *testing.T) {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed) http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return return
} }
// Verify Authorization header is present. if r.Header.Get("Authorization") == "" {
auth := r.Header.Get("Authorization")
if auth == "" {
writeError(w, http.StatusUnauthorized, "missing token") writeError(w, http.StatusUnauthorized, "missing token")
return return
} }
@@ -685,7 +1067,8 @@ func TestIntegration(t *testing.T) {
srv := httptest.NewServer(mux) srv := httptest.NewServer(mux)
defer srv.Close() defer srv.Close()
c := newTestClient(t, srv.URL) c := newTestClient(t, srv.URL)
// Step 1: login with wrong credentials should fail.
// Step 1: wrong credentials → MciasAuthError.
_, _, err := c.Login("alice", "wrong-password", "") _, _, err := c.Login("alice", "wrong-password", "")
if err == nil { if err == nil {
t.Fatal("expected error for wrong credentials") t.Fatal("expected error for wrong credentials")
@@ -694,7 +1077,8 @@ func TestIntegration(t *testing.T) {
if !errors.As(err, &authErr) { if !errors.As(err, &authErr) {
t.Errorf("expected MciasAuthError, got %T", err) t.Errorf("expected MciasAuthError, got %T", err)
} }
// Step 2: login with correct credentials.
// Step 2: correct login.
tok, _, err := c.Login("alice", "correct-horse", "") tok, _, err := c.Login("alice", "correct-horse", "")
if err != nil { if err != nil {
t.Fatalf("Login: %v", err) t.Fatalf("Login: %v", err)
@@ -702,7 +1086,8 @@ func TestIntegration(t *testing.T) {
if tok != sessionToken { if tok != sessionToken {
t.Errorf("expected %q, got %q", sessionToken, tok) t.Errorf("expected %q, got %q", sessionToken, tok)
} }
// Step 3: validate the returned token.
// Step 3: validate → valid=true.
claims, err := c.ValidateToken(tok) claims, err := c.ValidateToken(tok)
if err != nil { if err != nil {
t.Fatalf("ValidateToken: %v", err) t.Fatalf("ValidateToken: %v", err)
@@ -713,7 +1098,8 @@ func TestIntegration(t *testing.T) {
if claims.Sub != "alice-uuid" { if claims.Sub != "alice-uuid" {
t.Errorf("expected sub=alice-uuid, got %q", claims.Sub) t.Errorf("expected sub=alice-uuid, got %q", claims.Sub)
} }
// Step 4: validate an unknown token returns Valid=false, not an error.
// Step 4: garbage token → valid=false (not an error).
claims2, err := c.ValidateToken("garbage-token") claims2, err := c.ValidateToken("garbage-token")
if err != nil { if err != nil {
t.Fatalf("ValidateToken(garbage): unexpected error: %v", err) t.Fatalf("ValidateToken(garbage): unexpected error: %v", err)
@@ -721,7 +1107,8 @@ func TestIntegration(t *testing.T) {
if claims2.Valid { if claims2.Valid {
t.Error("expected Valid=false for garbage token") t.Error("expected Valid=false for garbage token")
} }
// Step 5: logout clears the stored token.
// Step 5: logout clears stored token.
if err := c.Logout(); err != nil { if err := c.Logout(); err != nil {
t.Fatalf("Logout: %v", err) t.Fatalf("Logout: %v", err)
} }

View File

@@ -7,9 +7,10 @@ from ._errors import (
MciasForbiddenError, MciasForbiddenError,
MciasInputError, MciasInputError,
MciasNotFoundError, MciasNotFoundError,
MciasRateLimitError,
MciasServerError, MciasServerError,
) )
from ._models import Account, PGCreds, PublicKey, TokenClaims from ._models import Account, PGCreds, PolicyRule, PublicKey, RuleBody, TokenClaims
__all__ = [ __all__ = [
"Client", "Client",
@@ -19,9 +20,12 @@ __all__ = [
"MciasNotFoundError", "MciasNotFoundError",
"MciasInputError", "MciasInputError",
"MciasConflictError", "MciasConflictError",
"MciasRateLimitError",
"MciasServerError", "MciasServerError",
"Account", "Account",
"PublicKey", "PublicKey",
"TokenClaims", "TokenClaims",
"PGCreds", "PGCreds",
"PolicyRule",
"RuleBody",
] ]

View File

@@ -8,7 +8,7 @@ from typing import Any
import httpx import httpx
from ._errors import raise_for_status from ._errors import raise_for_status
from ._models import Account, PGCreds, PublicKey, TokenClaims from ._models import Account, PGCreds, PolicyRule, PublicKey, RuleBody, TokenClaims
class Client: class Client:
@@ -76,6 +76,29 @@ class Client:
if status == 204 or not response.content: if status == 204 or not response.content:
return None return None
return response.json() # type: ignore[no-any-return] return response.json() # type: ignore[no-any-return]
def _request_list(
self,
method: str,
path: str,
*,
json: dict[str, Any] | None = None,
) -> list[dict[str, Any]]:
"""Send a request that returns a JSON array at the top level."""
url = f"{self._base_url}{path}"
headers: dict[str, str] = {}
if self.token is not None:
headers["Authorization"] = f"Bearer {self.token}"
response = self._http.request(method, url, json=json, headers=headers)
status = response.status_code
if status >= 400:
try:
body = response.json()
message = str(body.get("error", response.text))
except Exception:
message = response.text
raise_for_status(status, message)
return response.json() # type: ignore[no-any-return]
# ── Public ────────────────────────────────────────────────────────────────
def health(self) -> None: def health(self) -> None:
"""GET /v1/health — liveness check.""" """GET /v1/health — liveness check."""
self._request("GET", "/v1/health") self._request("GET", "/v1/health")
@@ -105,6 +128,12 @@ class Client:
expires_at = str(data["expires_at"]) expires_at = str(data["expires_at"])
self.token = token self.token = token
return token, expires_at return token, expires_at
def validate_token(self, token: str) -> TokenClaims:
"""POST /v1/token/validate — check whether a token is valid."""
data = self._request("POST", "/v1/token/validate", json={"token": token})
assert data is not None
return TokenClaims.from_dict(data)
# ── Authenticated ──────────────────────────────────────────────────────────
def logout(self) -> None: def logout(self) -> None:
"""POST /v1/auth/logout — invalidate the current token.""" """POST /v1/auth/logout — invalidate the current token."""
self._request("POST", "/v1/auth/logout") self._request("POST", "/v1/auth/logout")
@@ -119,11 +148,45 @@ class Client:
expires_at = str(data["expires_at"]) expires_at = str(data["expires_at"])
self.token = token self.token = token
return token, expires_at return token, expires_at
def validate_token(self, token: str) -> TokenClaims: def enroll_totp(self) -> tuple[str, str]:
"""POST /v1/token/validate — check whether a token is valid.""" """POST /v1/auth/totp/enroll — begin TOTP enrollment.
data = self._request("POST", "/v1/token/validate", json={"token": token}) Returns (secret, otpauth_uri). The secret is shown only once.
"""
data = self._request("POST", "/v1/auth/totp/enroll")
assert data is not None assert data is not None
return TokenClaims.from_dict(data) return str(data["secret"]), str(data["otpauth_uri"])
def confirm_totp(self, code: str) -> None:
"""POST /v1/auth/totp/confirm — confirm TOTP enrollment with a code."""
self._request("POST", "/v1/auth/totp/confirm", json={"code": code})
def change_password(self, current_password: str, new_password: str) -> None:
"""PUT /v1/auth/password — change own password (self-service)."""
self._request(
"PUT",
"/v1/auth/password",
json={"current_password": current_password, "new_password": new_password},
)
# ── Admin — Auth ──────────────────────────────────────────────────────────
def remove_totp(self, account_id: str) -> None:
"""DELETE /v1/auth/totp — remove TOTP from an account (admin)."""
self._request("DELETE", "/v1/auth/totp", json={"account_id": account_id})
# ── Admin — Tokens ────────────────────────────────────────────────────────
def issue_service_token(self, account_id: str) -> tuple[str, str]:
"""POST /v1/token/issue — issue a long-lived service token (admin).
Returns (token, expires_at).
"""
data = self._request("POST", "/v1/token/issue", json={"account_id": account_id})
assert data is not None
return str(data["token"]), str(data["expires_at"])
def revoke_token(self, jti: str) -> None:
"""DELETE /v1/token/{jti} — revoke a token by JTI (admin)."""
self._request("DELETE", f"/v1/token/{jti}")
# ── Admin — Accounts ──────────────────────────────────────────────────────
def list_accounts(self) -> list[Account]:
"""GET /v1/accounts — list all accounts (admin).
The API returns a JSON array directly (no wrapper object).
"""
items = self._request_list("GET", "/v1/accounts")
return [Account.from_dict(a) for a in items]
def create_account( def create_account(
self, self,
username: str, username: str,
@@ -131,7 +194,7 @@ class Client:
*, *,
password: str | None = None, password: str | None = None,
) -> Account: ) -> Account:
"""POST /v1/accounts — create a new account.""" """POST /v1/accounts — create a new account (admin)."""
payload: dict[str, Any] = { payload: dict[str, Any] = {
"username": username, "username": username,
"account_type": account_type, "account_type": account_type,
@@ -141,14 +204,8 @@ class Client:
data = self._request("POST", "/v1/accounts", json=payload) data = self._request("POST", "/v1/accounts", json=payload)
assert data is not None assert data is not None
return Account.from_dict(data) return Account.from_dict(data)
def list_accounts(self) -> list[Account]:
"""GET /v1/accounts — list all accounts."""
data = self._request("GET", "/v1/accounts")
assert data is not None
accounts_raw = data.get("accounts") or []
return [Account.from_dict(a) for a in accounts_raw]
def get_account(self, account_id: str) -> Account: def get_account(self, account_id: str) -> Account:
"""GET /v1/accounts/{id} — retrieve a single account.""" """GET /v1/accounts/{id} — retrieve a single account (admin)."""
data = self._request("GET", f"/v1/accounts/{account_id}") data = self._request("GET", f"/v1/accounts/{account_id}")
assert data is not None assert data is not None
return Account.from_dict(data) return Account.from_dict(data)
@@ -157,42 +214,40 @@ class Client:
account_id: str, account_id: str,
*, *,
status: str | None = None, status: str | None = None,
) -> Account: ) -> None:
"""PATCH /v1/accounts/{id} — update account fields.""" """PATCH /v1/accounts/{id} — update account fields (admin).
Currently only `status` is patchable. Returns None (204 No Content).
"""
payload: dict[str, Any] = {} payload: dict[str, Any] = {}
if status is not None: if status is not None:
payload["status"] = status payload["status"] = status
data = self._request("PATCH", f"/v1/accounts/{account_id}", json=payload) self._request("PATCH", f"/v1/accounts/{account_id}", json=payload)
assert data is not None
return Account.from_dict(data)
def delete_account(self, account_id: str) -> None: def delete_account(self, account_id: str) -> None:
"""DELETE /v1/accounts/{id}permanently remove an account.""" """DELETE /v1/accounts/{id}soft-delete an account (admin)."""
self._request("DELETE", f"/v1/accounts/{account_id}") self._request("DELETE", f"/v1/accounts/{account_id}")
def get_roles(self, account_id: str) -> list[str]: def get_roles(self, account_id: str) -> list[str]:
"""GET /v1/accounts/{id}/roles — list roles for an account.""" """GET /v1/accounts/{id}/roles — list roles for an account (admin)."""
data = self._request("GET", f"/v1/accounts/{account_id}/roles") data = self._request("GET", f"/v1/accounts/{account_id}/roles")
assert data is not None assert data is not None
roles_raw = data.get("roles") or [] roles_raw = data.get("roles") or []
return [str(r) for r in roles_raw] return [str(r) for r in roles_raw]
def set_roles(self, account_id: str, roles: list[str]) -> None: def set_roles(self, account_id: str, roles: list[str]) -> None:
"""PUT /v1/accounts/{id}/roles — replace the full role set.""" """PUT /v1/accounts/{id}/roles — replace the full role set (admin)."""
self._request( self._request(
"PUT", "PUT",
f"/v1/accounts/{account_id}/roles", f"/v1/accounts/{account_id}/roles",
json={"roles": roles}, json={"roles": roles},
) )
def issue_service_token(self, account_id: str) -> tuple[str, str]: def admin_set_password(self, account_id: str, new_password: str) -> None:
"""POST /v1/accounts/{id}/token — issue a long-lived service token. """PUT /v1/accounts/{id}/password — reset a password without the old one (admin)."""
Returns (token, expires_at). self._request(
""" "PUT",
data = self._request("POST", f"/v1/accounts/{account_id}/token") f"/v1/accounts/{account_id}/password",
assert data is not None json={"new_password": new_password},
return str(data["token"]), str(data["expires_at"]) )
def revoke_token(self, jti: str) -> None: # ── Admin — Credentials ───────────────────────────────────────────────────
"""DELETE /v1/token/{jti} — revoke a token by JTI."""
self._request("DELETE", f"/v1/token/{jti}")
def get_pg_creds(self, account_id: str) -> PGCreds: def get_pg_creds(self, account_id: str) -> PGCreds:
"""GET /v1/accounts/{id}/pgcreds — retrieve Postgres credentials.""" """GET /v1/accounts/{id}/pgcreds — retrieve Postgres credentials (admin)."""
data = self._request("GET", f"/v1/accounts/{account_id}/pgcreds") data = self._request("GET", f"/v1/accounts/{account_id}/pgcreds")
assert data is not None assert data is not None
return PGCreds.from_dict(data) return PGCreds.from_dict(data)
@@ -205,7 +260,7 @@ class Client:
username: str, username: str,
password: str, password: str,
) -> None: ) -> None:
"""PUT /v1/accounts/{id}/pgcreds — store or replace Postgres credentials.""" """PUT /v1/accounts/{id}/pgcreds — store or replace Postgres credentials (admin)."""
payload: dict[str, Any] = { payload: dict[str, Any] = {
"host": host, "host": host,
"port": port, "port": port,
@@ -214,3 +269,89 @@ class Client:
"password": password, "password": password,
} }
self._request("PUT", f"/v1/accounts/{account_id}/pgcreds", json=payload) self._request("PUT", f"/v1/accounts/{account_id}/pgcreds", json=payload)
# ── Admin — Policy ────────────────────────────────────────────────────────
def get_account_tags(self, account_id: str) -> list[str]:
"""GET /v1/accounts/{id}/tags — get account tags (admin)."""
data = self._request("GET", f"/v1/accounts/{account_id}/tags")
assert data is not None
return [str(t) for t in (data.get("tags") or [])]
def set_account_tags(self, account_id: str, tags: list[str]) -> list[str]:
"""PUT /v1/accounts/{id}/tags — replace the full tag set (admin).
Returns the updated tag list.
"""
data = self._request(
"PUT",
f"/v1/accounts/{account_id}/tags",
json={"tags": tags},
)
assert data is not None
return [str(t) for t in (data.get("tags") or [])]
def list_policy_rules(self) -> list[PolicyRule]:
"""GET /v1/policy/rules — list all operator policy rules (admin)."""
items = self._request_list("GET", "/v1/policy/rules")
return [PolicyRule.from_dict(r) for r in items]
def create_policy_rule(
self,
description: str,
rule: RuleBody,
*,
priority: int | None = None,
not_before: str | None = None,
expires_at: str | None = None,
) -> PolicyRule:
"""POST /v1/policy/rules — create a policy rule (admin)."""
payload: dict[str, Any] = {
"description": description,
"rule": rule.to_dict(),
}
if priority is not None:
payload["priority"] = priority
if not_before is not None:
payload["not_before"] = not_before
if expires_at is not None:
payload["expires_at"] = expires_at
data = self._request("POST", "/v1/policy/rules", json=payload)
assert data is not None
return PolicyRule.from_dict(data)
def get_policy_rule(self, rule_id: int) -> PolicyRule:
"""GET /v1/policy/rules/{id} — get a policy rule (admin)."""
data = self._request("GET", f"/v1/policy/rules/{rule_id}")
assert data is not None
return PolicyRule.from_dict(data)
def update_policy_rule(
self,
rule_id: int,
*,
description: str | None = None,
priority: int | None = None,
enabled: bool | None = None,
rule: RuleBody | None = None,
not_before: str | None = None,
expires_at: str | None = None,
clear_not_before: bool | None = None,
clear_expires_at: bool | None = None,
) -> PolicyRule:
"""PATCH /v1/policy/rules/{id} — update a policy rule (admin)."""
payload: dict[str, Any] = {}
if description is not None:
payload["description"] = description
if priority is not None:
payload["priority"] = priority
if enabled is not None:
payload["enabled"] = enabled
if rule is not None:
payload["rule"] = rule.to_dict()
if not_before is not None:
payload["not_before"] = not_before
if expires_at is not None:
payload["expires_at"] = expires_at
if clear_not_before is not None:
payload["clear_not_before"] = clear_not_before
if clear_expires_at is not None:
payload["clear_expires_at"] = clear_expires_at
data = self._request("PATCH", f"/v1/policy/rules/{rule_id}", json=payload)
assert data is not None
return PolicyRule.from_dict(data)
def delete_policy_rule(self, rule_id: int) -> None:
"""DELETE /v1/policy/rules/{id} — delete a policy rule (admin)."""
self._request("DELETE", f"/v1/policy/rules/{rule_id}")

View File

@@ -15,6 +15,8 @@ class MciasInputError(MciasError):
"""400 Bad Request — malformed request.""" """400 Bad Request — malformed request."""
class MciasConflictError(MciasError): class MciasConflictError(MciasError):
"""409 Conflict — e.g. duplicate username.""" """409 Conflict — e.g. duplicate username."""
class MciasRateLimitError(MciasError):
"""429 Too Many Requests — rate limit exceeded."""
class MciasServerError(MciasError): class MciasServerError(MciasError):
"""5xx — unexpected server error.""" """5xx — unexpected server error."""
def raise_for_status(status_code: int, message: str) -> None: def raise_for_status(status_code: int, message: str) -> None:
@@ -25,6 +27,7 @@ def raise_for_status(status_code: int, message: str) -> None:
403: MciasForbiddenError, 403: MciasForbiddenError,
404: MciasNotFoundError, 404: MciasNotFoundError,
409: MciasConflictError, 409: MciasConflictError,
429: MciasRateLimitError,
} }
cls = exc_map.get(status_code, MciasServerError) cls = exc_map.get(status_code, MciasServerError)
raise cls(status_code, message) raise cls(status_code, message)

View File

@@ -1,6 +1,6 @@
"""Data models for MCIAS API responses.""" """Data models for MCIAS API responses."""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import cast from typing import Any, cast
@dataclass @dataclass
@@ -74,3 +74,73 @@ class PGCreds:
username=str(d["username"]), username=str(d["username"]),
password=str(d["password"]), password=str(d["password"]),
) )
@dataclass
class RuleBody:
"""Match conditions and effect of a policy rule."""
effect: str
roles: list[str] = field(default_factory=list)
account_types: list[str] = field(default_factory=list)
subject_uuid: str | None = None
actions: list[str] = field(default_factory=list)
resource_type: str | None = None
owner_matches_subject: bool | None = None
service_names: list[str] = field(default_factory=list)
required_tags: list[str] = field(default_factory=list)
@classmethod
def from_dict(cls, d: dict[str, object]) -> "RuleBody":
return cls(
effect=str(d["effect"]),
roles=[str(r) for r in cast(list[Any], d.get("roles") or [])],
account_types=[str(t) for t in cast(list[Any], d.get("account_types") or [])],
subject_uuid=str(d["subject_uuid"]) if d.get("subject_uuid") is not None else None,
actions=[str(a) for a in cast(list[Any], d.get("actions") or [])],
resource_type=str(d["resource_type"]) if d.get("resource_type") is not None else None,
owner_matches_subject=bool(d["owner_matches_subject"]) if d.get("owner_matches_subject") is not None else None,
service_names=[str(s) for s in cast(list[Any], d.get("service_names") or [])],
required_tags=[str(t) for t in cast(list[Any], d.get("required_tags") or [])],
)
def to_dict(self) -> dict[str, Any]:
"""Serialise to a JSON-compatible dict, omitting None/empty fields."""
out: dict[str, Any] = {"effect": self.effect}
if self.roles:
out["roles"] = self.roles
if self.account_types:
out["account_types"] = self.account_types
if self.subject_uuid is not None:
out["subject_uuid"] = self.subject_uuid
if self.actions:
out["actions"] = self.actions
if self.resource_type is not None:
out["resource_type"] = self.resource_type
if self.owner_matches_subject is not None:
out["owner_matches_subject"] = self.owner_matches_subject
if self.service_names:
out["service_names"] = self.service_names
if self.required_tags:
out["required_tags"] = self.required_tags
return out
@dataclass
class PolicyRule:
"""An operator-defined policy rule."""
id: int
priority: int
description: str
rule: RuleBody
enabled: bool
created_at: str
updated_at: str
not_before: str | None = None
expires_at: str | None = None
@classmethod
def from_dict(cls, d: dict[str, object]) -> "PolicyRule":
return cls(
id=int(cast(int, d["id"])),
priority=int(cast(int, d["priority"])),
description=str(d["description"]),
rule=RuleBody.from_dict(cast(dict[str, object], d["rule"])),
enabled=bool(d["enabled"]),
created_at=str(d["created_at"]),
updated_at=str(d["updated_at"]),
not_before=str(d["not_before"]) if d.get("not_before") is not None else None,
expires_at=str(d["expires_at"]) if d.get("expires_at") is not None else None,
)

View File

@@ -13,15 +13,16 @@ from mcias_client import (
MciasForbiddenError, MciasForbiddenError,
MciasInputError, MciasInputError,
MciasNotFoundError, MciasNotFoundError,
MciasRateLimitError,
MciasServerError, MciasServerError,
) )
from mcias_client._models import Account, PGCreds, PublicKey, TokenClaims from mcias_client._models import Account, PGCreds, PolicyRule, PublicKey, RuleBody, TokenClaims
BASE_URL = "https://auth.example.com" BASE_URL = "https://auth.example.com"
SAMPLE_ACCOUNT: dict[str, object] = { SAMPLE_ACCOUNT: dict[str, object] = {
"id": "acc-001", "id": "acc-001",
"username": "alice", "username": "alice",
"account_type": "user", "account_type": "human",
"status": "active", "status": "active",
"created_at": "2024-01-01T00:00:00Z", "created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z", "updated_at": "2024-01-01T00:00:00Z",
@@ -34,6 +35,24 @@ SAMPLE_PK: dict[str, object] = {
"use": "sig", "use": "sig",
"alg": "EdDSA", "alg": "EdDSA",
} }
SAMPLE_RULE_BODY: dict[str, object] = {
"effect": "allow",
"roles": ["svc:payments-api"],
"actions": ["pgcreds:read"],
"resource_type": "pgcreds",
"owner_matches_subject": True,
}
SAMPLE_POLICY_RULE: dict[str, object] = {
"id": 1,
"priority": 100,
"description": "Allow payments-api to read its own pgcreds",
"rule": SAMPLE_RULE_BODY,
"enabled": True,
"not_before": None,
"expires_at": None,
"created_at": "2026-03-11T09:00:00Z",
"updated_at": "2026-03-11T09:00:00Z",
}
@pytest.fixture @pytest.fixture
def client() -> Client: def client() -> Client:
return Client(BASE_URL) return Client(BASE_URL)
@@ -88,6 +107,16 @@ def test_login_success(client: Client) -> None:
assert expires_at == "2099-01-01T00:00:00Z" assert expires_at == "2099-01-01T00:00:00Z"
assert client.token == "jwt-token-abc" assert client.token == "jwt-token-abc"
@respx.mock @respx.mock
def test_login_with_totp(client: Client) -> None:
respx.post(f"{BASE_URL}/v1/auth/login").mock(
return_value=httpx.Response(
200,
json={"token": "jwt-token-totp", "expires_at": "2099-01-01T00:00:00Z"},
)
)
token, _ = client.login("alice", "s3cr3t", totp_code="123456")
assert token == "jwt-token-totp"
@respx.mock
def test_login_unauthorized(client: Client) -> None: def test_login_unauthorized(client: Client) -> None:
respx.post(f"{BASE_URL}/v1/auth/login").mock( respx.post(f"{BASE_URL}/v1/auth/login").mock(
return_value=httpx.Response( return_value=httpx.Response(
@@ -98,6 +127,14 @@ def test_login_unauthorized(client: Client) -> None:
client.login("alice", "wrong") client.login("alice", "wrong")
assert exc_info.value.status_code == 401 assert exc_info.value.status_code == 401
@respx.mock @respx.mock
def test_login_rate_limited(client: Client) -> None:
respx.post(f"{BASE_URL}/v1/auth/login").mock(
return_value=httpx.Response(429, json={"error": "rate limit exceeded", "code": "rate_limited"})
)
with pytest.raises(MciasRateLimitError) as exc_info:
client.login("alice", "s3cr3t")
assert exc_info.value.status_code == 429
@respx.mock
def test_logout_clears_token(admin_client: Client) -> None: def test_logout_clears_token(admin_client: Client) -> None:
respx.post(f"{BASE_URL}/v1/auth/logout").mock( respx.post(f"{BASE_URL}/v1/auth/logout").mock(
return_value=httpx.Response(204) return_value=httpx.Response(204)
@@ -147,11 +184,58 @@ def test_validate_token_invalid(admin_client: Client) -> None:
claims = admin_client.validate_token("expired-token") claims = admin_client.validate_token("expired-token")
assert claims.valid is False assert claims.valid is False
@respx.mock @respx.mock
def test_enroll_totp(admin_client: Client) -> None:
respx.post(f"{BASE_URL}/v1/auth/totp/enroll").mock(
return_value=httpx.Response(
200,
json={"secret": "JBSWY3DPEHPK3PXP", "otpauth_uri": "otpauth://totp/MCIAS:alice?secret=JBSWY3DPEHPK3PXP&issuer=MCIAS"},
)
)
secret, uri = admin_client.enroll_totp()
assert secret == "JBSWY3DPEHPK3PXP"
assert "otpauth://totp/" in uri
@respx.mock
def test_confirm_totp(admin_client: Client) -> None:
respx.post(f"{BASE_URL}/v1/auth/totp/confirm").mock(
return_value=httpx.Response(204)
)
admin_client.confirm_totp("123456") # should not raise
@respx.mock
def test_change_password(admin_client: Client) -> None:
respx.put(f"{BASE_URL}/v1/auth/password").mock(
return_value=httpx.Response(204)
)
admin_client.change_password("old-pass", "new-pass-long-enough") # should not raise
@respx.mock
def test_remove_totp(admin_client: Client) -> None:
respx.delete(f"{BASE_URL}/v1/auth/totp").mock(
return_value=httpx.Response(204)
)
admin_client.remove_totp("acc-001") # should not raise
@respx.mock
def test_issue_service_token(admin_client: Client) -> None:
respx.post(f"{BASE_URL}/v1/token/issue").mock(
return_value=httpx.Response(
200,
json={"token": "svc-token-xyz", "expires_at": "2099-12-31T00:00:00Z"},
)
)
token, expires_at = admin_client.issue_service_token("acc-001")
assert token == "svc-token-xyz"
assert expires_at == "2099-12-31T00:00:00Z"
@respx.mock
def test_revoke_token(admin_client: Client) -> None:
jti = "some-jti-uuid"
respx.delete(f"{BASE_URL}/v1/token/{jti}").mock(
return_value=httpx.Response(204)
)
admin_client.revoke_token(jti) # should not raise
@respx.mock
def test_create_account(admin_client: Client) -> None: def test_create_account(admin_client: Client) -> None:
respx.post(f"{BASE_URL}/v1/accounts").mock( respx.post(f"{BASE_URL}/v1/accounts").mock(
return_value=httpx.Response(201, json=SAMPLE_ACCOUNT) return_value=httpx.Response(201, json=SAMPLE_ACCOUNT)
) )
acc = admin_client.create_account("alice", "user", password="pass123") acc = admin_client.create_account("alice", "human", password="pass123")
assert isinstance(acc, Account) assert isinstance(acc, Account)
assert acc.id == "acc-001" assert acc.id == "acc-001"
assert acc.username == "alice" assert acc.username == "alice"
@@ -161,15 +245,14 @@ def test_create_account_conflict(admin_client: Client) -> None:
return_value=httpx.Response(409, json={"error": "username already exists"}) return_value=httpx.Response(409, json={"error": "username already exists"})
) )
with pytest.raises(MciasConflictError) as exc_info: with pytest.raises(MciasConflictError) as exc_info:
admin_client.create_account("alice", "user") admin_client.create_account("alice", "human")
assert exc_info.value.status_code == 409 assert exc_info.value.status_code == 409
@respx.mock @respx.mock
def test_list_accounts(admin_client: Client) -> None: def test_list_accounts(admin_client: Client) -> None:
second = {**SAMPLE_ACCOUNT, "id": "acc-002"} second = {**SAMPLE_ACCOUNT, "id": "acc-002"}
# API returns a plain JSON array, not a wrapper object
respx.get(f"{BASE_URL}/v1/accounts").mock( respx.get(f"{BASE_URL}/v1/accounts").mock(
return_value=httpx.Response( return_value=httpx.Response(200, json=[SAMPLE_ACCOUNT, second])
200, json={"accounts": [SAMPLE_ACCOUNT, second]}
)
) )
accounts = admin_client.list_accounts() accounts = admin_client.list_accounts()
assert len(accounts) == 2 assert len(accounts) == 2
@@ -183,12 +266,12 @@ def test_get_account(admin_client: Client) -> None:
assert acc.id == "acc-001" assert acc.id == "acc-001"
@respx.mock @respx.mock
def test_update_account(admin_client: Client) -> None: def test_update_account(admin_client: Client) -> None:
updated = {**SAMPLE_ACCOUNT, "status": "suspended"} # PATCH /v1/accounts/{id} returns 204 No Content
respx.patch(f"{BASE_URL}/v1/accounts/acc-001").mock( respx.patch(f"{BASE_URL}/v1/accounts/acc-001").mock(
return_value=httpx.Response(200, json=updated) return_value=httpx.Response(204)
) )
acc = admin_client.update_account("acc-001", status="suspended") result = admin_client.update_account("acc-001", status="inactive")
assert acc.status == "suspended" assert result is None
@respx.mock @respx.mock
def test_delete_account(admin_client: Client) -> None: def test_delete_account(admin_client: Client) -> None:
respx.delete(f"{BASE_URL}/v1/accounts/acc-001").mock( respx.delete(f"{BASE_URL}/v1/accounts/acc-001").mock(
@@ -209,23 +292,11 @@ def test_set_roles(admin_client: Client) -> None:
) )
admin_client.set_roles("acc-001", ["viewer"]) # should not raise admin_client.set_roles("acc-001", ["viewer"]) # should not raise
@respx.mock @respx.mock
def test_issue_service_token(admin_client: Client) -> None: def test_admin_set_password(admin_client: Client) -> None:
respx.post(f"{BASE_URL}/v1/accounts/acc-001/token").mock( respx.put(f"{BASE_URL}/v1/accounts/acc-001/password").mock(
return_value=httpx.Response(
200,
json={"token": "svc-token-xyz", "expires_at": "2099-12-31T00:00:00Z"},
)
)
token, expires_at = admin_client.issue_service_token("acc-001")
assert token == "svc-token-xyz"
assert expires_at == "2099-12-31T00:00:00Z"
@respx.mock
def test_revoke_token(admin_client: Client) -> None:
jti = "some-jti-uuid"
respx.delete(f"{BASE_URL}/v1/token/{jti}").mock(
return_value=httpx.Response(204) return_value=httpx.Response(204)
) )
admin_client.revoke_token(jti) # should not raise admin_client.admin_set_password("acc-001", "new-secure-password") # should not raise
SAMPLE_PG_CREDS: dict[str, object] = { SAMPLE_PG_CREDS: dict[str, object] = {
"host": "db.example.com", "host": "db.example.com",
"port": 5432, "port": 5432,
@@ -256,6 +327,68 @@ def test_set_pg_creds(admin_client: Client) -> None:
username="appuser", username="appuser",
password="s3cr3t", password="s3cr3t",
) # should not raise ) # should not raise
@respx.mock
def test_get_account_tags(admin_client: Client) -> None:
respx.get(f"{BASE_URL}/v1/accounts/acc-001/tags").mock(
return_value=httpx.Response(200, json={"tags": ["env:production", "svc:payments-api"]})
)
tags = admin_client.get_account_tags("acc-001")
assert tags == ["env:production", "svc:payments-api"]
@respx.mock
def test_set_account_tags(admin_client: Client) -> None:
respx.put(f"{BASE_URL}/v1/accounts/acc-001/tags").mock(
return_value=httpx.Response(200, json={"tags": ["env:staging"]})
)
tags = admin_client.set_account_tags("acc-001", ["env:staging"])
assert tags == ["env:staging"]
@respx.mock
def test_list_policy_rules(admin_client: Client) -> None:
respx.get(f"{BASE_URL}/v1/policy/rules").mock(
return_value=httpx.Response(200, json=[SAMPLE_POLICY_RULE])
)
rules = admin_client.list_policy_rules()
assert len(rules) == 1
assert isinstance(rules[0], PolicyRule)
assert rules[0].id == 1
assert rules[0].rule.effect == "allow"
@respx.mock
def test_create_policy_rule(admin_client: Client) -> None:
respx.post(f"{BASE_URL}/v1/policy/rules").mock(
return_value=httpx.Response(201, json=SAMPLE_POLICY_RULE)
)
rule_body = RuleBody(effect="allow", actions=["pgcreds:read"], resource_type="pgcreds")
rule = admin_client.create_policy_rule(
"Allow payments-api to read its own pgcreds",
rule_body,
priority=50,
)
assert isinstance(rule, PolicyRule)
assert rule.id == 1
assert rule.description == "Allow payments-api to read its own pgcreds"
@respx.mock
def test_get_policy_rule(admin_client: Client) -> None:
respx.get(f"{BASE_URL}/v1/policy/rules/1").mock(
return_value=httpx.Response(200, json=SAMPLE_POLICY_RULE)
)
rule = admin_client.get_policy_rule(1)
assert isinstance(rule, PolicyRule)
assert rule.id == 1
assert rule.enabled is True
@respx.mock
def test_update_policy_rule(admin_client: Client) -> None:
updated = {**SAMPLE_POLICY_RULE, "enabled": False}
respx.patch(f"{BASE_URL}/v1/policy/rules/1").mock(
return_value=httpx.Response(200, json=updated)
)
rule = admin_client.update_policy_rule(1, enabled=False)
assert isinstance(rule, PolicyRule)
assert rule.enabled is False
@respx.mock
def test_delete_policy_rule(admin_client: Client) -> None:
respx.delete(f"{BASE_URL}/v1/policy/rules/1").mock(
return_value=httpx.Response(204)
)
admin_client.delete_policy_rule(1) # should not raise
@pytest.mark.parametrize( @pytest.mark.parametrize(
("status_code", "exc_class"), ("status_code", "exc_class"),
[ [
@@ -264,6 +397,7 @@ def test_set_pg_creds(admin_client: Client) -> None:
(403, MciasForbiddenError), (403, MciasForbiddenError),
(404, MciasNotFoundError), (404, MciasNotFoundError),
(409, MciasConflictError), (409, MciasConflictError),
(429, MciasRateLimitError),
(500, MciasServerError), (500, MciasServerError),
], ],
) )

View File

@@ -70,7 +70,7 @@ pub enum MciasError {
Decode(String), Decode(String),
} }
// ---- Data types ---- // ---- Public data types ----
/// Account information returned by the server. /// Account information returned by the server.
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
@@ -101,6 +101,11 @@ pub struct TokenClaims {
pub struct PublicKey { pub struct PublicKey {
pub kty: String, pub kty: String,
pub crv: String, pub crv: String,
/// Key use — always `"sig"` for the MCIAS signing key.
#[serde(rename = "use")]
pub key_use: Option<String>,
/// Algorithm — always `"EdDSA"`. Validate this before trusting the key.
pub alg: Option<String>,
pub x: String, pub x: String,
} }
@@ -114,6 +119,106 @@ pub struct PgCreds {
pub password: String, pub password: String,
} }
/// Audit log entry returned by `GET /v1/audit`.
#[derive(Debug, Clone, Deserialize)]
pub struct AuditEvent {
pub id: i64,
pub event_type: String,
pub event_time: String,
pub ip_address: String,
pub actor_id: Option<String>,
pub target_id: Option<String>,
pub details: Option<String>,
}
/// Paginated response from `GET /v1/audit`.
#[derive(Debug, Clone, Deserialize)]
pub struct AuditPage {
pub events: Vec<AuditEvent>,
pub total: i64,
pub limit: i64,
pub offset: i64,
}
/// Query parameters for `GET /v1/audit`.
#[derive(Debug, Clone, Default)]
pub struct AuditQuery {
pub limit: Option<u32>,
pub offset: Option<u32>,
pub event_type: Option<String>,
pub actor_id: Option<String>,
}
/// A single operator-defined policy rule.
#[derive(Debug, Clone, Deserialize)]
pub struct PolicyRule {
pub id: i64,
pub priority: i64,
pub description: String,
pub rule: RuleBody,
pub enabled: bool,
pub not_before: Option<String>,
pub expires_at: Option<String>,
pub created_at: String,
pub updated_at: String,
}
/// The match conditions and effect of a policy rule.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuleBody {
pub effect: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub roles: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub account_types: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub subject_uuid: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub actions: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub resource_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub owner_matches_subject: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_names: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required_tags: Option<Vec<String>>,
}
/// Request body for `POST /v1/policy/rules`.
#[derive(Debug, Clone, Serialize)]
pub struct CreatePolicyRuleRequest {
pub description: String,
pub rule: RuleBody,
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub not_before: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<String>,
}
/// Request body for `PATCH /v1/policy/rules/{id}`.
#[derive(Debug, Clone, Serialize, Default)]
pub struct UpdatePolicyRuleRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub priority: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub enabled: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub rule: Option<RuleBody>,
#[serde(skip_serializing_if = "Option::is_none")]
pub not_before: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub clear_not_before: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub clear_expires_at: Option<bool>,
}
// ---- Internal request/response types ---- // ---- Internal request/response types ----
#[derive(Serialize)] #[derive(Serialize)]
@@ -136,6 +241,22 @@ struct ErrorResponse {
error: String, error: String,
} }
#[derive(Deserialize)]
struct RolesResponse {
roles: Vec<String>,
}
#[derive(Deserialize)]
struct TagsResponse {
tags: Vec<String>,
}
#[derive(Deserialize)]
struct TotpEnrollResponse {
secret: String,
otpauth_uri: String,
}
// ---- Client options ---- // ---- Client options ----
/// Configuration options for the MCIAS client. /// Configuration options for the MCIAS client.
@@ -160,6 +281,7 @@ pub struct Client {
base_url: String, base_url: String,
http: reqwest::Client, http: reqwest::Client,
/// Bearer token storage. `Arc<RwLock<...>>` so clones share the token. /// Bearer token storage. `Arc<RwLock<...>>` so clones share the token.
/// Security: the token is never logged or included in error messages.
token: Arc<RwLock<Option<String>>>, token: Arc<RwLock<Option<String>>>,
} }
@@ -285,9 +407,9 @@ impl Client {
} }
/// Update an account's status. Allowed values: `"active"`, `"inactive"`. /// Update an account's status. Allowed values: `"active"`, `"inactive"`.
pub async fn update_account(&self, id: &str, status: &str) -> Result<Account, MciasError> { pub async fn update_account(&self, id: &str, status: &str) -> Result<(), MciasError> {
let body = serde_json::json!({ "status": status }); let body = serde_json::json!({ "status": status });
self.patch(&format!("/v1/accounts/{id}"), &body).await self.patch_no_content(&format!("/v1/accounts/{id}"), &body).await
} }
/// Soft-delete an account and revoke all its tokens. /// Soft-delete an account and revoke all its tokens.
@@ -299,13 +421,17 @@ impl Client {
/// Get all roles assigned to an account. /// Get all roles assigned to an account.
pub async fn get_roles(&self, account_id: &str) -> Result<Vec<String>, MciasError> { pub async fn get_roles(&self, account_id: &str) -> Result<Vec<String>, MciasError> {
self.get(&format!("/v1/accounts/{account_id}/roles")).await // Security: spec wraps roles in {"roles": [...]}, unwrap before returning.
let resp: RolesResponse = self.get(&format!("/v1/accounts/{account_id}/roles")).await?;
Ok(resp.roles)
} }
/// Replace the complete role set for an account. /// Replace the complete role set for an account.
pub async fn set_roles(&self, account_id: &str, roles: &[&str]) -> Result<(), MciasError> { pub async fn set_roles(&self, account_id: &str, roles: &[&str]) -> Result<(), MciasError> {
let url = format!("/v1/accounts/{account_id}/roles"); let url = format!("/v1/accounts/{account_id}/roles");
self.put_no_content(&url, roles).await // Spec requires {"roles": [...]} wrapper.
let body = serde_json::json!({ "roles": roles });
self.put_no_content(&url, &body).await
} }
// ---- Token management (admin only) ---- // ---- Token management (admin only) ----
@@ -354,10 +480,142 @@ impl Client {
.await .await
} }
// ---- TOTP enrollment (authenticated) ----
/// Begin TOTP enrollment. Returns `(secret, otpauth_uri)`.
/// The secret is shown once; store it in an authenticator app immediately.
pub async fn enroll_totp(&self) -> Result<(String, String), MciasError> {
let resp: TotpEnrollResponse =
self.post("/v1/auth/totp/enroll", &serde_json::json!({})).await?;
Ok((resp.secret, resp.otpauth_uri))
}
/// Confirm TOTP enrollment with the current 6-digit code.
/// On success, TOTP becomes required for all future logins.
pub async fn confirm_totp(&self, code: &str) -> Result<(), MciasError> {
let body = serde_json::json!({ "code": code });
self.post_empty_body("/v1/auth/totp/confirm", &body).await
}
// ---- Password management ----
/// Change the caller's own password (self-service). Requires the current
/// password to guard against token-theft attacks.
pub async fn change_password(
&self,
current_password: &str,
new_password: &str,
) -> Result<(), MciasError> {
let body = serde_json::json!({
"current_password": current_password,
"new_password": new_password,
});
self.put_no_content("/v1/auth/password", &body).await
}
// ---- Admin: TOTP removal ----
/// Remove TOTP enrollment from an account (admin). Use for recovery when
/// a user loses their TOTP device.
pub async fn remove_totp(&self, account_id: &str) -> Result<(), MciasError> {
let body = serde_json::json!({ "account_id": account_id });
self.delete_with_body("/v1/auth/totp", &body).await
}
// ---- Admin: password reset ----
/// Reset an account's password without requiring the current password.
pub async fn admin_set_password(
&self,
account_id: &str,
new_password: &str,
) -> Result<(), MciasError> {
let body = serde_json::json!({ "new_password": new_password });
self.put_no_content(&format!("/v1/accounts/{account_id}/password"), &body)
.await
}
// ---- Account tags (admin) ----
/// Get all tags for an account.
pub async fn get_tags(&self, account_id: &str) -> Result<Vec<String>, MciasError> {
let resp: TagsResponse =
self.get(&format!("/v1/accounts/{account_id}/tags")).await?;
Ok(resp.tags)
}
/// Replace the full tag set for an account atomically. Pass an empty slice
/// to clear all tags. Returns the updated tag list.
pub async fn set_tags(
&self,
account_id: &str,
tags: &[&str],
) -> Result<Vec<String>, MciasError> {
let body = serde_json::json!({ "tags": tags });
let resp: TagsResponse =
self.put_with_response(&format!("/v1/accounts/{account_id}/tags"), &body).await?;
Ok(resp.tags)
}
// ---- Audit log (admin) ----
/// Query the audit log. Returns a paginated [`AuditPage`].
pub async fn list_audit(&self, query: AuditQuery) -> Result<AuditPage, MciasError> {
let mut params: Vec<(&str, String)> = Vec::new();
if let Some(limit) = query.limit {
params.push(("limit", limit.to_string()));
}
if let Some(offset) = query.offset {
params.push(("offset", offset.to_string()));
}
if let Some(ref et) = query.event_type {
params.push(("event_type", et.clone()));
}
if let Some(ref aid) = query.actor_id {
params.push(("actor_id", aid.clone()));
}
self.get_with_query("/v1/audit", &params).await
}
// ---- Policy rules (admin) ----
/// List all operator-defined policy rules ordered by priority.
pub async fn list_policy_rules(&self) -> Result<Vec<PolicyRule>, MciasError> {
self.get("/v1/policy/rules").await
}
/// Create a new policy rule.
pub async fn create_policy_rule(
&self,
req: CreatePolicyRuleRequest,
) -> Result<PolicyRule, MciasError> {
self.post_expect_status("/v1/policy/rules", &req, StatusCode::CREATED)
.await
}
/// Get a single policy rule by ID.
pub async fn get_policy_rule(&self, id: i64) -> Result<PolicyRule, MciasError> {
self.get(&format!("/v1/policy/rules/{id}")).await
}
/// Update a policy rule. Omitted fields are left unchanged.
pub async fn update_policy_rule(
&self,
id: i64,
req: UpdatePolicyRuleRequest,
) -> Result<PolicyRule, MciasError> {
self.patch(&format!("/v1/policy/rules/{id}"), &req).await
}
/// Delete a policy rule permanently.
pub async fn delete_policy_rule(&self, id: i64) -> Result<(), MciasError> {
self.delete(&format!("/v1/policy/rules/{id}")).await
}
// ---- HTTP helpers ---- // ---- HTTP helpers ----
/// Build a request with the Authorization header set from the stored token. /// Build the Authorization header value from the stored token.
/// Security: the token is read under a read-lock and is not logged. /// Security: the token is read under a read-lock and is never logged.
async fn auth_header(&self) -> Option<header::HeaderValue> { async fn auth_header(&self) -> Option<header::HeaderValue> {
let guard = self.token.read().await; let guard = self.token.read().await;
guard.as_deref().and_then(|tok| { guard.as_deref().and_then(|tok| {
@@ -383,6 +641,22 @@ impl Client {
self.expect_success(resp).await self.expect_success(resp).await
} }
async fn get_with_query<T: for<'de> Deserialize<'de>>(
&self,
path: &str,
params: &[(&str, String)],
) -> Result<T, MciasError> {
let mut req = self
.http
.get(format!("{}{path}", self.base_url))
.query(params);
if let Some(auth) = self.auth_header().await {
req = req.header(header::AUTHORIZATION, auth);
}
let resp = req.send().await?;
self.decode(resp).await
}
async fn post<B: Serialize, T: for<'de> Deserialize<'de>>( async fn post<B: Serialize, T: for<'de> Deserialize<'de>>(
&self, &self,
path: &str, path: &str,
@@ -434,6 +708,19 @@ impl Client {
self.expect_success(resp).await self.expect_success(resp).await
} }
/// POST with a JSON body that expects a 2xx (no body) response.
async fn post_empty_body<B: Serialize>(&self, path: &str, body: &B) -> Result<(), MciasError> {
let mut req = self
.http
.post(format!("{}{path}", self.base_url))
.json(body);
if let Some(auth) = self.auth_header().await {
req = req.header(header::AUTHORIZATION, auth);
}
let resp = req.send().await?;
self.expect_success(resp).await
}
async fn patch<B: Serialize, T: for<'de> Deserialize<'de>>( async fn patch<B: Serialize, T: for<'de> Deserialize<'de>>(
&self, &self,
path: &str, path: &str,
@@ -450,6 +737,18 @@ impl Client {
self.decode(resp).await self.decode(resp).await
} }
async fn patch_no_content<B: Serialize>(&self, path: &str, body: &B) -> Result<(), MciasError> {
let mut req = self
.http
.patch(format!("{}{path}", self.base_url))
.json(body);
if let Some(auth) = self.auth_header().await {
req = req.header(header::AUTHORIZATION, auth);
}
let resp = req.send().await?;
self.expect_success(resp).await
}
async fn put_no_content<B: Serialize + ?Sized>(&self, path: &str, body: &B) -> Result<(), MciasError> { async fn put_no_content<B: Serialize + ?Sized>(&self, path: &str, body: &B) -> Result<(), MciasError> {
let mut req = self let mut req = self
.http .http
@@ -462,6 +761,22 @@ impl Client {
self.expect_success(resp).await self.expect_success(resp).await
} }
async fn put_with_response<B: Serialize, T: for<'de> Deserialize<'de>>(
&self,
path: &str,
body: &B,
) -> Result<T, MciasError> {
let mut req = self
.http
.put(format!("{}{path}", self.base_url))
.json(body);
if let Some(auth) = self.auth_header().await {
req = req.header(header::AUTHORIZATION, auth);
}
let resp = req.send().await?;
self.decode(resp).await
}
async fn delete(&self, path: &str) -> Result<(), MciasError> { async fn delete(&self, path: &str) -> Result<(), MciasError> {
let mut req = self.http.delete(format!("{}{path}", self.base_url)); let mut req = self.http.delete(format!("{}{path}", self.base_url));
if let Some(auth) = self.auth_header().await { if let Some(auth) = self.auth_header().await {
@@ -471,6 +786,19 @@ impl Client {
self.expect_success(resp).await self.expect_success(resp).await
} }
/// DELETE with a JSON request body (used by `DELETE /v1/auth/totp`).
async fn delete_with_body<B: Serialize>(&self, path: &str, body: &B) -> Result<(), MciasError> {
let mut req = self
.http
.delete(format!("{}{path}", self.base_url))
.json(body);
if let Some(auth) = self.auth_header().await {
req = req.header(header::AUTHORIZATION, auth);
}
let resp = req.send().await?;
self.expect_success(resp).await
}
async fn decode<T: for<'de> Deserialize<'de>>( async fn decode<T: for<'de> Deserialize<'de>>(
&self, &self,
resp: reqwest::Response, resp: reqwest::Response,

View File

@@ -1,12 +1,18 @@
use mcias_client::{Client, ClientOptions, MciasError}; use mcias_client::{
AuditQuery, Client, ClientOptions, CreatePolicyRuleRequest, MciasError, RuleBody,
UpdatePolicyRuleRequest,
};
use wiremock::matchers::{method, path}; use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate}; use wiremock::{Mock, MockServer, ResponseTemplate};
async fn admin_client(server: &MockServer) -> Client { async fn admin_client(server: &MockServer) -> Client {
Client::new(&server.uri(), ClientOptions { Client::new(
&server.uri(),
ClientOptions {
token: Some("admin-token".to_string()), token: Some("admin-token".to_string()),
..Default::default() ..Default::default()
}) },
)
.unwrap() .unwrap()
} }
@@ -48,7 +54,10 @@ async fn test_health_server_error() {
let c = Client::new(&server.uri(), ClientOptions::default()).unwrap(); let c = Client::new(&server.uri(), ClientOptions::default()).unwrap();
let err = c.health().await.unwrap_err(); let err = c.health().await.unwrap_err();
assert!(matches!(err, MciasError::Server { .. }), "expected Server error, got {err:?}"); assert!(
matches!(err, MciasError::Server { .. }),
"expected Server error, got {err:?}"
);
} }
// ---- public key ---- // ---- public key ----
@@ -61,6 +70,8 @@ async fn test_get_public_key() {
.respond_with(json_body(serde_json::json!({ .respond_with(json_body(serde_json::json!({
"kty": "OKP", "kty": "OKP",
"crv": "Ed25519", "crv": "Ed25519",
"use": "sig",
"alg": "EdDSA",
"x": "11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo" "x": "11qYAYKxCrfVS_7TyWQHOg7hcvPapiMlrwIaaPcHURo"
}))) })))
.mount(&server) .mount(&server)
@@ -70,6 +81,8 @@ async fn test_get_public_key() {
let pk = c.get_public_key().await.expect("get_public_key should succeed"); let pk = c.get_public_key().await.expect("get_public_key should succeed");
assert_eq!(pk.kty, "OKP"); assert_eq!(pk.kty, "OKP");
assert_eq!(pk.crv, "Ed25519"); assert_eq!(pk.crv, "Ed25519");
assert_eq!(pk.key_use.as_deref(), Some("sig"));
assert_eq!(pk.alg.as_deref(), Some("EdDSA"));
} }
// ---- login ---- // ---- login ----
@@ -99,7 +112,10 @@ async fn test_login_bad_credentials() {
let server = MockServer::start().await; let server = MockServer::start().await;
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/v1/auth/login")) .and(path("/v1/auth/login"))
.respond_with(json_body_status(401, serde_json::json!({"error": "invalid credentials"}))) .respond_with(json_body_status(
401,
serde_json::json!({"error": "invalid credentials"}),
))
.mount(&server) .mount(&server)
.await; .await;
@@ -119,10 +135,13 @@ async fn test_logout_clears_token() {
.mount(&server) .mount(&server)
.await; .await;
let c = Client::new(&server.uri(), ClientOptions { let c = Client::new(
&server.uri(),
ClientOptions {
token: Some("existing-token".to_string()), token: Some("existing-token".to_string()),
..Default::default() ..Default::default()
}) },
)
.unwrap(); .unwrap();
c.logout().await.unwrap(); c.logout().await.unwrap();
assert!(c.token().await.is_none(), "token should be cleared after logout"); assert!(c.token().await.is_none(), "token should be cleared after logout");
@@ -142,10 +161,13 @@ async fn test_renew_token() {
.mount(&server) .mount(&server)
.await; .await;
let c = Client::new(&server.uri(), ClientOptions { let c = Client::new(
&server.uri(),
ClientOptions {
token: Some("old-token".to_string()), token: Some("old-token".to_string()),
..Default::default() ..Default::default()
}) },
)
.unwrap(); .unwrap();
let (tok, _) = c.renew_token().await.unwrap(); let (tok, _) = c.renew_token().await.unwrap();
assert_eq!(tok, "new-token"); assert_eq!(tok, "new-token");
@@ -224,7 +246,10 @@ async fn test_create_account_conflict() {
let server = MockServer::start().await; let server = MockServer::start().await;
Mock::given(method("POST")) Mock::given(method("POST"))
.and(path("/v1/accounts")) .and(path("/v1/accounts"))
.respond_with(json_body_status(409, serde_json::json!({"error": "username already exists"}))) .respond_with(json_body_status(
409,
serde_json::json!({"error": "username already exists"}),
))
.mount(&server) .mount(&server)
.await; .await;
@@ -259,7 +284,10 @@ async fn test_get_account_not_found() {
let server = MockServer::start().await; let server = MockServer::start().await;
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path("/v1/accounts/missing")) .and(path("/v1/accounts/missing"))
.respond_with(json_body_status(404, serde_json::json!({"error": "account not found"}))) .respond_with(json_body_status(
404,
serde_json::json!({"error": "account not found"}),
))
.mount(&server) .mount(&server)
.await; .await;
@@ -271,19 +299,15 @@ async fn test_get_account_not_found() {
#[tokio::test] #[tokio::test]
async fn test_update_account() { async fn test_update_account() {
let server = MockServer::start().await; let server = MockServer::start().await;
// PATCH /v1/accounts/{id} returns 204 No Content per spec.
Mock::given(method("PATCH")) Mock::given(method("PATCH"))
.and(path("/v1/accounts/uuid-1")) .and(path("/v1/accounts/uuid-1"))
.respond_with(json_body(serde_json::json!({ .respond_with(ResponseTemplate::new(204))
"id": "uuid-1", "username": "alice", "account_type": "human",
"status": "inactive", "created_at": "2023-11-15T12:00:00Z",
"updated_at": "2023-11-15T13:00:00Z", "totp_enabled": false
})))
.mount(&server) .mount(&server)
.await; .await;
let c = admin_client(&server).await; let c = admin_client(&server).await;
let a = c.update_account("uuid-1", "inactive").await.unwrap(); c.update_account("uuid-1", "inactive").await.unwrap();
assert_eq!(a.status, "inactive");
} }
#[tokio::test] #[tokio::test]
@@ -305,12 +329,14 @@ async fn test_delete_account() {
async fn test_get_set_roles() { async fn test_get_set_roles() {
let server = MockServer::start().await; let server = MockServer::start().await;
// Spec wraps the array: {"roles": [...]}
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path("/v1/accounts/uuid-1/roles")) .and(path("/v1/accounts/uuid-1/roles"))
.respond_with(json_body(serde_json::json!(["admin", "viewer"]))) .respond_with(json_body(serde_json::json!({"roles": ["admin", "viewer"]})))
.mount(&server) .mount(&server)
.await; .await;
// Spec requires {"roles": [...]} in the PUT body.
Mock::given(method("PUT")) Mock::given(method("PUT"))
.and(path("/v1/accounts/uuid-1/roles")) .and(path("/v1/accounts/uuid-1/roles"))
.respond_with(ResponseTemplate::new(204)) .respond_with(ResponseTemplate::new(204))
@@ -363,7 +389,10 @@ async fn test_pg_creds_not_found() {
let server = MockServer::start().await; let server = MockServer::start().await;
Mock::given(method("GET")) Mock::given(method("GET"))
.and(path("/v1/accounts/uuid-1/pgcreds")) .and(path("/v1/accounts/uuid-1/pgcreds"))
.respond_with(json_body_status(404, serde_json::json!({"error": "no pg credentials found"}))) .respond_with(json_body_status(
404,
serde_json::json!({"error": "no pg credentials found"}),
))
.mount(&server) .mount(&server)
.await; .await;
@@ -405,6 +434,298 @@ async fn test_set_get_pg_creds() {
assert_eq!(creds.password, "dbpass"); assert_eq!(creds.password, "dbpass");
} }
// ---- TOTP ----
#[tokio::test]
async fn test_enroll_totp() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/auth/totp/enroll"))
.respond_with(json_body(serde_json::json!({
"secret": "JBSWY3DPEHPK3PXP",
"otpauth_uri": "otpauth://totp/MCIAS:alice?secret=JBSWY3DPEHPK3PXP&issuer=MCIAS"
})))
.mount(&server)
.await;
let c = admin_client(&server).await;
let (secret, uri) = c.enroll_totp().await.unwrap();
assert_eq!(secret, "JBSWY3DPEHPK3PXP");
assert!(uri.starts_with("otpauth://totp/"));
}
#[tokio::test]
async fn test_confirm_totp() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/auth/totp/confirm"))
.respond_with(ResponseTemplate::new(204))
.mount(&server)
.await;
let c = admin_client(&server).await;
c.confirm_totp("123456").await.unwrap();
}
#[tokio::test]
async fn test_remove_totp() {
let server = MockServer::start().await;
Mock::given(method("DELETE"))
.and(path("/v1/auth/totp"))
.respond_with(ResponseTemplate::new(204))
.mount(&server)
.await;
let c = admin_client(&server).await;
c.remove_totp("some-account-uuid").await.unwrap();
}
// ---- password management ----
#[tokio::test]
async fn test_change_password() {
let server = MockServer::start().await;
Mock::given(method("PUT"))
.and(path("/v1/auth/password"))
.respond_with(ResponseTemplate::new(204))
.mount(&server)
.await;
let c = admin_client(&server).await;
c.change_password("old-pass", "new-pass-long-enough").await.unwrap();
}
#[tokio::test]
async fn test_change_password_wrong_current() {
let server = MockServer::start().await;
Mock::given(method("PUT"))
.and(path("/v1/auth/password"))
.respond_with(json_body_status(
401,
serde_json::json!({"error": "current password is incorrect", "code": "unauthorized"}),
))
.mount(&server)
.await;
let c = admin_client(&server).await;
let err = c
.change_password("wrong", "new-pass-long-enough")
.await
.unwrap_err();
assert!(matches!(err, MciasError::Auth(_)));
}
#[tokio::test]
async fn test_admin_set_password() {
let server = MockServer::start().await;
Mock::given(method("PUT"))
.and(path("/v1/accounts/uuid-1/password"))
.respond_with(ResponseTemplate::new(204))
.mount(&server)
.await;
let c = admin_client(&server).await;
c.admin_set_password("uuid-1", "new-pass-long-enough").await.unwrap();
}
// ---- tags ----
#[tokio::test]
async fn test_get_set_tags() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/accounts/uuid-1/tags"))
.respond_with(json_body(
serde_json::json!({"tags": ["env:production", "svc:payments-api"]}),
))
.mount(&server)
.await;
Mock::given(method("PUT"))
.and(path("/v1/accounts/uuid-1/tags"))
.respond_with(json_body(serde_json::json!({"tags": ["env:staging"]})))
.mount(&server)
.await;
let c = admin_client(&server).await;
let tags = c.get_tags("uuid-1").await.unwrap();
assert_eq!(tags, vec!["env:production", "svc:payments-api"]);
let updated = c.set_tags("uuid-1", &["env:staging"]).await.unwrap();
assert_eq!(updated, vec!["env:staging"]);
}
// ---- audit log ----
#[tokio::test]
async fn test_list_audit() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/audit"))
.respond_with(json_body(serde_json::json!({
"events": [
{
"id": 1,
"event_type": "login_ok",
"event_time": "2026-03-11T09:01:23Z",
"ip_address": "192.0.2.1",
"actor_id": "uuid-1",
"target_id": null,
"details": null
}
],
"total": 1,
"limit": 50,
"offset": 0
})))
.mount(&server)
.await;
let c = admin_client(&server).await;
let page = c.list_audit(AuditQuery::default()).await.unwrap();
assert_eq!(page.total, 1);
assert_eq!(page.events.len(), 1);
assert_eq!(page.events[0].event_type, "login_ok");
}
// ---- policy rules ----
#[tokio::test]
async fn test_list_policy_rules() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/policy/rules"))
.respond_with(json_body(serde_json::json!([])))
.mount(&server)
.await;
let c = admin_client(&server).await;
let rules = c.list_policy_rules().await.unwrap();
assert!(rules.is_empty());
}
#[tokio::test]
async fn test_create_policy_rule() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/v1/policy/rules"))
.respond_with(
ResponseTemplate::new(201)
.set_body_json(serde_json::json!({
"id": 1,
"priority": 100,
"description": "Allow payments-api to read its own pgcreds",
"rule": {"effect": "allow", "roles": ["svc:payments-api"]},
"enabled": true,
"not_before": null,
"expires_at": null,
"created_at": "2026-03-11T09:00:00Z",
"updated_at": "2026-03-11T09:00:00Z"
}))
.insert_header("content-type", "application/json"),
)
.mount(&server)
.await;
let c = admin_client(&server).await;
let rule = c
.create_policy_rule(CreatePolicyRuleRequest {
description: "Allow payments-api to read its own pgcreds".to_string(),
rule: RuleBody {
effect: "allow".to_string(),
roles: Some(vec!["svc:payments-api".to_string()]),
account_types: None,
subject_uuid: None,
actions: None,
resource_type: None,
owner_matches_subject: None,
service_names: None,
required_tags: None,
},
priority: None,
not_before: None,
expires_at: None,
})
.await
.unwrap();
assert_eq!(rule.id, 1);
assert_eq!(rule.description, "Allow payments-api to read its own pgcreds");
}
#[tokio::test]
async fn test_get_policy_rule() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/v1/policy/rules/1"))
.respond_with(json_body(serde_json::json!({
"id": 1,
"priority": 100,
"description": "test rule",
"rule": {"effect": "deny"},
"enabled": true,
"not_before": null,
"expires_at": null,
"created_at": "2026-03-11T09:00:00Z",
"updated_at": "2026-03-11T09:00:00Z"
})))
.mount(&server)
.await;
let c = admin_client(&server).await;
let rule = c.get_policy_rule(1).await.unwrap();
assert_eq!(rule.id, 1);
assert_eq!(rule.rule.effect, "deny");
}
#[tokio::test]
async fn test_update_policy_rule() {
let server = MockServer::start().await;
Mock::given(method("PATCH"))
.and(path("/v1/policy/rules/1"))
.respond_with(json_body(serde_json::json!({
"id": 1,
"priority": 75,
"description": "updated rule",
"rule": {"effect": "allow"},
"enabled": false,
"not_before": null,
"expires_at": null,
"created_at": "2026-03-11T09:00:00Z",
"updated_at": "2026-03-11T10:00:00Z"
})))
.mount(&server)
.await;
let c = admin_client(&server).await;
let rule = c
.update_policy_rule(
1,
UpdatePolicyRuleRequest {
enabled: Some(false),
priority: Some(75),
..Default::default()
},
)
.await
.unwrap();
assert!(!rule.enabled);
assert_eq!(rule.priority, 75);
}
#[tokio::test]
async fn test_delete_policy_rule() {
let server = MockServer::start().await;
Mock::given(method("DELETE"))
.and(path("/v1/policy/rules/1"))
.respond_with(ResponseTemplate::new(204))
.mount(&server)
.await;
let c = admin_client(&server).await;
c.delete_policy_rule(1).await.unwrap();
}
// ---- error type coverage ---- // ---- error type coverage ----
#[tokio::test] #[tokio::test]
@@ -416,11 +737,13 @@ async fn test_forbidden_error() {
.mount(&server) .mount(&server)
.await; .await;
// Use a non-admin token. let c = Client::new(
let c = Client::new(&server.uri(), ClientOptions { &server.uri(),
ClientOptions {
token: Some("user-token".to_string()), token: Some("user-token".to_string()),
..Default::default() ..Default::default()
}) },
)
.unwrap(); .unwrap();
let err = c.list_accounts().await.unwrap_err(); let err = c.list_accounts().await.unwrap_err();
assert!(matches!(err, MciasError::Forbidden(_))); assert!(matches!(err, MciasError::Forbidden(_)));