The package-level defaultRateLimiter drained its token bucket across all test cases, causing later tests to hit ResourceExhausted. Move rateLimiter from a package-level var to a *grpcRateLimiter field on Server; New() allocates a fresh instance (10 req/s, burst 10) per server. Each test's newTestEnv() constructs its own Server, so tests no longer share limiter state. Production behaviour is unchanged: a single Server is constructed at startup and lives for the process lifetime.
527 lines
14 KiB
Go
527 lines
14 KiB
Go
// Package mock provides an in-memory MCIAS server for integration tests.
|
|
//
|
|
// Security note: this package is test-only. It never enforces TLS and uses
|
|
// trivial token generation. Do not use in production.
|
|
package mock
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
// Account holds mock account state.
|
|
type Account struct {
|
|
ID string
|
|
Username string
|
|
Password string
|
|
AccountType string
|
|
Status string
|
|
Roles []string
|
|
}
|
|
|
|
// PGCreds holds mock Postgres credential state.
|
|
type PGCreds struct {
|
|
Host string
|
|
Database string
|
|
Username string
|
|
Password string
|
|
Port int
|
|
}
|
|
|
|
// Server is an in-memory MCIAS mock server.
|
|
type Server struct {
|
|
httpServer *httptest.Server
|
|
accounts map[string]*Account // id → account
|
|
byName map[string]*Account // username → account
|
|
tokens map[string]string // token → account id
|
|
revoked map[string]bool // revoked tokens
|
|
pgcreds map[string]*PGCreds // account id → pg creds
|
|
nextSeq int
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewServer creates and starts a new mock server. Call Close() when done.
|
|
func NewServer() *Server {
|
|
s := &Server{
|
|
accounts: make(map[string]*Account),
|
|
byName: make(map[string]*Account),
|
|
tokens: make(map[string]string),
|
|
revoked: make(map[string]bool),
|
|
pgcreds: make(map[string]*PGCreds),
|
|
}
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc("/v1/health", s.handleHealth)
|
|
mux.HandleFunc("/v1/keys/public", s.handlePublicKey)
|
|
mux.HandleFunc("/v1/auth/login", s.handleLogin)
|
|
mux.HandleFunc("/v1/auth/logout", s.handleLogout)
|
|
mux.HandleFunc("/v1/auth/renew", s.handleRenew)
|
|
mux.HandleFunc("/v1/token/validate", s.handleValidate)
|
|
mux.HandleFunc("/v1/token/issue", s.handleIssueToken)
|
|
mux.HandleFunc("/v1/accounts", s.handleAccounts)
|
|
mux.HandleFunc("/v1/accounts/", s.handleAccountByID)
|
|
s.httpServer = httptest.NewServer(mux)
|
|
return s
|
|
}
|
|
|
|
// URL returns the base URL of the mock server.
|
|
func (s *Server) URL() string {
|
|
return s.httpServer.URL
|
|
}
|
|
|
|
// Close shuts down the mock server.
|
|
func (s *Server) Close() {
|
|
s.httpServer.Close()
|
|
}
|
|
|
|
// AddAccount adds a test account and returns its ID.
|
|
func (s *Server) AddAccount(username, password, accountType string, roles ...string) string {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.nextSeq++
|
|
id := fmt.Sprintf("mock-uuid-%d", s.nextSeq)
|
|
acct := &Account{
|
|
ID: id,
|
|
Username: username,
|
|
Password: password,
|
|
AccountType: accountType,
|
|
Status: "active",
|
|
Roles: append([]string{}, roles...),
|
|
}
|
|
s.accounts[id] = acct
|
|
s.byName[username] = acct
|
|
return id
|
|
}
|
|
|
|
// IssueToken directly adds a token for an account (for pre-auth test setup).
|
|
func (s *Server) IssueToken(accountID, token string) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.tokens[token] = accountID
|
|
}
|
|
|
|
// issueToken creates a new token for the given account ID.
|
|
// Caller must hold s.mu (write lock).
|
|
func (s *Server) issueToken(accountID string) string {
|
|
s.nextSeq++
|
|
tok := fmt.Sprintf("mock-token-%d", s.nextSeq)
|
|
s.tokens[tok] = accountID
|
|
return tok
|
|
}
|
|
func (s *Server) bearerToken(r *http.Request) string {
|
|
auth := r.Header.Get("Authorization")
|
|
if len(auth) > 7 && strings.ToLower(auth[:7]) == "bearer " {
|
|
return auth[7:]
|
|
}
|
|
return ""
|
|
}
|
|
func (s *Server) authenticatedAccount(r *http.Request) *Account {
|
|
tok := s.bearerToken(r)
|
|
if tok == "" {
|
|
return nil
|
|
}
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
if s.revoked[tok] {
|
|
return nil
|
|
}
|
|
id, ok := s.tokens[tok]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return s.accounts[id]
|
|
}
|
|
func sendJSON(w http.ResponseWriter, status int, v interface{}) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
_ = json.NewEncoder(w).Encode(v)
|
|
}
|
|
func sendError(w http.ResponseWriter, status int, msg string) {
|
|
sendJSON(w, status, map[string]string{"error": msg})
|
|
}
|
|
func (s *Server) accountToMap(a *Account) map[string]interface{} {
|
|
return map[string]interface{}{
|
|
"id": a.ID,
|
|
"username": a.Username,
|
|
"account_type": a.AccountType,
|
|
"status": a.Status,
|
|
"created_at": "2023-11-15T12:00:00Z",
|
|
"updated_at": "2023-11-15T12:00:00Z",
|
|
"totp_enabled": false,
|
|
}
|
|
}
|
|
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
sendJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
|
}
|
|
func (s *Server) handlePublicKey(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet {
|
|
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
sendJSON(w, http.StatusOK, map[string]string{
|
|
"kty": "OKP",
|
|
"crv": "Ed25519",
|
|
"x": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
|
|
"use": "sig",
|
|
"alg": "EdDSA",
|
|
})
|
|
}
|
|
func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
var req struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
sendError(w, http.StatusBadRequest, "bad request")
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
acct, ok := s.byName[req.Username]
|
|
if !ok || acct.Password != req.Password || acct.Status != "active" {
|
|
sendError(w, http.StatusUnauthorized, "invalid credentials")
|
|
return
|
|
}
|
|
tok := s.issueToken(acct.ID)
|
|
sendJSON(w, http.StatusOK, map[string]string{
|
|
"token": tok,
|
|
"expires_at": "2099-01-01T00:00:00Z",
|
|
})
|
|
}
|
|
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
tok := s.bearerToken(r)
|
|
if tok == "" {
|
|
sendError(w, http.StatusUnauthorized, "unauthorized")
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.revoked[tok] = true
|
|
delete(s.tokens, tok)
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
func (s *Server) handleRenew(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
tok := s.bearerToken(r)
|
|
if tok == "" {
|
|
sendError(w, http.StatusUnauthorized, "unauthorized")
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if s.revoked[tok] {
|
|
sendError(w, http.StatusUnauthorized, "unauthorized")
|
|
return
|
|
}
|
|
aid, ok := s.tokens[tok]
|
|
if !ok {
|
|
sendError(w, http.StatusUnauthorized, "unauthorized")
|
|
return
|
|
}
|
|
s.revoked[tok] = true
|
|
delete(s.tokens, tok)
|
|
newTok := s.issueToken(aid)
|
|
sendJSON(w, http.StatusOK, map[string]string{
|
|
"token": newTok,
|
|
"expires_at": "2099-01-01T00:00:00Z",
|
|
})
|
|
}
|
|
func (s *Server) handleValidate(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
var req struct {
|
|
Token string `json:"token"`
|
|
}
|
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
|
tok := req.Token
|
|
if tok == "" {
|
|
tok = s.bearerToken(r)
|
|
}
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
if tok == "" || s.revoked[tok] {
|
|
sendJSON(w, http.StatusOK, map[string]interface{}{"valid": false})
|
|
return
|
|
}
|
|
aid, ok := s.tokens[tok]
|
|
if !ok {
|
|
sendJSON(w, http.StatusOK, map[string]interface{}{"valid": false})
|
|
return
|
|
}
|
|
acct := s.accounts[aid]
|
|
sendJSON(w, http.StatusOK, map[string]interface{}{
|
|
"valid": true,
|
|
"sub": acct.ID,
|
|
"roles": acct.Roles,
|
|
"expires_at": "2099-01-01T00:00:00Z",
|
|
})
|
|
}
|
|
func (s *Server) handleIssueToken(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
return
|
|
}
|
|
acct := s.authenticatedAccount(r)
|
|
if acct == nil {
|
|
sendError(w, http.StatusUnauthorized, "unauthorized")
|
|
return
|
|
}
|
|
isAdmin := false
|
|
for _, role := range acct.Roles {
|
|
if role == "admin" {
|
|
isAdmin = true
|
|
break
|
|
}
|
|
}
|
|
if !isAdmin {
|
|
sendError(w, http.StatusForbidden, "forbidden")
|
|
return
|
|
}
|
|
var req struct {
|
|
AccountID string `json:"account_id"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.AccountID == "" {
|
|
sendError(w, http.StatusBadRequest, "bad request")
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if _, ok := s.accounts[req.AccountID]; !ok {
|
|
sendError(w, http.StatusNotFound, "account not found")
|
|
return
|
|
}
|
|
tok := s.issueToken(req.AccountID)
|
|
sendJSON(w, http.StatusOK, map[string]string{
|
|
"token": tok,
|
|
"expires_at": "2099-01-01T00:00:00Z",
|
|
})
|
|
}
|
|
func (s *Server) handleAccounts(w http.ResponseWriter, r *http.Request) {
|
|
switch r.Method {
|
|
case http.MethodGet:
|
|
if s.requireAdmin(w, r) == nil {
|
|
return
|
|
}
|
|
s.mu.RLock()
|
|
list := make([]map[string]interface{}, 0, len(s.accounts))
|
|
for _, a := range s.accounts {
|
|
list = append(list, s.accountToMap(a))
|
|
}
|
|
s.mu.RUnlock()
|
|
sendJSON(w, http.StatusOK, list)
|
|
case http.MethodPost:
|
|
if s.requireAdmin(w, r) == nil {
|
|
return
|
|
}
|
|
var req struct {
|
|
Username string `json:"username"`
|
|
AccountType string `json:"account_type"`
|
|
Password string `json:"password"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.Username == "" {
|
|
sendError(w, http.StatusBadRequest, "bad request")
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
if _, exists := s.byName[req.Username]; exists {
|
|
s.mu.Unlock()
|
|
sendError(w, http.StatusConflict, "username already exists")
|
|
return
|
|
}
|
|
s.nextSeq++
|
|
id := fmt.Sprintf("mock-uuid-%d", s.nextSeq)
|
|
newAcct := &Account{
|
|
ID: id,
|
|
Username: req.Username,
|
|
Password: req.Password,
|
|
AccountType: req.AccountType,
|
|
Status: "active",
|
|
Roles: []string{},
|
|
}
|
|
s.accounts[id] = newAcct
|
|
s.byName[req.Username] = newAcct
|
|
s.mu.Unlock()
|
|
sendJSON(w, http.StatusCreated, s.accountToMap(newAcct))
|
|
default:
|
|
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
}
|
|
}
|
|
func (s *Server) handleAccountByID(w http.ResponseWriter, r *http.Request) {
|
|
// Parse path: /v1/accounts/{id}[/roles|/pgcreds]
|
|
path := strings.TrimPrefix(r.URL.Path, "/v1/accounts/")
|
|
parts := strings.SplitN(path, "/", 2)
|
|
id := parts[0]
|
|
sub := ""
|
|
if len(parts) == 2 {
|
|
sub = parts[1]
|
|
}
|
|
switch sub {
|
|
case "roles":
|
|
s.handleRoles(w, r, id)
|
|
case "pgcreds":
|
|
s.handlePGCreds(w, r, id)
|
|
case "":
|
|
s.handleSingleAccount(w, r, id)
|
|
default:
|
|
sendError(w, http.StatusNotFound, "not found")
|
|
}
|
|
}
|
|
func (s *Server) requireAdmin(w http.ResponseWriter, r *http.Request) *Account {
|
|
acct := s.authenticatedAccount(r)
|
|
if acct == nil {
|
|
sendError(w, http.StatusUnauthorized, "unauthorized")
|
|
return nil
|
|
}
|
|
for _, role := range acct.Roles {
|
|
if role == "admin" {
|
|
return acct
|
|
}
|
|
}
|
|
sendError(w, http.StatusForbidden, "forbidden")
|
|
return nil
|
|
}
|
|
func (s *Server) handleSingleAccount(w http.ResponseWriter, r *http.Request, id string) {
|
|
if s.requireAdmin(w, r) == nil {
|
|
return
|
|
}
|
|
s.mu.RLock()
|
|
acct, ok := s.accounts[id]
|
|
s.mu.RUnlock()
|
|
if !ok {
|
|
sendError(w, http.StatusNotFound, "account not found")
|
|
return
|
|
}
|
|
switch r.Method {
|
|
case http.MethodGet:
|
|
sendJSON(w, http.StatusOK, s.accountToMap(acct))
|
|
case http.MethodPatch:
|
|
var req struct {
|
|
Status string `json:"status"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
sendError(w, http.StatusBadRequest, "bad request")
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
if req.Status != "" {
|
|
acct.Status = req.Status
|
|
}
|
|
s.mu.Unlock()
|
|
sendJSON(w, http.StatusOK, s.accountToMap(acct))
|
|
case http.MethodDelete:
|
|
s.mu.Lock()
|
|
delete(s.accounts, id)
|
|
delete(s.byName, acct.Username)
|
|
s.mu.Unlock()
|
|
w.WriteHeader(http.StatusNoContent)
|
|
default:
|
|
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
}
|
|
}
|
|
func (s *Server) handleRoles(w http.ResponseWriter, r *http.Request, id string) {
|
|
if s.requireAdmin(w, r) == nil {
|
|
return
|
|
}
|
|
s.mu.RLock()
|
|
acct, ok := s.accounts[id]
|
|
s.mu.RUnlock()
|
|
if !ok {
|
|
sendError(w, http.StatusNotFound, "account not found")
|
|
return
|
|
}
|
|
switch r.Method {
|
|
case http.MethodGet:
|
|
s.mu.RLock()
|
|
roles := append([]string{}, acct.Roles...)
|
|
s.mu.RUnlock()
|
|
sendJSON(w, http.StatusOK, map[string]interface{}{"roles": roles})
|
|
case http.MethodPut:
|
|
var req struct {
|
|
Roles []string `json:"roles"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
sendError(w, http.StatusBadRequest, "bad request")
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
acct.Roles = req.Roles
|
|
s.mu.Unlock()
|
|
sendJSON(w, http.StatusOK, map[string]interface{}{"roles": req.Roles})
|
|
default:
|
|
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
}
|
|
}
|
|
func (s *Server) handlePGCreds(w http.ResponseWriter, r *http.Request, id string) {
|
|
if s.requireAdmin(w, r) == nil {
|
|
return
|
|
}
|
|
s.mu.RLock()
|
|
_, ok := s.accounts[id]
|
|
s.mu.RUnlock()
|
|
if !ok {
|
|
sendError(w, http.StatusNotFound, "account not found")
|
|
return
|
|
}
|
|
switch r.Method {
|
|
case http.MethodGet:
|
|
s.mu.RLock()
|
|
creds, hasCreds := s.pgcreds[id]
|
|
s.mu.RUnlock()
|
|
if !hasCreds {
|
|
sendError(w, http.StatusNotFound, "no pg credentials")
|
|
return
|
|
}
|
|
sendJSON(w, http.StatusOK, map[string]interface{}{
|
|
"host": creds.Host,
|
|
"port": creds.Port,
|
|
"database": creds.Database,
|
|
"username": creds.Username,
|
|
"password": creds.Password,
|
|
})
|
|
case http.MethodPut:
|
|
var req struct {
|
|
Host string `json:"host"`
|
|
Database string `json:"database"`
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
Port int `json:"port"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
sendError(w, http.StatusBadRequest, "bad request")
|
|
return
|
|
}
|
|
s.mu.Lock()
|
|
s.pgcreds[id] = &PGCreds{
|
|
Host: req.Host,
|
|
Port: req.Port,
|
|
Database: req.Database,
|
|
Username: req.Username,
|
|
Password: req.Password,
|
|
}
|
|
s.mu.Unlock()
|
|
w.WriteHeader(http.StatusNoContent)
|
|
default:
|
|
sendError(w, http.StatusMethodNotAllowed, "method not allowed")
|
|
}
|
|
}
|