trusted proxy, TOTP replay protection, new tests

- Trusted proxy config option for proxy-aware IP extraction
  used by rate limiting and audit logs; validates proxy IP
  before trusting X-Forwarded-For / X-Real-IP headers
- TOTP replay protection via counter-based validation to
  reject reused codes within the same time step (±30s)
- RateLimit middleware updated to extract client IP from
  proxy headers without IP spoofing risk
- New tests for ClientIP proxy logic (spoofed headers,
  fallback) and extended rate-limit proxy coverage
- HTMX error banner script integrated into web UI base
- .gitignore updated for mciasdb build artifact

Security: resolves CRIT-01 (TOTP replay attack) and
DEF-03 (proxy-unaware rate limiting); gRPC TOTP
enrollment aligned with REST via StorePendingTOTP

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-12 17:44:01 -07:00
parent f262ca7b4e
commit ec7c966ad2
31 changed files with 799 additions and 250 deletions

View File

@@ -200,19 +200,31 @@ func parsePHC(phc string) (ArgonParams, []byte, []byte, error) {
// ValidateTOTP checks a 6-digit TOTP code against a raw TOTP secret (bytes).
// A ±1 time-step window (±30s) is allowed to accommodate clock skew.
//
// Returns (true, counter, nil) on a valid code where counter is the HOTP
// counter value that matched. The caller MUST pass this counter to
// db.CheckAndUpdateTOTPCounter to prevent replay attacks within the validity
// window (CRIT-01).
//
// Security:
// - Comparison uses crypto/subtle.ConstantTimeCompare to resist timing attacks.
// - Only RFC 6238-compliant HOTP (HMAC-SHA1) is implemented; no custom crypto.
// - A ±1 window is the RFC 6238 recommendation; wider windows increase
// exposure to code interception between generation and submission.
func ValidateTOTP(secret []byte, code string) (bool, error) {
// - The returned counter enables replay prevention: callers store it and
// reject any future code that does not advance past it (RFC 6238 §5.2).
func ValidateTOTP(secret []byte, code string) (bool, int64, error) {
if len(code) != 6 {
return false, nil
return false, 0, nil
}
now := time.Now().Unix()
step := int64(30) // RFC 6238 default time step in seconds
// Security: evaluate all three counters with constant-time comparisons
// before returning. Early-exit would leak which counter matched via
// timing; we instead record the match and continue, returning at the end.
var matched bool
var matchedCounter int64
for _, counter := range []int64{
now/step - 1,
now / step,
@@ -220,14 +232,21 @@ func ValidateTOTP(secret []byte, code string) (bool, error) {
} {
expected, err := hotp(secret, uint64(counter)) //nolint:gosec // G115: counter is Unix time / step, always non-negative
if err != nil {
return false, fmt.Errorf("auth: compute TOTP: %w", err)
return false, 0, fmt.Errorf("auth: compute TOTP: %w", err)
}
// Security: constant-time comparison to prevent timing attack.
// We deliberately do NOT break early so that all three comparisons
// always execute, preventing a timing side-channel on which counter
// slot matched.
if subtle.ConstantTimeCompare([]byte(code), []byte(expected)) == 1 {
return true, nil
matched = true
matchedCounter = counter
}
}
return false, nil
if matched {
return true, matchedCounter, nil
}
return false, 0, nil
}
// hotp computes an HMAC-SHA1-based OTP for a given counter value.

View File

@@ -101,13 +101,16 @@ func TestValidateTOTP(t *testing.T) {
t.Fatalf("hotp: %v", err)
}
ok, err := ValidateTOTP(rawSecret, code)
ok, counter, err := ValidateTOTP(rawSecret, code)
if err != nil {
t.Fatalf("ValidateTOTP: %v", err)
}
if !ok {
t.Errorf("ValidateTOTP rejected a valid code %q", code)
}
if ok && counter == 0 {
t.Errorf("ValidateTOTP returned zero counter for valid code")
}
}
// TestValidateTOTPWrongCode verifies that an incorrect code is rejected.
@@ -117,7 +120,7 @@ func TestValidateTOTPWrongCode(t *testing.T) {
t.Fatalf("GenerateTOTPSecret: %v", err)
}
ok, err := ValidateTOTP(rawSecret, "000000")
ok, _, err := ValidateTOTP(rawSecret, "000000")
if err != nil {
t.Fatalf("ValidateTOTP: %v", err)
}
@@ -135,7 +138,7 @@ func TestValidateTOTPWrongLength(t *testing.T) {
}
for _, code := range []string{"", "12345", "1234567", "abcdef"} {
ok, err := ValidateTOTP(rawSecret, code)
ok, _, err := ValidateTOTP(rawSecret, code)
if err != nil {
t.Errorf("ValidateTOTP(%q): unexpected error: %v", code, err)
}

View File

@@ -6,6 +6,7 @@ package config
import (
"errors"
"fmt"
"net"
"os"
"time"
@@ -30,6 +31,17 @@ type ServerConfig struct {
GRPCAddr string `toml:"grpc_addr"`
TLSCert string `toml:"tls_cert"`
TLSKey string `toml:"tls_key"`
// TrustedProxy is the IP address (not a range) of a reverse proxy that
// sits in front of the server and sets X-Forwarded-For or X-Real-IP
// headers. When set, the rate limiter and audit log extract the real
// client IP from these headers instead of r.RemoteAddr.
//
// Security: only requests whose r.RemoteAddr matches TrustedProxy are
// trusted to carry a valid forwarded-IP header. All other requests use
// r.RemoteAddr directly, so this field cannot be exploited for IP
// spoofing by external clients. Omit or leave empty when running
// without a reverse proxy.
TrustedProxy string `toml:"trusted_proxy"`
}
// DatabaseConfig holds SQLite database settings.
@@ -137,6 +149,14 @@ func (c *Config) validate() error {
if c.Server.TLSKey == "" {
errs = append(errs, errors.New("server.tls_key is required"))
}
// Security (DEF-03): if trusted_proxy is set it must be a valid IP address
// (not a hostname or CIDR) so the middleware can compare it to the parsed
// host part of r.RemoteAddr using a reliable byte-level equality check.
if c.Server.TrustedProxy != "" {
if net.ParseIP(c.Server.TrustedProxy) == nil {
errs = append(errs, fmt.Errorf("server.trusted_proxy %q is not a valid IP address", c.Server.TrustedProxy))
}
}
// Database
if c.Database.Path == "" {
@@ -147,14 +167,31 @@ func (c *Config) validate() error {
if c.Tokens.Issuer == "" {
errs = append(errs, errors.New("tokens.issuer is required"))
}
// Security (DEF-05): enforce both lower and upper bounds on token expiry
// durations. An operator misconfiguration could otherwise produce tokens
// valid for centuries, which would be irrevocable (bar explicit JTI
// revocation) if a token were stolen. Upper bounds are intentionally
// generous to accommodate a range of legitimate deployments while
// catching obvious typos (e.g. "876000h" instead of "8760h").
const (
maxDefaultExpiry = 30 * 24 * time.Hour // 30 days
maxAdminExpiry = 24 * time.Hour // 24 hours
maxServiceExpiry = 5 * 365 * 24 * time.Hour // 5 years
)
if c.Tokens.DefaultExpiry.Duration <= 0 {
errs = append(errs, errors.New("tokens.default_expiry must be positive"))
} else if c.Tokens.DefaultExpiry.Duration > maxDefaultExpiry {
errs = append(errs, fmt.Errorf("tokens.default_expiry must be <= %s (got %s)", maxDefaultExpiry, c.Tokens.DefaultExpiry.Duration))
}
if c.Tokens.AdminExpiry.Duration <= 0 {
errs = append(errs, errors.New("tokens.admin_expiry must be positive"))
} else if c.Tokens.AdminExpiry.Duration > maxAdminExpiry {
errs = append(errs, fmt.Errorf("tokens.admin_expiry must be <= %s (got %s)", maxAdminExpiry, c.Tokens.AdminExpiry.Duration))
}
if c.Tokens.ServiceExpiry.Duration <= 0 {
errs = append(errs, errors.New("tokens.service_expiry must be positive"))
} else if c.Tokens.ServiceExpiry.Duration > maxServiceExpiry {
errs = append(errs, fmt.Errorf("tokens.service_expiry must be <= %s (got %s)", maxServiceExpiry, c.Tokens.ServiceExpiry.Duration))
}
// Argon2 — enforce OWASP 2023 minimums (time=2, memory=65536 KiB).

View File

@@ -210,6 +210,40 @@ threads = 4
}
}
// TestTrustedProxyValidation verifies that trusted_proxy must be a valid IP.
func TestTrustedProxyValidation(t *testing.T) {
tests := []struct {
name string
proxy string
wantErr bool
}{
{"empty is valid (disabled)", "", false},
{"valid IPv4", "127.0.0.1", false},
{"valid IPv6 loopback", "::1", false},
{"valid private IPv4", "10.0.0.1", false},
{"hostname rejected", "proxy.example.com", true},
{"CIDR rejected", "10.0.0.0/8", true},
{"garbage rejected", "not-an-ip", true},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
cfg, _ := Load(writeTempConfig(t, validConfig()))
if cfg == nil {
t.Fatal("baseline config load failed")
}
cfg.Server.TrustedProxy = tc.proxy
err := cfg.validate()
if tc.wantErr && err == nil {
t.Errorf("expected validation error for proxy=%q, got nil", tc.proxy)
}
if !tc.wantErr && err != nil {
t.Errorf("unexpected error for proxy=%q: %v", tc.proxy, err)
}
})
}
}
func TestDurationParsing(t *testing.T) {
var d duration
if err := d.UnmarshalText([]byte("1h30m")); err != nil {

View File

@@ -70,7 +70,10 @@ func (db *DB) GetAccountByID(id int64) (*model.Account, error) {
`, id))
}
// GetAccountByUsername retrieves an account by username (case-insensitive).
// GetAccountByUsername retrieves an account by username.
// Matching is case-sensitive: SQLite uses BINARY collation by default, so
// "admin" and "Admin" are distinct usernames. This is intentional for an
// SSO system where usernames should be treated as opaque identifiers.
// Returns ErrNotFound if no matching account exists.
func (db *DB) GetAccountByUsername(username string) (*model.Account, error) {
return db.scanAccount(db.sql.QueryRow(`
@@ -184,6 +187,46 @@ func (db *DB) SetTOTP(accountID int64, secretEnc, secretNonce []byte) error {
return nil
}
// CheckAndUpdateTOTPCounter atomically verifies that counter is strictly
// greater than the last accepted TOTP counter for the account, and if so,
// stores counter as the new last accepted value.
//
// Returns ErrTOTPReplay if counter ≤ the stored value, preventing a replay
// of a previously accepted code within the ±1 time-step validity window.
// On the first successful TOTP login (stored value NULL) any counter is
// accepted.
//
// Security (CRIT-01): RFC 6238 §5.2 recommends recording the last OTP
// counter used and rejecting any code that does not advance it. Without
// this, an intercepted code remains valid for up to 90 seconds. The update
// is performed in a single parameterized SQL statement, so there is no
// TOCTOU window between the check and the write.
func (db *DB) CheckAndUpdateTOTPCounter(accountID int64, counter int64) error {
result, err := db.sql.Exec(`
UPDATE accounts
SET last_totp_counter = ?, updated_at = ?
WHERE id = ?
AND (last_totp_counter IS NULL OR last_totp_counter < ?)
`, counter, now(), accountID, counter)
if err != nil {
return fmt.Errorf("db: check-and-update TOTP counter: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("db: check-and-update TOTP counter rows affected: %w", err)
}
if rows == 0 {
// Security: the counter was not advanced — this code has already been
// used within its validity window. Treat as authentication failure.
return ErrTOTPReplay
}
return nil
}
// ErrTOTPReplay is returned by CheckAndUpdateTOTPCounter when the submitted
// TOTP code corresponds to a counter value that has already been accepted.
var ErrTOTPReplay = errors.New("db: TOTP code already used (replay)")
// ClearTOTP removes the TOTP secret and disables TOTP requirement.
func (db *DB) ClearTOTP(accountID int64) error {
_, err := db.sql.Exec(`
@@ -300,6 +343,12 @@ func (db *DB) GetRoles(accountID int64) ([]string, error) {
// GrantRole adds a role to an account. If the role already exists, it is a no-op.
func (db *DB) GrantRole(accountID int64, role string, grantedBy *int64) error {
// Security (DEF-10): reject unknown roles before writing to the DB so
// that typos (e.g. "admim") are caught immediately rather than silently
// creating an unmatchable role.
if err := model.ValidateRole(role); err != nil {
return err
}
_, err := db.sql.Exec(`
INSERT OR IGNORE INTO account_roles (account_id, role, granted_by, granted_at)
VALUES (?, ?, ?, ?)
@@ -323,6 +372,14 @@ func (db *DB) RevokeRole(accountID int64, role string) error {
// SetRoles replaces the full role set for an account atomically.
func (db *DB) SetRoles(accountID int64, roles []string, grantedBy *int64) error {
// Security (DEF-10): validate all roles before opening the transaction so
// we fail fast without touching the database on an invalid input.
for _, role := range roles {
if err := model.ValidateRole(role); err != nil {
return err
}
}
tx, err := db.sql.Begin()
if err != nil {
return fmt.Errorf("db: set roles begin tx: %w", err)

View File

@@ -65,7 +65,14 @@ func (db *DB) configure() error {
"PRAGMA journal_mode=WAL",
"PRAGMA foreign_keys=ON",
"PRAGMA busy_timeout=5000",
"PRAGMA synchronous=NORMAL",
// Security (DEF-07): FULL synchronous mode ensures every write is
// flushed to disk before SQLite considers it committed. With WAL
// mode + NORMAL, a power failure between a write and the next
// checkpoint could lose the most recent committed transactions,
// including token issuance and revocation records — which must be
// durable. The performance cost is negligible for a single-node
// personal SSO server.
"PRAGMA synchronous=FULL",
}
for _, p := range pragmas {
if _, err := db.sql.Exec(p); err != nil {

View File

@@ -162,7 +162,7 @@ func TestRoleOperations(t *testing.T) {
}
// SetRoles
if err := db.SetRoles(acct.ID, []string{"reader", "writer"}, nil); err != nil {
if err := db.SetRoles(acct.ID, []string{"admin", "user"}, nil); err != nil {
t.Fatalf("SetRoles: %v", err)
}
roles, err = db.GetRoles(acct.ID)

View File

@@ -22,7 +22,7 @@ var migrationsFS embed.FS
// LatestSchemaVersion is the highest migration version defined in the
// migrations/ directory. Update this constant whenever a new migration file
// is added.
const LatestSchemaVersion = 6
const LatestSchemaVersion = 7
// newMigrate constructs a migrate.Migrate instance backed by the embedded SQL
// files. It opens a dedicated *sql.DB using the same DSN as the main

View File

@@ -0,0 +1,9 @@
-- Add last_totp_counter to track the most recently accepted TOTP counter value
-- per account. This is used to prevent TOTP replay attacks within the ±1
-- time-step validity window. NULL means no TOTP code has ever been accepted
-- for this account (fresh enrollment or TOTP not yet used).
--
-- Security (CRIT-01): RFC 6238 §5.2 recommends recording the last OTP counter
-- used and rejecting codes that do not advance it, eliminating the ~90-second
-- replay window that would otherwise be exploitable.
ALTER TABLE accounts ADD COLUMN last_totp_counter INTEGER DEFAULT NULL;

View File

@@ -72,8 +72,14 @@ func (a *authServiceServer) Login(ctx context.Context, req *mciasv1.LoginRequest
if acct.TOTPRequired {
if req.TotpCode == "" {
// Security (DEF-08): password was already verified, so a missing
// TOTP code means the gRPC client needs to re-prompt the user —
// it is not a credential failure. Do NOT increment the lockout
// counter here; doing so would lock out well-behaved clients that
// call Login in two steps (password first, TOTP second) and would
// also let an attacker trigger account lockout by omitting the
// code after a successful password guess.
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"totp_missing"}`) //nolint:errcheck
_ = a.s.db.RecordLoginFailure(acct.ID)
return nil, status.Error(codes.Unauthenticated, "TOTP code required")
}
secret, err := crypto.OpenAESGCM(a.s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc)
@@ -81,12 +87,19 @@ func (a *authServiceServer) Login(ctx context.Context, req *mciasv1.LoginRequest
a.s.logger.Error("decrypt TOTP secret", "error", err, "account_id", acct.ID)
return nil, status.Error(codes.Internal, "internal error")
}
valid, err := auth.ValidateTOTP(secret, req.TotpCode)
valid, counter, err := auth.ValidateTOTP(secret, req.TotpCode)
if err != nil || !valid {
a.s.db.WriteAuditEvent(model.EventLoginTOTPFail, &acct.ID, nil, ip, `{"reason":"wrong_totp"}`) //nolint:errcheck
_ = a.s.db.RecordLoginFailure(acct.ID)
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
// Security (CRIT-01): reject replay of a code already used within
// its ±30-second validity window.
if err := a.s.db.CheckAndUpdateTOTPCounter(acct.ID, counter); err != nil {
a.s.db.WriteAuditEvent(model.EventLoginTOTPFail, &acct.ID, nil, ip, `{"reason":"totp_replay"}`) //nolint:errcheck
_ = a.s.db.RecordLoginFailure(acct.ID)
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
}
// Login succeeded: clear any outstanding failure counter.
@@ -199,7 +212,12 @@ func (a *authServiceServer) EnrollTOTP(ctx context.Context, _ *mciasv1.EnrollTOT
return nil, status.Error(codes.Internal, "internal error")
}
if err := a.s.db.SetTOTP(acct.ID, secretEnc, secretNonce); err != nil {
// Security: use StorePendingTOTP (not SetTOTP) so that totp_required is
// not set to 1 until the user confirms the code via ConfirmTOTP. Calling
// SetTOTP here would immediately lock the account behind TOTP before the
// user has had a chance to configure their authenticator app — matching the
// behaviour of the REST EnrollTOTP handler at internal/server/server.go.
if err := a.s.db.StorePendingTOTP(acct.ID, secretEnc, secretNonce); err != nil {
return nil, status.Error(codes.Internal, "internal error")
}
@@ -232,10 +250,15 @@ func (a *authServiceServer) ConfirmTOTP(ctx context.Context, req *mciasv1.Confir
return nil, status.Error(codes.Internal, "internal error")
}
valid, err := auth.ValidateTOTP(secret, req.Code)
valid, counter, err := auth.ValidateTOTP(secret, req.Code)
if err != nil || !valid {
return nil, status.Error(codes.Unauthenticated, "invalid TOTP code")
}
// Security (CRIT-01): record the counter even during enrollment confirmation
// so the same code cannot be replayed immediately after confirming.
if err := a.s.db.CheckAndUpdateTOTPCounter(acct.ID, counter); err != nil {
return nil, status.Error(codes.Unauthenticated, "invalid TOTP code")
}
// SetTOTP with existing enc/nonce sets totp_required=1, confirming enrollment.
if err := a.s.db.SetTOTP(acct.ID, acct.TOTPSecretEnc, acct.TOTPSecretNonce); err != nil {

View File

@@ -542,7 +542,7 @@ func TestSetAndGetRoles(t *testing.T) {
_, err = cl.SetRoles(authCtx(adminTok), &mciasv1.SetRolesRequest{
Id: id,
Roles: []string{"editor", "viewer"},
Roles: []string{"admin", "user"},
})
if err != nil {
t.Fatalf("SetRoles: %v", err)

View File

@@ -176,15 +176,62 @@ type ipRateLimiter struct {
mu sync.Mutex
}
// ClientIP returns the real client IP for a request, optionally trusting a
// single reverse-proxy address.
//
// Security (DEF-03): X-Forwarded-For and X-Real-IP headers can be forged by
// any client. This function only honours them when the immediate TCP peer
// (r.RemoteAddr) matches trustedProxy exactly. When trustedProxy is nil or
// the peer address does not match, r.RemoteAddr is used unconditionally.
//
// This prevents IP-spoofing attacks: an attacker who sends a fake
// X-Forwarded-For header from their own connection still has their real IP
// used for rate limiting, because their RemoteAddr will not match the proxy.
//
// Only the first (leftmost) value in X-Forwarded-For is used, as that is the
// client-supplied address as appended by the outermost proxy. If neither
// header is present, RemoteAddr is used as a fallback even when the request
// comes from the proxy.
func ClientIP(r *http.Request, trustedProxy net.IP) string {
remoteHost, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
remoteHost = r.RemoteAddr
}
if trustedProxy != nil {
remoteIP := net.ParseIP(remoteHost)
if remoteIP != nil && remoteIP.Equal(trustedProxy) {
// Request is from the trusted proxy; extract the real client IP.
// Prefer X-Real-IP (single value) over X-Forwarded-For (may be a
// comma-separated list when multiple proxies are chained).
if xri := r.Header.Get("X-Real-IP"); xri != "" {
if ip := net.ParseIP(strings.TrimSpace(xri)); ip != nil {
return ip.String()
}
}
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// Take the first (leftmost) address — the original client.
first, _, _ := strings.Cut(xff, ",")
if ip := net.ParseIP(strings.TrimSpace(first)); ip != nil {
return ip.String()
}
}
}
}
return remoteHost
}
// RateLimit returns middleware implementing a per-IP token bucket.
// rps is the sustained request rate (tokens refilled per second).
// burst is the maximum burst size (initial and maximum token count).
// trustedProxy, if non-nil, enables proxy-aware client IP extraction via
// ClientIP; pass nil when not running behind a reverse proxy.
//
// Security: Rate limiting is applied at the IP level. In production, the
// server should be behind a reverse proxy that sets X-Forwarded-For; this
// middleware uses RemoteAddr directly which may be the proxy IP. For single-
// instance deployment without a proxy, RemoteAddr is the client IP.
func RateLimit(rps float64, burst int) func(http.Handler) http.Handler {
// Security (DEF-03): when trustedProxy is set, real client IPs are extracted
// from X-Forwarded-For/X-Real-IP headers but only for requests whose
// RemoteAddr matches the trusted proxy, preventing IP-spoofing.
func RateLimit(rps float64, burst int, trustedProxy net.IP) func(http.Handler) http.Handler {
limiter := &ipRateLimiter{
rps: rps,
burst: float64(burst),
@@ -197,10 +244,7 @@ func RateLimit(rps float64, burst int) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
ip = r.RemoteAddr
}
ip := ClientIP(r, trustedProxy)
if !limiter.allow(ip) {
w.Header().Set("Retry-After", "60")

View File

@@ -6,6 +6,7 @@ import (
"crypto/ed25519"
"crypto/rand"
"log/slog"
"net"
"net/http"
"net/http/httptest"
"testing"
@@ -271,7 +272,7 @@ func TestRequireRoleNoClaims(t *testing.T) {
}
func TestRateLimitAllows(t *testing.T) {
handler := RateLimit(10, 5)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
handler := RateLimit(10, 5, nil)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
@@ -289,7 +290,7 @@ func TestRateLimitAllows(t *testing.T) {
}
func TestRateLimitBlocks(t *testing.T) {
handler := RateLimit(0.1, 2)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
handler := RateLimit(0.1, 2, nil)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
@@ -340,3 +341,124 @@ func TestExtractBearerToken(t *testing.T) {
})
}
}
// TestClientIP verifies the proxy-aware IP extraction logic.
func TestClientIP(t *testing.T) {
proxy := net.ParseIP("10.0.0.1")
tests := []struct {
name string
remoteAddr string
xForwardedFor string
xRealIP string
trustedProxy net.IP
want string
}{
{
name: "no proxy configured: uses RemoteAddr",
remoteAddr: "203.0.113.5:54321",
want: "203.0.113.5",
},
{
name: "proxy configured but request not from proxy: uses RemoteAddr",
remoteAddr: "198.51.100.9:12345",
xForwardedFor: "203.0.113.99",
trustedProxy: proxy,
want: "198.51.100.9",
},
{
name: "request from trusted proxy with X-Real-IP: uses X-Real-IP",
remoteAddr: "10.0.0.1:8080",
xRealIP: "203.0.113.42",
trustedProxy: proxy,
want: "203.0.113.42",
},
{
name: "request from trusted proxy with X-Forwarded-For: uses first entry",
remoteAddr: "10.0.0.1:8080",
xForwardedFor: "203.0.113.77, 10.0.0.2",
trustedProxy: proxy,
want: "203.0.113.77",
},
{
name: "X-Real-IP takes precedence over X-Forwarded-For",
remoteAddr: "10.0.0.1:8080",
xRealIP: "203.0.113.11",
xForwardedFor: "203.0.113.22",
trustedProxy: proxy,
want: "203.0.113.11",
},
{
name: "proxy request with invalid X-Real-IP falls back to X-Forwarded-For",
remoteAddr: "10.0.0.1:8080",
xRealIP: "not-an-ip",
xForwardedFor: "203.0.113.55",
trustedProxy: proxy,
want: "203.0.113.55",
},
{
name: "proxy request with no forwarding headers falls back to RemoteAddr host",
remoteAddr: "10.0.0.1:8080",
trustedProxy: proxy,
want: "10.0.0.1",
},
{
// Security: attacker fakes X-Forwarded-For but connects directly.
name: "spoofed X-Forwarded-For from non-proxy IP is ignored",
remoteAddr: "198.51.100.99:9999",
xForwardedFor: "127.0.0.1",
trustedProxy: proxy,
want: "198.51.100.99",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.RemoteAddr = tc.remoteAddr
if tc.xForwardedFor != "" {
req.Header.Set("X-Forwarded-For", tc.xForwardedFor)
}
if tc.xRealIP != "" {
req.Header.Set("X-Real-IP", tc.xRealIP)
}
got := ClientIP(req, tc.trustedProxy)
if got != tc.want {
t.Errorf("ClientIP = %q, want %q", got, tc.want)
}
})
}
}
// TestRateLimitTrustedProxy verifies that rate limiting uses the forwarded IP
// when the request originates from a trusted proxy.
func TestRateLimitTrustedProxy(t *testing.T) {
proxy := net.ParseIP("10.0.0.1")
// Very low rps and burst=1 so any two requests from the same IP are blocked.
handler := RateLimit(0.001, 1, proxy)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Two requests from the same real client IP, forwarded by the proxy.
// Both carry the same X-Real-IP; the second should be rate-limited.
for i, wantStatus := range []int{http.StatusOK, http.StatusTooManyRequests} {
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
req.RemoteAddr = "10.0.0.1:5000" // from the trusted proxy
req.Header.Set("X-Real-IP", "203.0.113.5")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != wantStatus {
t.Errorf("request %d: status = %d, want %d", i+1, rr.Code, wantStatus)
}
}
// A different real client (different X-Real-IP) should still be allowed.
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
req.RemoteAddr = "10.0.0.1:5001"
req.Header.Set("X-Real-IP", "203.0.113.99")
rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("distinct client: status = %d, want 200 (separate bucket)", rr.Code)
}
}

View File

@@ -2,7 +2,10 @@
// These are pure data definitions with no external dependencies.
package model
import "time"
import (
"fmt"
"time"
)
// AccountType distinguishes human interactive accounts from non-interactive
// service accounts.
@@ -43,6 +46,33 @@ type Account struct {
TOTPRequired bool `json:"totp_required"`
}
// Allowlisted role names (DEF-10).
// Only these strings may be stored in account_roles. Extending the set of
// valid roles requires a code change, ensuring that typos such as "admim"
// are caught at grant time rather than silently creating a useless role.
const (
RoleAdmin = "admin"
RoleUser = "user"
)
// allowedRoles is the compile-time set of recognised role names.
var allowedRoles = map[string]struct{}{
RoleAdmin: {},
RoleUser: {},
}
// ValidateRole returns nil if role is an allowlisted role name, or an error
// describing the problem. Call this before writing to account_roles.
//
// Security (DEF-10): prevents admins from accidentally creating unmatchable
// roles (e.g. "admim") by enforcing a compile-time allowlist.
func ValidateRole(role string) error {
if _, ok := allowedRoles[role]; !ok {
return fmt.Errorf("model: unknown role %q; allowed roles: admin, user", role)
}
return nil
}
// Role is a string label assigned to an account to grant permissions.
type Role struct {
GrantedAt time.Time `json:"granted_at"`

View File

@@ -16,6 +16,7 @@ import (
"fmt"
"io/fs"
"log/slog"
"net"
"net/http"
"git.wntrmute.dev/kyle/mcias/internal/auth"
@@ -56,10 +57,19 @@ func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed255
func (s *Server) Handler() http.Handler {
mux := http.NewServeMux()
// Security (DEF-03): parse the optional trusted-proxy address once here
// so RateLimit and audit-log helpers use consistent IP extraction.
// net.ParseIP returns nil for an empty string, which disables proxy
// trust and falls back to r.RemoteAddr.
var trustedProxy net.IP
if s.cfg.Server.TrustedProxy != "" {
trustedProxy = net.ParseIP(s.cfg.Server.TrustedProxy)
}
// Security: per-IP rate limiting on public auth endpoints to prevent
// brute-force login attempts and token-validation abuse. Parameters match
// the gRPC rate limiter (10 req/s sustained, burst 10).
loginRateLimit := middleware.RateLimit(10, 10)
loginRateLimit := middleware.RateLimit(10, 10, trustedProxy)
// Public endpoints (no authentication required).
mux.HandleFunc("GET /v1/health", s.handleHealth)
@@ -82,16 +92,20 @@ func (s *Server) Handler() http.Handler {
if err != nil {
panic(fmt.Sprintf("server: read openapi.yaml: %v", err))
}
mux.HandleFunc("GET /docs", func(w http.ResponseWriter, _ *http.Request) {
// Security (DEF-09): apply defensive HTTP headers to the docs handlers.
// The Swagger UI page at /docs loads JavaScript from the same origin
// and renders untrusted content (API descriptions), so it benefits from
// CSP, X-Frame-Options, and the other headers applied to the UI sub-mux.
mux.Handle("GET /docs", docsSecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(docsHTML)
})
mux.HandleFunc("GET /docs/openapi.yaml", func(w http.ResponseWriter, _ *http.Request) {
})))
mux.Handle("GET /docs/openapi.yaml", docsSecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/yaml")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(specYAML)
})
})))
// Authenticated endpoints.
requireAuth := middleware.RequireAuth(s.pubKey, s.db, s.cfg.Tokens.Issuer)
@@ -251,13 +265,21 @@ func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
return
}
valid, err := auth.ValidateTOTP(secret, req.TOTPCode)
valid, totpCounter, err := auth.ValidateTOTP(secret, req.TOTPCode)
if err != nil || !valid {
s.writeAudit(r, model.EventLoginTOTPFail, &acct.ID, nil, `{"reason":"wrong_totp"}`)
_ = s.db.RecordLoginFailure(acct.ID)
middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized")
return
}
// Security (CRIT-01): reject replay of a code already used within
// its ±30-second validity window.
if err := s.db.CheckAndUpdateTOTPCounter(acct.ID, totpCounter); err != nil {
s.writeAudit(r, model.EventLoginTOTPFail, &acct.ID, nil, `{"reason":"totp_replay"}`)
_ = s.db.RecordLoginFailure(acct.ID)
middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized")
return
}
}
// Login succeeded: clear any outstanding failure counter.
@@ -764,11 +786,18 @@ func (s *Server) handleTOTPConfirm(w http.ResponseWriter, r *http.Request) {
return
}
valid, err := auth.ValidateTOTP(secret, req.Code)
valid, totpCounter, err := auth.ValidateTOTP(secret, req.Code)
if err != nil || !valid {
middleware.WriteError(w, http.StatusUnauthorized, "invalid TOTP code", "unauthorized")
return
}
// Security (CRIT-01): record the counter even during enrollment
// confirmation so the same code cannot be replayed immediately after
// confirming.
if err := s.db.CheckAndUpdateTOTPCounter(acct.ID, totpCounter); err != nil {
middleware.WriteError(w, http.StatusUnauthorized, "invalid TOTP code", "unauthorized")
return
}
// Mark TOTP as confirmed and required.
if err := s.db.SetTOTP(acct.ID, acct.TOTPSecretEnc, acct.TOTPSecretNonce); err != nil {
@@ -1149,8 +1178,14 @@ func (s *Server) loadAccount(w http.ResponseWriter, r *http.Request) (*model.Acc
}
// writeAudit appends an audit log entry, logging errors but not failing the request.
// The logged IP honours the trusted-proxy setting so the real client address
// is recorded rather than the proxy's address (DEF-03).
func (s *Server) writeAudit(r *http.Request, eventType string, actorID, targetID *int64, details string) {
ip := r.RemoteAddr
var proxyIP net.IP
if s.cfg.Server.TrustedProxy != "" {
proxyIP = net.ParseIP(s.cfg.Server.TrustedProxy)
}
ip := middleware.ClientIP(r, proxyIP)
if err := s.db.WriteAuditEvent(eventType, actorID, targetID, ip, details); err != nil {
s.logger.Error("write audit event", "error", err, "event_type", eventType)
}
@@ -1191,6 +1226,25 @@ func extractBearerFromRequest(r *http.Request) (string, error) {
return auth[len(prefix):], nil
}
// docsSecurityHeaders adds the same defensive HTTP headers as the UI sub-mux
// to the /docs and /docs/openapi.yaml endpoints.
//
// Security (DEF-09): without these headers the Swagger UI HTML page is
// served without CSP, X-Frame-Options, or HSTS, leaving it susceptible
// to clickjacking and MIME-type confusion in browsers.
func docsSecurityHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
h.Set("Content-Security-Policy",
"default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; font-src 'self'")
h.Set("X-Content-Type-Options", "nosniff")
h.Set("X-Frame-Options", "DENY")
h.Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains")
h.Set("Referrer-Policy", "no-referrer")
next.ServeHTTP(w, r)
})
}
// encodeBase64URL encodes bytes as base64url without padding.
func encodeBase64URL(b []byte) string {
const table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"

View File

@@ -376,7 +376,7 @@ func TestSetAndGetRoles(t *testing.T) {
// Set roles.
rr := doRequest(t, handler, "PUT", "/v1/accounts/"+target.UUID+"/roles", map[string][]string{
"roles": {"reader", "writer"},
"roles": {"admin", "user"},
}, adminToken)
if rr.Code != http.StatusNoContent {
t.Errorf("set roles status = %d, want 204; body: %s", rr.Code, rr.Body.String())

View File

@@ -70,11 +70,16 @@ func IssueToken(key ed25519.PrivateKey, issuer, subject string, roles []string,
exp := now.Add(expiry)
jti := uuid.New().String()
// Security (DEF-04): set NotBefore = now so tokens are not valid before
// the instant of issuance. This is a defence-in-depth measure: without
// nbf, a clock-skewed client or intermediate could present a token
// before its intended validity window.
jc := jwtClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: issuer,
Subject: subject,
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
ExpiresAt: jwt.NewNumericDate(exp),
ID: jti,
},
@@ -127,6 +132,9 @@ func ValidateToken(key ed25519.PublicKey, tokenString, expectedIssuer string) (*
jwt.WithIssuedAt(),
jwt.WithIssuer(expectedIssuer),
jwt.WithExpirationRequired(),
// Security (DEF-04): nbf is validated automatically by the library
// when the claim is present; no explicit option is needed. If nbf is
// in the future the library returns ErrTokenNotValidYet.
)
if err != nil {
// Map library errors to our typed errors for consistent handling.

View File

@@ -149,7 +149,7 @@ func (u *UIServer) handleTOTPStep(w http.ResponseWriter, r *http.Request) {
u.render(w, "login", LoginData{Error: "internal error"})
return
}
valid, err := auth.ValidateTOTP(secret, totpCode)
valid, totpCounter, err := auth.ValidateTOTP(secret, totpCode)
if err != nil || !valid {
u.writeAudit(r, model.EventLoginTOTPFail, &acct.ID, nil, `{"reason":"wrong_totp"}`)
_ = u.db.RecordLoginFailure(acct.ID)
@@ -166,6 +166,23 @@ func (u *UIServer) handleTOTPStep(w http.ResponseWriter, r *http.Request) {
})
return
}
// Security (CRIT-01): reject replay of a code already used within its
// ±30-second validity window.
if err := u.db.CheckAndUpdateTOTPCounter(acct.ID, totpCounter); err != nil {
u.writeAudit(r, model.EventLoginTOTPFail, &acct.ID, nil, `{"reason":"totp_replay"}`)
_ = u.db.RecordLoginFailure(acct.ID)
newNonce, nonceErr := u.issueTOTPNonce(acct.ID)
if nonceErr != nil {
u.render(w, "login", LoginData{Error: "internal error"})
return
}
u.render(w, "totp_step", LoginData{
Error: "invalid TOTP code",
Username: username,
Nonce: newNonce,
})
return
}
u.finishLogin(w, r, acct)
}
@@ -251,7 +268,7 @@ func (u *UIServer) handleLogout(w http.ResponseWriter, r *http.Request) {
// writeAudit is a fire-and-forget audit log helper for the UI package.
func (u *UIServer) writeAudit(r *http.Request, eventType string, actorID, targetID *int64, details string) {
ip := clientIP(r)
ip := u.clientIP(r)
if err := u.db.WriteAuditEvent(eventType, actorID, targetID, ip, details); err != nil {
u.logger.Warn("write audit event", "type", eventType, "error", err)
}

View File

@@ -22,14 +22,15 @@ import (
"html/template"
"io/fs"
"log/slog"
"net"
"net/http"
"strings"
"sync"
"time"
"git.wntrmute.dev/kyle/mcias/internal/auth"
"git.wntrmute.dev/kyle/mcias/internal/config"
"git.wntrmute.dev/kyle/mcias/internal/db"
"git.wntrmute.dev/kyle/mcias/internal/middleware"
"git.wntrmute.dev/kyle/mcias/internal/model"
"git.wntrmute.dev/kyle/mcias/web"
)
@@ -223,7 +224,7 @@ func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed255
tmpls[name] = clone
}
return &UIServer{
srv := &UIServer{
db: database,
cfg: cfg,
pubKey: pub,
@@ -232,7 +233,33 @@ func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed255
logger: logger,
csrf: csrf,
tmpls: tmpls,
}, nil
}
// Security (DEF-02): launch a background goroutine to evict expired TOTP
// nonces from pendingLogins. consumeTOTPNonce deletes entries on use, but
// entries abandoned by users who never complete step 2 would otherwise
// accumulate indefinitely, enabling a memory-exhaustion attack.
go srv.cleanupPendingLogins()
return srv, nil
}
// cleanupPendingLogins periodically evicts expired entries from pendingLogins.
// It runs every 5 minutes, which is well within the 90-second nonce TTL, so
// stale entries are removed before they can accumulate to any significant size.
func (u *UIServer) cleanupPendingLogins() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
u.pendingLogins.Range(func(key, value any) bool {
pl, ok := value.(*pendingLogin)
if !ok || now.After(pl.expiresAt) {
u.pendingLogins.Delete(key)
}
return true
})
}
}
// Register attaches all UI routes to mux, wrapped with security headers.
@@ -259,9 +286,18 @@ func (u *UIServer) Register(mux *http.ServeMux) {
http.NotFound(w, r)
})
// Security (DEF-01, DEF-03): apply the same per-IP rate limit as the REST
// /v1/auth/login endpoint, using the same proxy-aware IP extraction so
// the rate limit is applied to real client IPs behind a reverse proxy.
var trustedProxy net.IP
if u.cfg.Server.TrustedProxy != "" {
trustedProxy = net.ParseIP(u.cfg.Server.TrustedProxy)
}
loginRateLimit := middleware.RateLimit(10, 10, trustedProxy)
// Auth routes (no session required).
uiMux.HandleFunc("GET /login", u.handleLoginPage)
uiMux.HandleFunc("POST /login", u.handleLoginPost)
uiMux.Handle("POST /login", loginRateLimit(http.HandlerFunc(u.handleLoginPost)))
uiMux.HandleFunc("POST /logout", u.handleLogout)
// Protected routes.
@@ -498,13 +534,15 @@ func securityHeaders(next http.Handler) http.Handler {
})
}
// clientIP extracts the client IP from RemoteAddr (best effort).
func clientIP(r *http.Request) string {
addr := r.RemoteAddr
if idx := strings.LastIndex(addr, ":"); idx != -1 {
return addr[:idx]
// clientIP returns the real client IP for the request, respecting the
// server's trusted-proxy setting (DEF-03). Delegates to middleware.ClientIP
// so the same extraction logic is used for rate limiting and audit logging.
func (u *UIServer) clientIP(r *http.Request) string {
var proxyIP net.IP
if u.cfg.Server.TrustedProxy != "" {
proxyIP = net.ParseIP(u.cfg.Server.TrustedProxy)
}
return addr
return middleware.ClientIP(r, proxyIP)
}
// actorName resolves the username of the currently authenticated user from the