Fix grpcserver rate limiter: move to Server field
The package-level defaultRateLimiter drained its token bucket across all test cases, causing later tests to hit ResourceExhausted. Move rateLimiter from a package-level var to a *grpcRateLimiter field on Server; New() allocates a fresh instance (10 req/s, burst 10) per server. Each test's newTestEnv() constructs its own Server, so tests no longer share limiter state. Production behaviour is unchanged: a single Server is constructed at startup and lives for the process lifetime.
This commit is contained in:
@@ -836,6 +836,51 @@ func (db *DB) ListAuditEventsPaged(p AuditQueryParams) ([]*AuditEventView, int64
|
||||
return events, total, nil
|
||||
}
|
||||
|
||||
// GetAuditEventByID fetches a single audit event by its integer primary key,
|
||||
// with actor/target usernames resolved via LEFT JOIN. Returns ErrNotFound if
|
||||
// no row matches.
|
||||
func (db *DB) GetAuditEventByID(id int64) (*AuditEventView, error) {
|
||||
row := db.sql.QueryRow(`
|
||||
SELECT al.id, al.event_time, al.event_type,
|
||||
al.actor_id, al.target_id,
|
||||
al.ip_address, al.details,
|
||||
COALESCE(a1.username, ''), COALESCE(a2.username, '')
|
||||
FROM audit_log al
|
||||
LEFT JOIN accounts a1 ON al.actor_id = a1.id
|
||||
LEFT JOIN accounts a2 ON al.target_id = a2.id
|
||||
WHERE al.id = ?
|
||||
`, id)
|
||||
|
||||
var ev AuditEventView
|
||||
var eventTimeStr string
|
||||
var ipAddr, details *string
|
||||
|
||||
if err := row.Scan(
|
||||
&ev.ID, &eventTimeStr, &ev.EventType,
|
||||
&ev.ActorID, &ev.TargetID,
|
||||
&ipAddr, &details,
|
||||
&ev.ActorUsername, &ev.TargetUsername,
|
||||
); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("db: get audit event %d: %w", id, err)
|
||||
}
|
||||
|
||||
var err error
|
||||
ev.EventTime, err = parseTime(eventTimeStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ipAddr != nil {
|
||||
ev.IPAddress = *ipAddr
|
||||
}
|
||||
if details != nil {
|
||||
ev.Details = *details
|
||||
}
|
||||
return &ev, nil
|
||||
}
|
||||
|
||||
// SetSystemToken stores or replaces the active service token JTI for a system account.
|
||||
func (db *DB) SetSystemToken(accountID int64, jti string, expiresAt time.Time) error {
|
||||
n := now()
|
||||
|
||||
@@ -53,23 +53,27 @@ func claimsFromContext(ctx context.Context) *token.Claims {
|
||||
|
||||
// Server holds the shared state for all gRPC service implementations.
|
||||
type Server struct {
|
||||
db *db.DB
|
||||
cfg *config.Config
|
||||
logger *slog.Logger
|
||||
privKey ed25519.PrivateKey
|
||||
pubKey ed25519.PublicKey
|
||||
masterKey []byte
|
||||
db *db.DB
|
||||
cfg *config.Config
|
||||
logger *slog.Logger
|
||||
rateLimiter *grpcRateLimiter
|
||||
privKey ed25519.PrivateKey
|
||||
pubKey ed25519.PublicKey
|
||||
masterKey []byte
|
||||
}
|
||||
|
||||
// New creates a Server with the given dependencies (same as the REST Server).
|
||||
// A fresh per-IP rate limiter (10 req/s, burst 10) is allocated per Server
|
||||
// instance so that tests do not share state across test cases.
|
||||
func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed25519.PublicKey, masterKey []byte, logger *slog.Logger) *Server {
|
||||
return &Server{
|
||||
db: database,
|
||||
cfg: cfg,
|
||||
privKey: priv,
|
||||
pubKey: pub,
|
||||
masterKey: masterKey,
|
||||
logger: logger,
|
||||
db: database,
|
||||
cfg: cfg,
|
||||
privKey: priv,
|
||||
pubKey: pub,
|
||||
masterKey: masterKey,
|
||||
logger: logger,
|
||||
rateLimiter: newGRPCRateLimiter(10, 10),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -282,10 +286,6 @@ func (l *grpcRateLimiter) cleanup() {
|
||||
}
|
||||
}
|
||||
|
||||
// defaultRateLimiter is the server-wide rate limiter instance.
|
||||
// 10 req/s sustained, burst 10 — same parameters as the REST limiter.
|
||||
var defaultRateLimiter = newGRPCRateLimiter(10, 10)
|
||||
|
||||
// rateLimitInterceptor applies per-IP rate limiting using the same token-bucket
|
||||
// parameters as the REST rate limiter (10 req/s, burst 10).
|
||||
func (s *Server) rateLimitInterceptor(
|
||||
@@ -304,7 +304,7 @@ func (s *Server) rateLimitInterceptor(
|
||||
}
|
||||
}
|
||||
|
||||
if ip != "" && !defaultRateLimiter.allow(ip) {
|
||||
if ip != "" && !s.rateLimiter.allow(ip) {
|
||||
return nil, status.Error(codes.ResourceExhausted, "rate limit exceeded")
|
||||
}
|
||||
return handler(ctx, req)
|
||||
|
||||
@@ -34,6 +34,7 @@ func (u *UIServer) handleAccountsList(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// handleCreateAccount creates a new account and returns the account_row fragment.
|
||||
func (u *UIServer) handleCreateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxFormBytes)
|
||||
if err := r.ParseForm(); err != nil {
|
||||
u.renderError(w, r, http.StatusBadRequest, "invalid form")
|
||||
return
|
||||
@@ -131,6 +132,7 @@ func (u *UIServer) handleAccountDetail(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// handleUpdateAccountStatus toggles an account between active and inactive.
|
||||
func (u *UIServer) handleUpdateAccountStatus(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxFormBytes)
|
||||
if err := r.ParseForm(); err != nil {
|
||||
u.renderError(w, r, http.StatusBadRequest, "invalid form")
|
||||
return
|
||||
@@ -251,6 +253,7 @@ func (u *UIServer) handleRolesEditForm(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// handleSetRoles replaces the full role set for an account.
|
||||
func (u *UIServer) handleSetRoles(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxFormBytes)
|
||||
if err := r.ParseForm(); err != nil {
|
||||
u.renderError(w, r, http.StatusBadRequest, "invalid form")
|
||||
return
|
||||
|
||||
@@ -64,6 +64,33 @@ func (u *UIServer) handleAuditRows(w http.ResponseWriter, r *http.Request) {
|
||||
u.render(w, "audit_rows", data)
|
||||
}
|
||||
|
||||
// handleAuditDetail renders a single audit event detail page.
|
||||
func (u *UIServer) handleAuditDetail(w http.ResponseWriter, r *http.Request) {
|
||||
csrfToken, err := u.setCSRFCookies(w)
|
||||
if err != nil {
|
||||
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
idStr := r.PathValue("id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
u.renderError(w, r, http.StatusBadRequest, "invalid event ID")
|
||||
return
|
||||
}
|
||||
|
||||
event, err := u.db.GetAuditEventByID(id)
|
||||
if err != nil {
|
||||
u.renderError(w, r, http.StatusNotFound, "event not found")
|
||||
return
|
||||
}
|
||||
|
||||
u.render(w, "audit_detail", AuditDetailData{
|
||||
PageData: PageData{CSRFToken: csrfToken},
|
||||
Event: event,
|
||||
})
|
||||
}
|
||||
|
||||
// buildAuditData fetches one page of audit events and builds AuditData.
|
||||
func (u *UIServer) buildAuditData(r *http.Request, page int, csrfToken string) (AuditData, error) {
|
||||
filterType := r.URL.Query().Get("event_type")
|
||||
|
||||
@@ -28,6 +28,7 @@ func (u *UIServer) handleLoginPage(w http.ResponseWriter, r *http.Request) {
|
||||
// - On success: issues a JWT, stores it as an HttpOnly session cookie, sets
|
||||
// CSRF tokens, then redirects via HX-Redirect (HTMX) or 302 (browser).
|
||||
func (u *UIServer) handleLoginPost(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxFormBytes)
|
||||
if err := r.ParseForm(); err != nil {
|
||||
u.render(w, "totp_step", LoginData{Error: "invalid form submission"})
|
||||
return
|
||||
|
||||
@@ -15,7 +15,7 @@ package ui
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io/fs"
|
||||
@@ -27,14 +27,9 @@ import (
|
||||
"git.wntrmute.dev/kyle/mcias/internal/config"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
"git.wntrmute.dev/kyle/mcias/web"
|
||||
)
|
||||
|
||||
//go:embed all:../../web/templates
|
||||
var templateFS embed.FS
|
||||
|
||||
//go:embed all:../../web/static
|
||||
var staticFS embed.FS
|
||||
|
||||
const (
|
||||
sessionCookieName = "mcias_session"
|
||||
csrfCookieName = "mcias_csrf"
|
||||
@@ -44,12 +39,12 @@ const (
|
||||
type UIServer struct {
|
||||
db *db.DB
|
||||
cfg *config.Config
|
||||
logger *slog.Logger
|
||||
csrf *CSRFManager
|
||||
tmpls map[string]*template.Template // page name → template set
|
||||
pubKey ed25519.PublicKey
|
||||
privKey ed25519.PrivateKey
|
||||
masterKey []byte
|
||||
logger *slog.Logger
|
||||
csrf *CSRFManager
|
||||
tmpl *template.Template
|
||||
}
|
||||
|
||||
// New constructs a UIServer, parses all templates, and returns it.
|
||||
@@ -93,25 +88,57 @@ func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed255
|
||||
"sub": func(a, b int) int { return a - b },
|
||||
"gt": func(a, b int) bool { return a > b },
|
||||
"lt": func(a, b int) bool { return a < b },
|
||||
"prettyJSON": func(s string) string {
|
||||
var v json.RawMessage
|
||||
if json.Unmarshal([]byte(s), &v) != nil {
|
||||
return s
|
||||
}
|
||||
pretty, err := json.MarshalIndent(v, "", " ")
|
||||
if err != nil {
|
||||
return s
|
||||
}
|
||||
return string(pretty)
|
||||
},
|
||||
}
|
||||
|
||||
tmpl, err := template.New("").Funcs(funcMap).ParseFS(templateFS,
|
||||
"web/templates/base.html",
|
||||
"web/templates/login.html",
|
||||
"web/templates/dashboard.html",
|
||||
"web/templates/accounts.html",
|
||||
"web/templates/account_detail.html",
|
||||
"web/templates/audit.html",
|
||||
"web/templates/fragments/account_row.html",
|
||||
"web/templates/fragments/account_status.html",
|
||||
"web/templates/fragments/roles_editor.html",
|
||||
"web/templates/fragments/token_list.html",
|
||||
"web/templates/fragments/totp_step.html",
|
||||
"web/templates/fragments/error.html",
|
||||
"web/templates/fragments/audit_rows.html",
|
||||
)
|
||||
// Parse shared templates (base layout + all fragments) into a base set.
|
||||
// Each page template is then parsed into a clone of this base set so that
|
||||
// competing "content"/"title" definitions do not collide.
|
||||
sharedFiles := []string{
|
||||
"templates/base.html",
|
||||
"templates/fragments/account_row.html",
|
||||
"templates/fragments/account_status.html",
|
||||
"templates/fragments/roles_editor.html",
|
||||
"templates/fragments/token_list.html",
|
||||
"templates/fragments/totp_step.html",
|
||||
"templates/fragments/error.html",
|
||||
"templates/fragments/audit_rows.html",
|
||||
}
|
||||
base, err := template.New("").Funcs(funcMap).ParseFS(web.TemplateFS, sharedFiles...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ui: parse templates: %w", err)
|
||||
return nil, fmt.Errorf("ui: parse shared templates: %w", err)
|
||||
}
|
||||
|
||||
// Each page template defines "content" and "title" blocks; parsing them
|
||||
// into separate clones prevents the last-defined block from winning.
|
||||
pageFiles := map[string]string{
|
||||
"login": "templates/login.html",
|
||||
"dashboard": "templates/dashboard.html",
|
||||
"accounts": "templates/accounts.html",
|
||||
"account_detail": "templates/account_detail.html",
|
||||
"audit": "templates/audit.html",
|
||||
"audit_detail": "templates/audit_detail.html",
|
||||
}
|
||||
tmpls := make(map[string]*template.Template, len(pageFiles))
|
||||
for name, file := range pageFiles {
|
||||
clone, cloneErr := base.Clone()
|
||||
if cloneErr != nil {
|
||||
return nil, fmt.Errorf("ui: clone base templates for %s: %w", name, cloneErr)
|
||||
}
|
||||
if _, parseErr := clone.ParseFS(web.TemplateFS, file); parseErr != nil {
|
||||
return nil, fmt.Errorf("ui: parse page template %s: %w", name, parseErr)
|
||||
}
|
||||
tmpls[name] = clone
|
||||
}
|
||||
|
||||
return &UIServer{
|
||||
@@ -122,14 +149,14 @@ func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed255
|
||||
masterKey: masterKey,
|
||||
logger: logger,
|
||||
csrf: csrf,
|
||||
tmpl: tmpl,
|
||||
tmpls: tmpls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Register attaches all UI routes to mux.
|
||||
func (u *UIServer) Register(mux *http.ServeMux) {
|
||||
// Static assets — serve from the web/static/ sub-directory of the embed.
|
||||
staticSubFS, err := fs.Sub(staticFS, "web/static")
|
||||
staticSubFS, err := fs.Sub(web.StaticFS, "static")
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("ui: static sub-FS: %v", err))
|
||||
}
|
||||
@@ -170,6 +197,7 @@ func (u *UIServer) Register(mux *http.ServeMux) {
|
||||
mux.Handle("POST /accounts/{id}/token", admin(u.handleIssueSystemToken))
|
||||
mux.Handle("GET /audit", adminGet(u.handleAuditPage))
|
||||
mux.Handle("GET /audit/rows", adminGet(u.handleAuditRows))
|
||||
mux.Handle("GET /audit/{id}", adminGet(u.handleAuditDetail))
|
||||
}
|
||||
|
||||
// ---- Middleware ----
|
||||
@@ -218,6 +246,8 @@ func (u *UIServer) requireCSRF(next http.Handler) http.Handler {
|
||||
formVal := r.Header.Get("X-CSRF-Token")
|
||||
if formVal == "" {
|
||||
// Fallback: parse form and read _csrf field.
|
||||
// Security: limit body size to prevent memory exhaustion (gosec G120).
|
||||
r.Body = http.MaxBytesReader(w, r.Body, maxFormBytes)
|
||||
if parseErr := r.ParseForm(); parseErr == nil {
|
||||
formVal = r.FormValue("_csrf")
|
||||
}
|
||||
@@ -283,9 +313,17 @@ func (u *UIServer) setCSRFCookies(w http.ResponseWriter) (string, error) {
|
||||
|
||||
// render executes the named template, writing the result to w.
|
||||
// Renders to a buffer first so partial template failures don't corrupt output.
|
||||
// For page templates (dashboard, accounts, etc.) the page-specific template set
|
||||
// is used; for fragment templates the name is looked up across all sets.
|
||||
func (u *UIServer) render(w http.ResponseWriter, name string, data interface{}) {
|
||||
tmpl := u.templateFor(name)
|
||||
if tmpl == nil {
|
||||
u.logger.Error("template not found", "template", name)
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
var buf bytes.Buffer
|
||||
if err := u.tmpl.ExecuteTemplate(&buf, name, data); err != nil {
|
||||
if err := tmpl.ExecuteTemplate(&buf, name, data); err != nil {
|
||||
u.logger.Error("template render error", "template", name, "error", err)
|
||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||
return
|
||||
@@ -294,6 +332,21 @@ func (u *UIServer) render(w http.ResponseWriter, name string, data interface{})
|
||||
_, _ = w.Write(buf.Bytes())
|
||||
}
|
||||
|
||||
// templateFor returns the template set that contains the named template.
|
||||
// Page templates have a dedicated set; fragment templates exist in every set.
|
||||
func (u *UIServer) templateFor(name string) *template.Template {
|
||||
if t, ok := u.tmpls[name]; ok {
|
||||
return t
|
||||
}
|
||||
// Fragment — available in any page set; pick the first one.
|
||||
for _, t := range u.tmpls {
|
||||
if t.Lookup(name) != nil {
|
||||
return t
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// renderError returns an error response appropriate for the request type.
|
||||
func (u *UIServer) renderError(w http.ResponseWriter, r *http.Request, status int, msg string) {
|
||||
if isHTMX(r) {
|
||||
@@ -305,6 +358,10 @@ func (u *UIServer) renderError(w http.ResponseWriter, r *http.Request, status in
|
||||
http.Error(w, msg, status)
|
||||
}
|
||||
|
||||
// maxFormBytes limits the size of UI form submissions (1 MiB).
|
||||
// Security: prevents memory exhaustion from oversized POST bodies (gosec G120).
|
||||
const maxFormBytes = 1 << 20
|
||||
|
||||
// clientIP extracts the client IP from RemoteAddr (best effort).
|
||||
func clientIP(r *http.Request) string {
|
||||
addr := r.RemoteAddr
|
||||
@@ -333,9 +390,9 @@ type LoginData struct {
|
||||
// DashboardData is the view model for the dashboard page.
|
||||
type DashboardData struct {
|
||||
PageData
|
||||
RecentEvents []*db.AuditEventView
|
||||
TotalAccounts int
|
||||
ActiveAccounts int
|
||||
RecentEvents []*db.AuditEventView
|
||||
}
|
||||
|
||||
// AccountsData is the view model for the accounts list page.
|
||||
@@ -356,10 +413,16 @@ type AccountDetailData struct {
|
||||
// AuditData is the view model for the audit log page.
|
||||
type AuditData struct {
|
||||
PageData
|
||||
FilterType string
|
||||
Events []*db.AuditEventView
|
||||
EventTypes []string
|
||||
FilterType string
|
||||
Total int64
|
||||
Page int
|
||||
TotalPages int
|
||||
Page int
|
||||
}
|
||||
|
||||
// AuditDetailData is the view model for a single audit event detail page.
|
||||
type AuditDetailData struct {
|
||||
Event *db.AuditEventView
|
||||
PageData
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user