Files
mcias/test/mock/mockserver.go
Kyle Isom 4596ea08ab Fix grpcserver rate limiter: move to Server field
The package-level defaultRateLimiter drained its token bucket
across all test cases, causing later tests to hit ResourceExhausted.
Move rateLimiter from a package-level var to a *grpcRateLimiter field
on Server; New() allocates a fresh instance (10 req/s, burst 10) per
server. Each test's newTestEnv() constructs its own Server, so tests
no longer share limiter state.

Production behaviour is unchanged: a single Server is constructed at
startup and lives for the process lifetime.
2026-03-11 19:23:34 -07:00

527 lines
14 KiB
Go

// Package mock provides an in-memory MCIAS server for integration tests.
//
// Security note: this package is test-only. It never enforces TLS and uses
// trivial token generation. Do not use in production.
package mock
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
)
// Account holds mock account state.
type Account struct {
ID string
Username string
Password string
AccountType string
Status string
Roles []string
}
// PGCreds holds mock Postgres credential state.
type PGCreds struct {
Host string
Database string
Username string
Password string
Port int
}
// Server is an in-memory MCIAS mock server.
type Server struct {
httpServer *httptest.Server
accounts map[string]*Account // id → account
byName map[string]*Account // username → account
tokens map[string]string // token → account id
revoked map[string]bool // revoked tokens
pgcreds map[string]*PGCreds // account id → pg creds
nextSeq int
mu sync.RWMutex
}
// NewServer creates and starts a new mock server. Call Close() when done.
func NewServer() *Server {
s := &Server{
accounts: make(map[string]*Account),
byName: make(map[string]*Account),
tokens: make(map[string]string),
revoked: make(map[string]bool),
pgcreds: make(map[string]*PGCreds),
}
mux := http.NewServeMux()
mux.HandleFunc("/v1/health", s.handleHealth)
mux.HandleFunc("/v1/keys/public", s.handlePublicKey)
mux.HandleFunc("/v1/auth/login", s.handleLogin)
mux.HandleFunc("/v1/auth/logout", s.handleLogout)
mux.HandleFunc("/v1/auth/renew", s.handleRenew)
mux.HandleFunc("/v1/token/validate", s.handleValidate)
mux.HandleFunc("/v1/token/issue", s.handleIssueToken)
mux.HandleFunc("/v1/accounts", s.handleAccounts)
mux.HandleFunc("/v1/accounts/", s.handleAccountByID)
s.httpServer = httptest.NewServer(mux)
return s
}
// URL returns the base URL of the mock server.
func (s *Server) URL() string {
return s.httpServer.URL
}
// Close shuts down the mock server.
func (s *Server) Close() {
s.httpServer.Close()
}
// AddAccount adds a test account and returns its ID.
func (s *Server) AddAccount(username, password, accountType string, roles ...string) string {
s.mu.Lock()
defer s.mu.Unlock()
s.nextSeq++
id := fmt.Sprintf("mock-uuid-%d", s.nextSeq)
acct := &Account{
ID: id,
Username: username,
Password: password,
AccountType: accountType,
Status: "active",
Roles: append([]string{}, roles...),
}
s.accounts[id] = acct
s.byName[username] = acct
return id
}
// IssueToken directly adds a token for an account (for pre-auth test setup).
func (s *Server) IssueToken(accountID, token string) {
s.mu.Lock()
defer s.mu.Unlock()
s.tokens[token] = accountID
}
// issueToken creates a new token for the given account ID.
// Caller must hold s.mu (write lock).
func (s *Server) issueToken(accountID string) string {
s.nextSeq++
tok := fmt.Sprintf("mock-token-%d", s.nextSeq)
s.tokens[tok] = accountID
return tok
}
func (s *Server) bearerToken(r *http.Request) string {
auth := r.Header.Get("Authorization")
if len(auth) > 7 && strings.ToLower(auth[:7]) == "bearer " {
return auth[7:]
}
return ""
}
func (s *Server) authenticatedAccount(r *http.Request) *Account {
tok := s.bearerToken(r)
if tok == "" {
return nil
}
s.mu.RLock()
defer s.mu.RUnlock()
if s.revoked[tok] {
return nil
}
id, ok := s.tokens[tok]
if !ok {
return nil
}
return s.accounts[id]
}
func sendJSON(w http.ResponseWriter, status int, v interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(v)
}
func sendError(w http.ResponseWriter, status int, msg string) {
sendJSON(w, status, map[string]string{"error": msg})
}
func (s *Server) accountToMap(a *Account) map[string]interface{} {
return map[string]interface{}{
"id": a.ID,
"username": a.Username,
"account_type": a.AccountType,
"status": a.Status,
"created_at": "2023-11-15T12:00:00Z",
"updated_at": "2023-11-15T12:00:00Z",
"totp_enabled": false,
}
}
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
sendJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
func (s *Server) handlePublicKey(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
sendJSON(w, http.StatusOK, map[string]string{
"kty": "OKP",
"crv": "Ed25519",
"x": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
"use": "sig",
"alg": "EdDSA",
})
}
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
var req struct {
Username string `json:"username"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
sendError(w, http.StatusBadRequest, "bad request")
return
}
s.mu.Lock()
defer s.mu.Unlock()
acct, ok := s.byName[req.Username]
if !ok || acct.Password != req.Password || acct.Status != "active" {
sendError(w, http.StatusUnauthorized, "invalid credentials")
return
}
tok := s.issueToken(acct.ID)
sendJSON(w, http.StatusOK, map[string]string{
"token": tok,
"expires_at": "2099-01-01T00:00:00Z",
})
}
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
tok := s.bearerToken(r)
if tok == "" {
sendError(w, http.StatusUnauthorized, "unauthorized")
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.revoked[tok] = true
delete(s.tokens, tok)
w.WriteHeader(http.StatusNoContent)
}
func (s *Server) handleRenew(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
tok := s.bearerToken(r)
if tok == "" {
sendError(w, http.StatusUnauthorized, "unauthorized")
return
}
s.mu.Lock()
defer s.mu.Unlock()
if s.revoked[tok] {
sendError(w, http.StatusUnauthorized, "unauthorized")
return
}
aid, ok := s.tokens[tok]
if !ok {
sendError(w, http.StatusUnauthorized, "unauthorized")
return
}
s.revoked[tok] = true
delete(s.tokens, tok)
newTok := s.issueToken(aid)
sendJSON(w, http.StatusOK, map[string]string{
"token": newTok,
"expires_at": "2099-01-01T00:00:00Z",
})
}
func (s *Server) handleValidate(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
var req struct {
Token string `json:"token"`
}
_ = json.NewDecoder(r.Body).Decode(&req)
tok := req.Token
if tok == "" {
tok = s.bearerToken(r)
}
s.mu.RLock()
defer s.mu.RUnlock()
if tok == "" || s.revoked[tok] {
sendJSON(w, http.StatusOK, map[string]interface{}{"valid": false})
return
}
aid, ok := s.tokens[tok]
if !ok {
sendJSON(w, http.StatusOK, map[string]interface{}{"valid": false})
return
}
acct := s.accounts[aid]
sendJSON(w, http.StatusOK, map[string]interface{}{
"valid": true,
"sub": acct.ID,
"roles": acct.Roles,
"expires_at": "2099-01-01T00:00:00Z",
})
}
func (s *Server) handleIssueToken(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
acct := s.authenticatedAccount(r)
if acct == nil {
sendError(w, http.StatusUnauthorized, "unauthorized")
return
}
isAdmin := false
for _, role := range acct.Roles {
if role == "admin" {
isAdmin = true
break
}
}
if !isAdmin {
sendError(w, http.StatusForbidden, "forbidden")
return
}
var req struct {
AccountID string `json:"account_id"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.AccountID == "" {
sendError(w, http.StatusBadRequest, "bad request")
return
}
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.accounts[req.AccountID]; !ok {
sendError(w, http.StatusNotFound, "account not found")
return
}
tok := s.issueToken(req.AccountID)
sendJSON(w, http.StatusOK, map[string]string{
"token": tok,
"expires_at": "2099-01-01T00:00:00Z",
})
}
func (s *Server) handleAccounts(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
if s.requireAdmin(w, r) == nil {
return
}
s.mu.RLock()
list := make([]map[string]interface{}, 0, len(s.accounts))
for _, a := range s.accounts {
list = append(list, s.accountToMap(a))
}
s.mu.RUnlock()
sendJSON(w, http.StatusOK, list)
case http.MethodPost:
if s.requireAdmin(w, r) == nil {
return
}
var req struct {
Username string `json:"username"`
AccountType string `json:"account_type"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Username == "" {
sendError(w, http.StatusBadRequest, "bad request")
return
}
s.mu.Lock()
if _, exists := s.byName[req.Username]; exists {
s.mu.Unlock()
sendError(w, http.StatusConflict, "username already exists")
return
}
s.nextSeq++
id := fmt.Sprintf("mock-uuid-%d", s.nextSeq)
newAcct := &Account{
ID: id,
Username: req.Username,
Password: req.Password,
AccountType: req.AccountType,
Status: "active",
Roles: []string{},
}
s.accounts[id] = newAcct
s.byName[req.Username] = newAcct
s.mu.Unlock()
sendJSON(w, http.StatusCreated, s.accountToMap(newAcct))
default:
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
}
}
func (s *Server) handleAccountByID(w http.ResponseWriter, r *http.Request) {
// Parse path: /v1/accounts/{id}[/roles|/pgcreds]
path := strings.TrimPrefix(r.URL.Path, "/v1/accounts/")
parts := strings.SplitN(path, "/", 2)
id := parts[0]
sub := ""
if len(parts) == 2 {
sub = parts[1]
}
switch sub {
case "roles":
s.handleRoles(w, r, id)
case "pgcreds":
s.handlePGCreds(w, r, id)
case "":
s.handleSingleAccount(w, r, id)
default:
sendError(w, http.StatusNotFound, "not found")
}
}
func (s *Server) requireAdmin(w http.ResponseWriter, r *http.Request) *Account {
acct := s.authenticatedAccount(r)
if acct == nil {
sendError(w, http.StatusUnauthorized, "unauthorized")
return nil
}
for _, role := range acct.Roles {
if role == "admin" {
return acct
}
}
sendError(w, http.StatusForbidden, "forbidden")
return nil
}
func (s *Server) handleSingleAccount(w http.ResponseWriter, r *http.Request, id string) {
if s.requireAdmin(w, r) == nil {
return
}
s.mu.RLock()
acct, ok := s.accounts[id]
s.mu.RUnlock()
if !ok {
sendError(w, http.StatusNotFound, "account not found")
return
}
switch r.Method {
case http.MethodGet:
sendJSON(w, http.StatusOK, s.accountToMap(acct))
case http.MethodPatch:
var req struct {
Status string `json:"status"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
sendError(w, http.StatusBadRequest, "bad request")
return
}
s.mu.Lock()
if req.Status != "" {
acct.Status = req.Status
}
s.mu.Unlock()
sendJSON(w, http.StatusOK, s.accountToMap(acct))
case http.MethodDelete:
s.mu.Lock()
delete(s.accounts, id)
delete(s.byName, acct.Username)
s.mu.Unlock()
w.WriteHeader(http.StatusNoContent)
default:
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
}
}
func (s *Server) handleRoles(w http.ResponseWriter, r *http.Request, id string) {
if s.requireAdmin(w, r) == nil {
return
}
s.mu.RLock()
acct, ok := s.accounts[id]
s.mu.RUnlock()
if !ok {
sendError(w, http.StatusNotFound, "account not found")
return
}
switch r.Method {
case http.MethodGet:
s.mu.RLock()
roles := append([]string{}, acct.Roles...)
s.mu.RUnlock()
sendJSON(w, http.StatusOK, map[string]interface{}{"roles": roles})
case http.MethodPut:
var req struct {
Roles []string `json:"roles"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
sendError(w, http.StatusBadRequest, "bad request")
return
}
s.mu.Lock()
acct.Roles = req.Roles
s.mu.Unlock()
sendJSON(w, http.StatusOK, map[string]interface{}{"roles": req.Roles})
default:
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
}
}
func (s *Server) handlePGCreds(w http.ResponseWriter, r *http.Request, id string) {
if s.requireAdmin(w, r) == nil {
return
}
s.mu.RLock()
_, ok := s.accounts[id]
s.mu.RUnlock()
if !ok {
sendError(w, http.StatusNotFound, "account not found")
return
}
switch r.Method {
case http.MethodGet:
s.mu.RLock()
creds, hasCreds := s.pgcreds[id]
s.mu.RUnlock()
if !hasCreds {
sendError(w, http.StatusNotFound, "no pg credentials")
return
}
sendJSON(w, http.StatusOK, map[string]interface{}{
"host": creds.Host,
"port": creds.Port,
"database": creds.Database,
"username": creds.Username,
"password": creds.Password,
})
case http.MethodPut:
var req struct {
Host string `json:"host"`
Database string `json:"database"`
Username string `json:"username"`
Password string `json:"password"`
Port int `json:"port"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
sendError(w, http.StatusBadRequest, "bad request")
return
}
s.mu.Lock()
s.pgcreds[id] = &PGCreds{
Host: req.Host,
Port: req.Port,
Database: req.Database,
Username: req.Username,
Password: req.Password,
}
s.mu.Unlock()
w.WriteHeader(http.StatusNoContent)
default:
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
}
}