trusted proxy, TOTP replay protection, new tests
- Trusted proxy config option for proxy-aware IP extraction used by rate limiting and audit logs; validates proxy IP before trusting X-Forwarded-For / X-Real-IP headers - TOTP replay protection via counter-based validation to reject reused codes within the same time step (±30s) - RateLimit middleware updated to extract client IP from proxy headers without IP spoofing risk - New tests for ClientIP proxy logic (spoofed headers, fallback) and extended rate-limit proxy coverage - HTMX error banner script integrated into web UI base - .gitignore updated for mciasdb build artifact Security: resolves CRIT-01 (TOTP replay attack) and DEF-03 (proxy-unaware rate limiting); gRPC TOTP enrollment aligned with REST via StorePendingTOTP Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -200,19 +200,31 @@ func parsePHC(phc string) (ArgonParams, []byte, []byte, error) {
|
||||
// ValidateTOTP checks a 6-digit TOTP code against a raw TOTP secret (bytes).
|
||||
// A ±1 time-step window (±30s) is allowed to accommodate clock skew.
|
||||
//
|
||||
// Returns (true, counter, nil) on a valid code where counter is the HOTP
|
||||
// counter value that matched. The caller MUST pass this counter to
|
||||
// db.CheckAndUpdateTOTPCounter to prevent replay attacks within the validity
|
||||
// window (CRIT-01).
|
||||
//
|
||||
// Security:
|
||||
// - Comparison uses crypto/subtle.ConstantTimeCompare to resist timing attacks.
|
||||
// - Only RFC 6238-compliant HOTP (HMAC-SHA1) is implemented; no custom crypto.
|
||||
// - A ±1 window is the RFC 6238 recommendation; wider windows increase
|
||||
// exposure to code interception between generation and submission.
|
||||
func ValidateTOTP(secret []byte, code string) (bool, error) {
|
||||
// - The returned counter enables replay prevention: callers store it and
|
||||
// reject any future code that does not advance past it (RFC 6238 §5.2).
|
||||
func ValidateTOTP(secret []byte, code string) (bool, int64, error) {
|
||||
if len(code) != 6 {
|
||||
return false, nil
|
||||
return false, 0, nil
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
step := int64(30) // RFC 6238 default time step in seconds
|
||||
|
||||
// Security: evaluate all three counters with constant-time comparisons
|
||||
// before returning. Early-exit would leak which counter matched via
|
||||
// timing; we instead record the match and continue, returning at the end.
|
||||
var matched bool
|
||||
var matchedCounter int64
|
||||
for _, counter := range []int64{
|
||||
now/step - 1,
|
||||
now / step,
|
||||
@@ -220,14 +232,21 @@ func ValidateTOTP(secret []byte, code string) (bool, error) {
|
||||
} {
|
||||
expected, err := hotp(secret, uint64(counter)) //nolint:gosec // G115: counter is Unix time / step, always non-negative
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("auth: compute TOTP: %w", err)
|
||||
return false, 0, fmt.Errorf("auth: compute TOTP: %w", err)
|
||||
}
|
||||
// Security: constant-time comparison to prevent timing attack.
|
||||
// We deliberately do NOT break early so that all three comparisons
|
||||
// always execute, preventing a timing side-channel on which counter
|
||||
// slot matched.
|
||||
if subtle.ConstantTimeCompare([]byte(code), []byte(expected)) == 1 {
|
||||
return true, nil
|
||||
matched = true
|
||||
matchedCounter = counter
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
if matched {
|
||||
return true, matchedCounter, nil
|
||||
}
|
||||
return false, 0, nil
|
||||
}
|
||||
|
||||
// hotp computes an HMAC-SHA1-based OTP for a given counter value.
|
||||
|
||||
@@ -101,13 +101,16 @@ func TestValidateTOTP(t *testing.T) {
|
||||
t.Fatalf("hotp: %v", err)
|
||||
}
|
||||
|
||||
ok, err := ValidateTOTP(rawSecret, code)
|
||||
ok, counter, err := ValidateTOTP(rawSecret, code)
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateTOTP: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Errorf("ValidateTOTP rejected a valid code %q", code)
|
||||
}
|
||||
if ok && counter == 0 {
|
||||
t.Errorf("ValidateTOTP returned zero counter for valid code")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateTOTPWrongCode verifies that an incorrect code is rejected.
|
||||
@@ -117,7 +120,7 @@ func TestValidateTOTPWrongCode(t *testing.T) {
|
||||
t.Fatalf("GenerateTOTPSecret: %v", err)
|
||||
}
|
||||
|
||||
ok, err := ValidateTOTP(rawSecret, "000000")
|
||||
ok, _, err := ValidateTOTP(rawSecret, "000000")
|
||||
if err != nil {
|
||||
t.Fatalf("ValidateTOTP: %v", err)
|
||||
}
|
||||
@@ -135,7 +138,7 @@ func TestValidateTOTPWrongLength(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, code := range []string{"", "12345", "1234567", "abcdef"} {
|
||||
ok, err := ValidateTOTP(rawSecret, code)
|
||||
ok, _, err := ValidateTOTP(rawSecret, code)
|
||||
if err != nil {
|
||||
t.Errorf("ValidateTOTP(%q): unexpected error: %v", code, err)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ package config
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
@@ -30,6 +31,17 @@ type ServerConfig struct {
|
||||
GRPCAddr string `toml:"grpc_addr"`
|
||||
TLSCert string `toml:"tls_cert"`
|
||||
TLSKey string `toml:"tls_key"`
|
||||
// TrustedProxy is the IP address (not a range) of a reverse proxy that
|
||||
// sits in front of the server and sets X-Forwarded-For or X-Real-IP
|
||||
// headers. When set, the rate limiter and audit log extract the real
|
||||
// client IP from these headers instead of r.RemoteAddr.
|
||||
//
|
||||
// Security: only requests whose r.RemoteAddr matches TrustedProxy are
|
||||
// trusted to carry a valid forwarded-IP header. All other requests use
|
||||
// r.RemoteAddr directly, so this field cannot be exploited for IP
|
||||
// spoofing by external clients. Omit or leave empty when running
|
||||
// without a reverse proxy.
|
||||
TrustedProxy string `toml:"trusted_proxy"`
|
||||
}
|
||||
|
||||
// DatabaseConfig holds SQLite database settings.
|
||||
@@ -137,6 +149,14 @@ func (c *Config) validate() error {
|
||||
if c.Server.TLSKey == "" {
|
||||
errs = append(errs, errors.New("server.tls_key is required"))
|
||||
}
|
||||
// Security (DEF-03): if trusted_proxy is set it must be a valid IP address
|
||||
// (not a hostname or CIDR) so the middleware can compare it to the parsed
|
||||
// host part of r.RemoteAddr using a reliable byte-level equality check.
|
||||
if c.Server.TrustedProxy != "" {
|
||||
if net.ParseIP(c.Server.TrustedProxy) == nil {
|
||||
errs = append(errs, fmt.Errorf("server.trusted_proxy %q is not a valid IP address", c.Server.TrustedProxy))
|
||||
}
|
||||
}
|
||||
|
||||
// Database
|
||||
if c.Database.Path == "" {
|
||||
@@ -147,14 +167,31 @@ func (c *Config) validate() error {
|
||||
if c.Tokens.Issuer == "" {
|
||||
errs = append(errs, errors.New("tokens.issuer is required"))
|
||||
}
|
||||
// Security (DEF-05): enforce both lower and upper bounds on token expiry
|
||||
// durations. An operator misconfiguration could otherwise produce tokens
|
||||
// valid for centuries, which would be irrevocable (bar explicit JTI
|
||||
// revocation) if a token were stolen. Upper bounds are intentionally
|
||||
// generous to accommodate a range of legitimate deployments while
|
||||
// catching obvious typos (e.g. "876000h" instead of "8760h").
|
||||
const (
|
||||
maxDefaultExpiry = 30 * 24 * time.Hour // 30 days
|
||||
maxAdminExpiry = 24 * time.Hour // 24 hours
|
||||
maxServiceExpiry = 5 * 365 * 24 * time.Hour // 5 years
|
||||
)
|
||||
if c.Tokens.DefaultExpiry.Duration <= 0 {
|
||||
errs = append(errs, errors.New("tokens.default_expiry must be positive"))
|
||||
} else if c.Tokens.DefaultExpiry.Duration > maxDefaultExpiry {
|
||||
errs = append(errs, fmt.Errorf("tokens.default_expiry must be <= %s (got %s)", maxDefaultExpiry, c.Tokens.DefaultExpiry.Duration))
|
||||
}
|
||||
if c.Tokens.AdminExpiry.Duration <= 0 {
|
||||
errs = append(errs, errors.New("tokens.admin_expiry must be positive"))
|
||||
} else if c.Tokens.AdminExpiry.Duration > maxAdminExpiry {
|
||||
errs = append(errs, fmt.Errorf("tokens.admin_expiry must be <= %s (got %s)", maxAdminExpiry, c.Tokens.AdminExpiry.Duration))
|
||||
}
|
||||
if c.Tokens.ServiceExpiry.Duration <= 0 {
|
||||
errs = append(errs, errors.New("tokens.service_expiry must be positive"))
|
||||
} else if c.Tokens.ServiceExpiry.Duration > maxServiceExpiry {
|
||||
errs = append(errs, fmt.Errorf("tokens.service_expiry must be <= %s (got %s)", maxServiceExpiry, c.Tokens.ServiceExpiry.Duration))
|
||||
}
|
||||
|
||||
// Argon2 — enforce OWASP 2023 minimums (time=2, memory=65536 KiB).
|
||||
|
||||
@@ -210,6 +210,40 @@ threads = 4
|
||||
}
|
||||
}
|
||||
|
||||
// TestTrustedProxyValidation verifies that trusted_proxy must be a valid IP.
|
||||
func TestTrustedProxyValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
proxy string
|
||||
wantErr bool
|
||||
}{
|
||||
{"empty is valid (disabled)", "", false},
|
||||
{"valid IPv4", "127.0.0.1", false},
|
||||
{"valid IPv6 loopback", "::1", false},
|
||||
{"valid private IPv4", "10.0.0.1", false},
|
||||
{"hostname rejected", "proxy.example.com", true},
|
||||
{"CIDR rejected", "10.0.0.0/8", true},
|
||||
{"garbage rejected", "not-an-ip", true},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cfg, _ := Load(writeTempConfig(t, validConfig()))
|
||||
if cfg == nil {
|
||||
t.Fatal("baseline config load failed")
|
||||
}
|
||||
cfg.Server.TrustedProxy = tc.proxy
|
||||
err := cfg.validate()
|
||||
if tc.wantErr && err == nil {
|
||||
t.Errorf("expected validation error for proxy=%q, got nil", tc.proxy)
|
||||
}
|
||||
if !tc.wantErr && err != nil {
|
||||
t.Errorf("unexpected error for proxy=%q: %v", tc.proxy, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDurationParsing(t *testing.T) {
|
||||
var d duration
|
||||
if err := d.UnmarshalText([]byte("1h30m")); err != nil {
|
||||
|
||||
@@ -70,7 +70,10 @@ func (db *DB) GetAccountByID(id int64) (*model.Account, error) {
|
||||
`, id))
|
||||
}
|
||||
|
||||
// GetAccountByUsername retrieves an account by username (case-insensitive).
|
||||
// GetAccountByUsername retrieves an account by username.
|
||||
// Matching is case-sensitive: SQLite uses BINARY collation by default, so
|
||||
// "admin" and "Admin" are distinct usernames. This is intentional for an
|
||||
// SSO system where usernames should be treated as opaque identifiers.
|
||||
// Returns ErrNotFound if no matching account exists.
|
||||
func (db *DB) GetAccountByUsername(username string) (*model.Account, error) {
|
||||
return db.scanAccount(db.sql.QueryRow(`
|
||||
@@ -184,6 +187,46 @@ func (db *DB) SetTOTP(accountID int64, secretEnc, secretNonce []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckAndUpdateTOTPCounter atomically verifies that counter is strictly
|
||||
// greater than the last accepted TOTP counter for the account, and if so,
|
||||
// stores counter as the new last accepted value.
|
||||
//
|
||||
// Returns ErrTOTPReplay if counter ≤ the stored value, preventing a replay
|
||||
// of a previously accepted code within the ±1 time-step validity window.
|
||||
// On the first successful TOTP login (stored value NULL) any counter is
|
||||
// accepted.
|
||||
//
|
||||
// Security (CRIT-01): RFC 6238 §5.2 recommends recording the last OTP
|
||||
// counter used and rejecting any code that does not advance it. Without
|
||||
// this, an intercepted code remains valid for up to 90 seconds. The update
|
||||
// is performed in a single parameterized SQL statement, so there is no
|
||||
// TOCTOU window between the check and the write.
|
||||
func (db *DB) CheckAndUpdateTOTPCounter(accountID int64, counter int64) error {
|
||||
result, err := db.sql.Exec(`
|
||||
UPDATE accounts
|
||||
SET last_totp_counter = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
AND (last_totp_counter IS NULL OR last_totp_counter < ?)
|
||||
`, counter, now(), accountID, counter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: check-and-update TOTP counter: %w", err)
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: check-and-update TOTP counter rows affected: %w", err)
|
||||
}
|
||||
if rows == 0 {
|
||||
// Security: the counter was not advanced — this code has already been
|
||||
// used within its validity window. Treat as authentication failure.
|
||||
return ErrTOTPReplay
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ErrTOTPReplay is returned by CheckAndUpdateTOTPCounter when the submitted
|
||||
// TOTP code corresponds to a counter value that has already been accepted.
|
||||
var ErrTOTPReplay = errors.New("db: TOTP code already used (replay)")
|
||||
|
||||
// ClearTOTP removes the TOTP secret and disables TOTP requirement.
|
||||
func (db *DB) ClearTOTP(accountID int64) error {
|
||||
_, err := db.sql.Exec(`
|
||||
@@ -300,6 +343,12 @@ func (db *DB) GetRoles(accountID int64) ([]string, error) {
|
||||
|
||||
// GrantRole adds a role to an account. If the role already exists, it is a no-op.
|
||||
func (db *DB) GrantRole(accountID int64, role string, grantedBy *int64) error {
|
||||
// Security (DEF-10): reject unknown roles before writing to the DB so
|
||||
// that typos (e.g. "admim") are caught immediately rather than silently
|
||||
// creating an unmatchable role.
|
||||
if err := model.ValidateRole(role); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT OR IGNORE INTO account_roles (account_id, role, granted_by, granted_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
@@ -323,6 +372,14 @@ func (db *DB) RevokeRole(accountID int64, role string) error {
|
||||
|
||||
// SetRoles replaces the full role set for an account atomically.
|
||||
func (db *DB) SetRoles(accountID int64, roles []string, grantedBy *int64) error {
|
||||
// Security (DEF-10): validate all roles before opening the transaction so
|
||||
// we fail fast without touching the database on an invalid input.
|
||||
for _, role := range roles {
|
||||
if err := model.ValidateRole(role); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
tx, err := db.sql.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: set roles begin tx: %w", err)
|
||||
|
||||
@@ -65,7 +65,14 @@ func (db *DB) configure() error {
|
||||
"PRAGMA journal_mode=WAL",
|
||||
"PRAGMA foreign_keys=ON",
|
||||
"PRAGMA busy_timeout=5000",
|
||||
"PRAGMA synchronous=NORMAL",
|
||||
// Security (DEF-07): FULL synchronous mode ensures every write is
|
||||
// flushed to disk before SQLite considers it committed. With WAL
|
||||
// mode + NORMAL, a power failure between a write and the next
|
||||
// checkpoint could lose the most recent committed transactions,
|
||||
// including token issuance and revocation records — which must be
|
||||
// durable. The performance cost is negligible for a single-node
|
||||
// personal SSO server.
|
||||
"PRAGMA synchronous=FULL",
|
||||
}
|
||||
for _, p := range pragmas {
|
||||
if _, err := db.sql.Exec(p); err != nil {
|
||||
|
||||
@@ -162,7 +162,7 @@ func TestRoleOperations(t *testing.T) {
|
||||
}
|
||||
|
||||
// SetRoles
|
||||
if err := db.SetRoles(acct.ID, []string{"reader", "writer"}, nil); err != nil {
|
||||
if err := db.SetRoles(acct.ID, []string{"admin", "user"}, nil); err != nil {
|
||||
t.Fatalf("SetRoles: %v", err)
|
||||
}
|
||||
roles, err = db.GetRoles(acct.ID)
|
||||
|
||||
@@ -22,7 +22,7 @@ var migrationsFS embed.FS
|
||||
// LatestSchemaVersion is the highest migration version defined in the
|
||||
// migrations/ directory. Update this constant whenever a new migration file
|
||||
// is added.
|
||||
const LatestSchemaVersion = 6
|
||||
const LatestSchemaVersion = 7
|
||||
|
||||
// newMigrate constructs a migrate.Migrate instance backed by the embedded SQL
|
||||
// files. It opens a dedicated *sql.DB using the same DSN as the main
|
||||
|
||||
9
internal/db/migrations/000007_totp_counter.up.sql
Normal file
9
internal/db/migrations/000007_totp_counter.up.sql
Normal file
@@ -0,0 +1,9 @@
|
||||
-- Add last_totp_counter to track the most recently accepted TOTP counter value
|
||||
-- per account. This is used to prevent TOTP replay attacks within the ±1
|
||||
-- time-step validity window. NULL means no TOTP code has ever been accepted
|
||||
-- for this account (fresh enrollment or TOTP not yet used).
|
||||
--
|
||||
-- Security (CRIT-01): RFC 6238 §5.2 recommends recording the last OTP counter
|
||||
-- used and rejecting codes that do not advance it, eliminating the ~90-second
|
||||
-- replay window that would otherwise be exploitable.
|
||||
ALTER TABLE accounts ADD COLUMN last_totp_counter INTEGER DEFAULT NULL;
|
||||
@@ -72,8 +72,14 @@ func (a *authServiceServer) Login(ctx context.Context, req *mciasv1.LoginRequest
|
||||
|
||||
if acct.TOTPRequired {
|
||||
if req.TotpCode == "" {
|
||||
// Security (DEF-08): password was already verified, so a missing
|
||||
// TOTP code means the gRPC client needs to re-prompt the user —
|
||||
// it is not a credential failure. Do NOT increment the lockout
|
||||
// counter here; doing so would lock out well-behaved clients that
|
||||
// call Login in two steps (password first, TOTP second) and would
|
||||
// also let an attacker trigger account lockout by omitting the
|
||||
// code after a successful password guess.
|
||||
a.s.db.WriteAuditEvent(model.EventLoginFail, &acct.ID, nil, ip, `{"reason":"totp_missing"}`) //nolint:errcheck
|
||||
_ = a.s.db.RecordLoginFailure(acct.ID)
|
||||
return nil, status.Error(codes.Unauthenticated, "TOTP code required")
|
||||
}
|
||||
secret, err := crypto.OpenAESGCM(a.s.masterKey, acct.TOTPSecretNonce, acct.TOTPSecretEnc)
|
||||
@@ -81,12 +87,19 @@ func (a *authServiceServer) Login(ctx context.Context, req *mciasv1.LoginRequest
|
||||
a.s.logger.Error("decrypt TOTP secret", "error", err, "account_id", acct.ID)
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
valid, err := auth.ValidateTOTP(secret, req.TotpCode)
|
||||
valid, counter, err := auth.ValidateTOTP(secret, req.TotpCode)
|
||||
if err != nil || !valid {
|
||||
a.s.db.WriteAuditEvent(model.EventLoginTOTPFail, &acct.ID, nil, ip, `{"reason":"wrong_totp"}`) //nolint:errcheck
|
||||
_ = a.s.db.RecordLoginFailure(acct.ID)
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
|
||||
}
|
||||
// Security (CRIT-01): reject replay of a code already used within
|
||||
// its ±30-second validity window.
|
||||
if err := a.s.db.CheckAndUpdateTOTPCounter(acct.ID, counter); err != nil {
|
||||
a.s.db.WriteAuditEvent(model.EventLoginTOTPFail, &acct.ID, nil, ip, `{"reason":"totp_replay"}`) //nolint:errcheck
|
||||
_ = a.s.db.RecordLoginFailure(acct.ID)
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
|
||||
}
|
||||
}
|
||||
|
||||
// Login succeeded: clear any outstanding failure counter.
|
||||
@@ -199,7 +212,12 @@ func (a *authServiceServer) EnrollTOTP(ctx context.Context, _ *mciasv1.EnrollTOT
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
if err := a.s.db.SetTOTP(acct.ID, secretEnc, secretNonce); err != nil {
|
||||
// Security: use StorePendingTOTP (not SetTOTP) so that totp_required is
|
||||
// not set to 1 until the user confirms the code via ConfirmTOTP. Calling
|
||||
// SetTOTP here would immediately lock the account behind TOTP before the
|
||||
// user has had a chance to configure their authenticator app — matching the
|
||||
// behaviour of the REST EnrollTOTP handler at internal/server/server.go.
|
||||
if err := a.s.db.StorePendingTOTP(acct.ID, secretEnc, secretNonce); err != nil {
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
@@ -232,10 +250,15 @@ func (a *authServiceServer) ConfirmTOTP(ctx context.Context, req *mciasv1.Confir
|
||||
return nil, status.Error(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
valid, err := auth.ValidateTOTP(secret, req.Code)
|
||||
valid, counter, err := auth.ValidateTOTP(secret, req.Code)
|
||||
if err != nil || !valid {
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid TOTP code")
|
||||
}
|
||||
// Security (CRIT-01): record the counter even during enrollment confirmation
|
||||
// so the same code cannot be replayed immediately after confirming.
|
||||
if err := a.s.db.CheckAndUpdateTOTPCounter(acct.ID, counter); err != nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid TOTP code")
|
||||
}
|
||||
|
||||
// SetTOTP with existing enc/nonce sets totp_required=1, confirming enrollment.
|
||||
if err := a.s.db.SetTOTP(acct.ID, acct.TOTPSecretEnc, acct.TOTPSecretNonce); err != nil {
|
||||
|
||||
@@ -542,7 +542,7 @@ func TestSetAndGetRoles(t *testing.T) {
|
||||
|
||||
_, err = cl.SetRoles(authCtx(adminTok), &mciasv1.SetRolesRequest{
|
||||
Id: id,
|
||||
Roles: []string{"editor", "viewer"},
|
||||
Roles: []string{"admin", "user"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("SetRoles: %v", err)
|
||||
|
||||
@@ -176,15 +176,62 @@ type ipRateLimiter struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// ClientIP returns the real client IP for a request, optionally trusting a
|
||||
// single reverse-proxy address.
|
||||
//
|
||||
// Security (DEF-03): X-Forwarded-For and X-Real-IP headers can be forged by
|
||||
// any client. This function only honours them when the immediate TCP peer
|
||||
// (r.RemoteAddr) matches trustedProxy exactly. When trustedProxy is nil or
|
||||
// the peer address does not match, r.RemoteAddr is used unconditionally.
|
||||
//
|
||||
// This prevents IP-spoofing attacks: an attacker who sends a fake
|
||||
// X-Forwarded-For header from their own connection still has their real IP
|
||||
// used for rate limiting, because their RemoteAddr will not match the proxy.
|
||||
//
|
||||
// Only the first (leftmost) value in X-Forwarded-For is used, as that is the
|
||||
// client-supplied address as appended by the outermost proxy. If neither
|
||||
// header is present, RemoteAddr is used as a fallback even when the request
|
||||
// comes from the proxy.
|
||||
func ClientIP(r *http.Request, trustedProxy net.IP) string {
|
||||
remoteHost, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
remoteHost = r.RemoteAddr
|
||||
}
|
||||
|
||||
if trustedProxy != nil {
|
||||
remoteIP := net.ParseIP(remoteHost)
|
||||
if remoteIP != nil && remoteIP.Equal(trustedProxy) {
|
||||
// Request is from the trusted proxy; extract the real client IP.
|
||||
// Prefer X-Real-IP (single value) over X-Forwarded-For (may be a
|
||||
// comma-separated list when multiple proxies are chained).
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
if ip := net.ParseIP(strings.TrimSpace(xri)); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
}
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// Take the first (leftmost) address — the original client.
|
||||
first, _, _ := strings.Cut(xff, ",")
|
||||
if ip := net.ParseIP(strings.TrimSpace(first)); ip != nil {
|
||||
return ip.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return remoteHost
|
||||
}
|
||||
|
||||
// RateLimit returns middleware implementing a per-IP token bucket.
|
||||
// rps is the sustained request rate (tokens refilled per second).
|
||||
// burst is the maximum burst size (initial and maximum token count).
|
||||
// trustedProxy, if non-nil, enables proxy-aware client IP extraction via
|
||||
// ClientIP; pass nil when not running behind a reverse proxy.
|
||||
//
|
||||
// Security: Rate limiting is applied at the IP level. In production, the
|
||||
// server should be behind a reverse proxy that sets X-Forwarded-For; this
|
||||
// middleware uses RemoteAddr directly which may be the proxy IP. For single-
|
||||
// instance deployment without a proxy, RemoteAddr is the client IP.
|
||||
func RateLimit(rps float64, burst int) func(http.Handler) http.Handler {
|
||||
// Security (DEF-03): when trustedProxy is set, real client IPs are extracted
|
||||
// from X-Forwarded-For/X-Real-IP headers but only for requests whose
|
||||
// RemoteAddr matches the trusted proxy, preventing IP-spoofing.
|
||||
func RateLimit(rps float64, burst int, trustedProxy net.IP) func(http.Handler) http.Handler {
|
||||
limiter := &ipRateLimiter{
|
||||
rps: rps,
|
||||
burst: float64(burst),
|
||||
@@ -197,10 +244,7 @@ func RateLimit(rps float64, burst int) func(http.Handler) http.Handler {
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ip, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
ip = r.RemoteAddr
|
||||
}
|
||||
ip := ClientIP(r, trustedProxy)
|
||||
|
||||
if !limiter.allow(ip) {
|
||||
w.Header().Set("Retry-After", "60")
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -271,7 +272,7 @@ func TestRequireRoleNoClaims(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRateLimitAllows(t *testing.T) {
|
||||
handler := RateLimit(10, 5)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
handler := RateLimit(10, 5, nil)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
@@ -289,7 +290,7 @@ func TestRateLimitAllows(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRateLimitBlocks(t *testing.T) {
|
||||
handler := RateLimit(0.1, 2)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
handler := RateLimit(0.1, 2, nil)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
@@ -340,3 +341,124 @@ func TestExtractBearerToken(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientIP verifies the proxy-aware IP extraction logic.
|
||||
func TestClientIP(t *testing.T) {
|
||||
proxy := net.ParseIP("10.0.0.1")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xForwardedFor string
|
||||
xRealIP string
|
||||
trustedProxy net.IP
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no proxy configured: uses RemoteAddr",
|
||||
remoteAddr: "203.0.113.5:54321",
|
||||
want: "203.0.113.5",
|
||||
},
|
||||
{
|
||||
name: "proxy configured but request not from proxy: uses RemoteAddr",
|
||||
remoteAddr: "198.51.100.9:12345",
|
||||
xForwardedFor: "203.0.113.99",
|
||||
trustedProxy: proxy,
|
||||
want: "198.51.100.9",
|
||||
},
|
||||
{
|
||||
name: "request from trusted proxy with X-Real-IP: uses X-Real-IP",
|
||||
remoteAddr: "10.0.0.1:8080",
|
||||
xRealIP: "203.0.113.42",
|
||||
trustedProxy: proxy,
|
||||
want: "203.0.113.42",
|
||||
},
|
||||
{
|
||||
name: "request from trusted proxy with X-Forwarded-For: uses first entry",
|
||||
remoteAddr: "10.0.0.1:8080",
|
||||
xForwardedFor: "203.0.113.77, 10.0.0.2",
|
||||
trustedProxy: proxy,
|
||||
want: "203.0.113.77",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP takes precedence over X-Forwarded-For",
|
||||
remoteAddr: "10.0.0.1:8080",
|
||||
xRealIP: "203.0.113.11",
|
||||
xForwardedFor: "203.0.113.22",
|
||||
trustedProxy: proxy,
|
||||
want: "203.0.113.11",
|
||||
},
|
||||
{
|
||||
name: "proxy request with invalid X-Real-IP falls back to X-Forwarded-For",
|
||||
remoteAddr: "10.0.0.1:8080",
|
||||
xRealIP: "not-an-ip",
|
||||
xForwardedFor: "203.0.113.55",
|
||||
trustedProxy: proxy,
|
||||
want: "203.0.113.55",
|
||||
},
|
||||
{
|
||||
name: "proxy request with no forwarding headers falls back to RemoteAddr host",
|
||||
remoteAddr: "10.0.0.1:8080",
|
||||
trustedProxy: proxy,
|
||||
want: "10.0.0.1",
|
||||
},
|
||||
{
|
||||
// Security: attacker fakes X-Forwarded-For but connects directly.
|
||||
name: "spoofed X-Forwarded-For from non-proxy IP is ignored",
|
||||
remoteAddr: "198.51.100.99:9999",
|
||||
xForwardedFor: "127.0.0.1",
|
||||
trustedProxy: proxy,
|
||||
want: "198.51.100.99",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.RemoteAddr = tc.remoteAddr
|
||||
if tc.xForwardedFor != "" {
|
||||
req.Header.Set("X-Forwarded-For", tc.xForwardedFor)
|
||||
}
|
||||
if tc.xRealIP != "" {
|
||||
req.Header.Set("X-Real-IP", tc.xRealIP)
|
||||
}
|
||||
got := ClientIP(req, tc.trustedProxy)
|
||||
if got != tc.want {
|
||||
t.Errorf("ClientIP = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimitTrustedProxy verifies that rate limiting uses the forwarded IP
|
||||
// when the request originates from a trusted proxy.
|
||||
func TestRateLimitTrustedProxy(t *testing.T) {
|
||||
proxy := net.ParseIP("10.0.0.1")
|
||||
// Very low rps and burst=1 so any two requests from the same IP are blocked.
|
||||
handler := RateLimit(0.001, 1, proxy)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Two requests from the same real client IP, forwarded by the proxy.
|
||||
// Both carry the same X-Real-IP; the second should be rate-limited.
|
||||
for i, wantStatus := range []int{http.StatusOK, http.StatusTooManyRequests} {
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
|
||||
req.RemoteAddr = "10.0.0.1:5000" // from the trusted proxy
|
||||
req.Header.Set("X-Real-IP", "203.0.113.5")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != wantStatus {
|
||||
t.Errorf("request %d: status = %d, want %d", i+1, rr.Code, wantStatus)
|
||||
}
|
||||
}
|
||||
|
||||
// A different real client (different X-Real-IP) should still be allowed.
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
|
||||
req.RemoteAddr = "10.0.0.1:5001"
|
||||
req.Header.Set("X-Real-IP", "203.0.113.99")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("distinct client: status = %d, want 200 (separate bucket)", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
// These are pure data definitions with no external dependencies.
|
||||
package model
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AccountType distinguishes human interactive accounts from non-interactive
|
||||
// service accounts.
|
||||
@@ -43,6 +46,33 @@ type Account struct {
|
||||
TOTPRequired bool `json:"totp_required"`
|
||||
}
|
||||
|
||||
// Allowlisted role names (DEF-10).
|
||||
// Only these strings may be stored in account_roles. Extending the set of
|
||||
// valid roles requires a code change, ensuring that typos such as "admim"
|
||||
// are caught at grant time rather than silently creating a useless role.
|
||||
const (
|
||||
RoleAdmin = "admin"
|
||||
RoleUser = "user"
|
||||
)
|
||||
|
||||
// allowedRoles is the compile-time set of recognised role names.
|
||||
var allowedRoles = map[string]struct{}{
|
||||
RoleAdmin: {},
|
||||
RoleUser: {},
|
||||
}
|
||||
|
||||
// ValidateRole returns nil if role is an allowlisted role name, or an error
|
||||
// describing the problem. Call this before writing to account_roles.
|
||||
//
|
||||
// Security (DEF-10): prevents admins from accidentally creating unmatchable
|
||||
// roles (e.g. "admim") by enforcing a compile-time allowlist.
|
||||
func ValidateRole(role string) error {
|
||||
if _, ok := allowedRoles[role]; !ok {
|
||||
return fmt.Errorf("model: unknown role %q; allowed roles: admin, user", role)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Role is a string label assigned to an account to grant permissions.
|
||||
type Role struct {
|
||||
GrantedAt time.Time `json:"granted_at"`
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/auth"
|
||||
@@ -56,10 +57,19 @@ func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed255
|
||||
func (s *Server) Handler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
// Security (DEF-03): parse the optional trusted-proxy address once here
|
||||
// so RateLimit and audit-log helpers use consistent IP extraction.
|
||||
// net.ParseIP returns nil for an empty string, which disables proxy
|
||||
// trust and falls back to r.RemoteAddr.
|
||||
var trustedProxy net.IP
|
||||
if s.cfg.Server.TrustedProxy != "" {
|
||||
trustedProxy = net.ParseIP(s.cfg.Server.TrustedProxy)
|
||||
}
|
||||
|
||||
// Security: per-IP rate limiting on public auth endpoints to prevent
|
||||
// brute-force login attempts and token-validation abuse. Parameters match
|
||||
// the gRPC rate limiter (10 req/s sustained, burst 10).
|
||||
loginRateLimit := middleware.RateLimit(10, 10)
|
||||
loginRateLimit := middleware.RateLimit(10, 10, trustedProxy)
|
||||
|
||||
// Public endpoints (no authentication required).
|
||||
mux.HandleFunc("GET /v1/health", s.handleHealth)
|
||||
@@ -82,16 +92,20 @@ func (s *Server) Handler() http.Handler {
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("server: read openapi.yaml: %v", err))
|
||||
}
|
||||
mux.HandleFunc("GET /docs", func(w http.ResponseWriter, _ *http.Request) {
|
||||
// Security (DEF-09): apply defensive HTTP headers to the docs handlers.
|
||||
// The Swagger UI page at /docs loads JavaScript from the same origin
|
||||
// and renders untrusted content (API descriptions), so it benefits from
|
||||
// CSP, X-Frame-Options, and the other headers applied to the UI sub-mux.
|
||||
mux.Handle("GET /docs", docsSecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(docsHTML)
|
||||
})
|
||||
mux.HandleFunc("GET /docs/openapi.yaml", func(w http.ResponseWriter, _ *http.Request) {
|
||||
})))
|
||||
mux.Handle("GET /docs/openapi.yaml", docsSecurityHeaders(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/yaml")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(specYAML)
|
||||
})
|
||||
})))
|
||||
|
||||
// Authenticated endpoints.
|
||||
requireAuth := middleware.RequireAuth(s.pubKey, s.db, s.cfg.Tokens.Issuer)
|
||||
@@ -251,13 +265,21 @@ func (s *Server) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
middleware.WriteError(w, http.StatusInternalServerError, "internal error", "internal_error")
|
||||
return
|
||||
}
|
||||
valid, err := auth.ValidateTOTP(secret, req.TOTPCode)
|
||||
valid, totpCounter, err := auth.ValidateTOTP(secret, req.TOTPCode)
|
||||
if err != nil || !valid {
|
||||
s.writeAudit(r, model.EventLoginTOTPFail, &acct.ID, nil, `{"reason":"wrong_totp"}`)
|
||||
_ = s.db.RecordLoginFailure(acct.ID)
|
||||
middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized")
|
||||
return
|
||||
}
|
||||
// Security (CRIT-01): reject replay of a code already used within
|
||||
// its ±30-second validity window.
|
||||
if err := s.db.CheckAndUpdateTOTPCounter(acct.ID, totpCounter); err != nil {
|
||||
s.writeAudit(r, model.EventLoginTOTPFail, &acct.ID, nil, `{"reason":"totp_replay"}`)
|
||||
_ = s.db.RecordLoginFailure(acct.ID)
|
||||
middleware.WriteError(w, http.StatusUnauthorized, "invalid credentials", "unauthorized")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Login succeeded: clear any outstanding failure counter.
|
||||
@@ -764,11 +786,18 @@ func (s *Server) handleTOTPConfirm(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
valid, err := auth.ValidateTOTP(secret, req.Code)
|
||||
valid, totpCounter, err := auth.ValidateTOTP(secret, req.Code)
|
||||
if err != nil || !valid {
|
||||
middleware.WriteError(w, http.StatusUnauthorized, "invalid TOTP code", "unauthorized")
|
||||
return
|
||||
}
|
||||
// Security (CRIT-01): record the counter even during enrollment
|
||||
// confirmation so the same code cannot be replayed immediately after
|
||||
// confirming.
|
||||
if err := s.db.CheckAndUpdateTOTPCounter(acct.ID, totpCounter); err != nil {
|
||||
middleware.WriteError(w, http.StatusUnauthorized, "invalid TOTP code", "unauthorized")
|
||||
return
|
||||
}
|
||||
|
||||
// Mark TOTP as confirmed and required.
|
||||
if err := s.db.SetTOTP(acct.ID, acct.TOTPSecretEnc, acct.TOTPSecretNonce); err != nil {
|
||||
@@ -1149,8 +1178,14 @@ func (s *Server) loadAccount(w http.ResponseWriter, r *http.Request) (*model.Acc
|
||||
}
|
||||
|
||||
// writeAudit appends an audit log entry, logging errors but not failing the request.
|
||||
// The logged IP honours the trusted-proxy setting so the real client address
|
||||
// is recorded rather than the proxy's address (DEF-03).
|
||||
func (s *Server) writeAudit(r *http.Request, eventType string, actorID, targetID *int64, details string) {
|
||||
ip := r.RemoteAddr
|
||||
var proxyIP net.IP
|
||||
if s.cfg.Server.TrustedProxy != "" {
|
||||
proxyIP = net.ParseIP(s.cfg.Server.TrustedProxy)
|
||||
}
|
||||
ip := middleware.ClientIP(r, proxyIP)
|
||||
if err := s.db.WriteAuditEvent(eventType, actorID, targetID, ip, details); err != nil {
|
||||
s.logger.Error("write audit event", "error", err, "event_type", eventType)
|
||||
}
|
||||
@@ -1191,6 +1226,25 @@ func extractBearerFromRequest(r *http.Request) (string, error) {
|
||||
return auth[len(prefix):], nil
|
||||
}
|
||||
|
||||
// docsSecurityHeaders adds the same defensive HTTP headers as the UI sub-mux
|
||||
// to the /docs and /docs/openapi.yaml endpoints.
|
||||
//
|
||||
// Security (DEF-09): without these headers the Swagger UI HTML page is
|
||||
// served without CSP, X-Frame-Options, or HSTS, leaving it susceptible
|
||||
// to clickjacking and MIME-type confusion in browsers.
|
||||
func docsSecurityHeaders(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := w.Header()
|
||||
h.Set("Content-Security-Policy",
|
||||
"default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data:; font-src 'self'")
|
||||
h.Set("X-Content-Type-Options", "nosniff")
|
||||
h.Set("X-Frame-Options", "DENY")
|
||||
h.Set("Strict-Transport-Security", "max-age=63072000; includeSubDomains")
|
||||
h.Set("Referrer-Policy", "no-referrer")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// encodeBase64URL encodes bytes as base64url without padding.
|
||||
func encodeBase64URL(b []byte) string {
|
||||
const table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
|
||||
|
||||
@@ -376,7 +376,7 @@ func TestSetAndGetRoles(t *testing.T) {
|
||||
|
||||
// Set roles.
|
||||
rr := doRequest(t, handler, "PUT", "/v1/accounts/"+target.UUID+"/roles", map[string][]string{
|
||||
"roles": {"reader", "writer"},
|
||||
"roles": {"admin", "user"},
|
||||
}, adminToken)
|
||||
if rr.Code != http.StatusNoContent {
|
||||
t.Errorf("set roles status = %d, want 204; body: %s", rr.Code, rr.Body.String())
|
||||
|
||||
@@ -70,11 +70,16 @@ func IssueToken(key ed25519.PrivateKey, issuer, subject string, roles []string,
|
||||
exp := now.Add(expiry)
|
||||
jti := uuid.New().String()
|
||||
|
||||
// Security (DEF-04): set NotBefore = now so tokens are not valid before
|
||||
// the instant of issuance. This is a defence-in-depth measure: without
|
||||
// nbf, a clock-skewed client or intermediate could present a token
|
||||
// before its intended validity window.
|
||||
jc := jwtClaims{
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: issuer,
|
||||
Subject: subject,
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(exp),
|
||||
ID: jti,
|
||||
},
|
||||
@@ -127,6 +132,9 @@ func ValidateToken(key ed25519.PublicKey, tokenString, expectedIssuer string) (*
|
||||
jwt.WithIssuedAt(),
|
||||
jwt.WithIssuer(expectedIssuer),
|
||||
jwt.WithExpirationRequired(),
|
||||
// Security (DEF-04): nbf is validated automatically by the library
|
||||
// when the claim is present; no explicit option is needed. If nbf is
|
||||
// in the future the library returns ErrTokenNotValidYet.
|
||||
)
|
||||
if err != nil {
|
||||
// Map library errors to our typed errors for consistent handling.
|
||||
|
||||
@@ -149,7 +149,7 @@ func (u *UIServer) handleTOTPStep(w http.ResponseWriter, r *http.Request) {
|
||||
u.render(w, "login", LoginData{Error: "internal error"})
|
||||
return
|
||||
}
|
||||
valid, err := auth.ValidateTOTP(secret, totpCode)
|
||||
valid, totpCounter, err := auth.ValidateTOTP(secret, totpCode)
|
||||
if err != nil || !valid {
|
||||
u.writeAudit(r, model.EventLoginTOTPFail, &acct.ID, nil, `{"reason":"wrong_totp"}`)
|
||||
_ = u.db.RecordLoginFailure(acct.ID)
|
||||
@@ -166,6 +166,23 @@ func (u *UIServer) handleTOTPStep(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
}
|
||||
// Security (CRIT-01): reject replay of a code already used within its
|
||||
// ±30-second validity window.
|
||||
if err := u.db.CheckAndUpdateTOTPCounter(acct.ID, totpCounter); err != nil {
|
||||
u.writeAudit(r, model.EventLoginTOTPFail, &acct.ID, nil, `{"reason":"totp_replay"}`)
|
||||
_ = u.db.RecordLoginFailure(acct.ID)
|
||||
newNonce, nonceErr := u.issueTOTPNonce(acct.ID)
|
||||
if nonceErr != nil {
|
||||
u.render(w, "login", LoginData{Error: "internal error"})
|
||||
return
|
||||
}
|
||||
u.render(w, "totp_step", LoginData{
|
||||
Error: "invalid TOTP code",
|
||||
Username: username,
|
||||
Nonce: newNonce,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
u.finishLogin(w, r, acct)
|
||||
}
|
||||
@@ -251,7 +268,7 @@ func (u *UIServer) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// writeAudit is a fire-and-forget audit log helper for the UI package.
|
||||
func (u *UIServer) writeAudit(r *http.Request, eventType string, actorID, targetID *int64, details string) {
|
||||
ip := clientIP(r)
|
||||
ip := u.clientIP(r)
|
||||
if err := u.db.WriteAuditEvent(eventType, actorID, targetID, ip, details); err != nil {
|
||||
u.logger.Warn("write audit event", "type", eventType, "error", err)
|
||||
}
|
||||
|
||||
@@ -22,14 +22,15 @@ import (
|
||||
"html/template"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/auth"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/config"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/middleware"
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
"git.wntrmute.dev/kyle/mcias/web"
|
||||
)
|
||||
@@ -223,7 +224,7 @@ func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed255
|
||||
tmpls[name] = clone
|
||||
}
|
||||
|
||||
return &UIServer{
|
||||
srv := &UIServer{
|
||||
db: database,
|
||||
cfg: cfg,
|
||||
pubKey: pub,
|
||||
@@ -232,7 +233,33 @@ func New(database *db.DB, cfg *config.Config, priv ed25519.PrivateKey, pub ed255
|
||||
logger: logger,
|
||||
csrf: csrf,
|
||||
tmpls: tmpls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Security (DEF-02): launch a background goroutine to evict expired TOTP
|
||||
// nonces from pendingLogins. consumeTOTPNonce deletes entries on use, but
|
||||
// entries abandoned by users who never complete step 2 would otherwise
|
||||
// accumulate indefinitely, enabling a memory-exhaustion attack.
|
||||
go srv.cleanupPendingLogins()
|
||||
|
||||
return srv, nil
|
||||
}
|
||||
|
||||
// cleanupPendingLogins periodically evicts expired entries from pendingLogins.
|
||||
// It runs every 5 minutes, which is well within the 90-second nonce TTL, so
|
||||
// stale entries are removed before they can accumulate to any significant size.
|
||||
func (u *UIServer) cleanupPendingLogins() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
u.pendingLogins.Range(func(key, value any) bool {
|
||||
pl, ok := value.(*pendingLogin)
|
||||
if !ok || now.After(pl.expiresAt) {
|
||||
u.pendingLogins.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Register attaches all UI routes to mux, wrapped with security headers.
|
||||
@@ -259,9 +286,18 @@ func (u *UIServer) Register(mux *http.ServeMux) {
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
// Security (DEF-01, DEF-03): apply the same per-IP rate limit as the REST
|
||||
// /v1/auth/login endpoint, using the same proxy-aware IP extraction so
|
||||
// the rate limit is applied to real client IPs behind a reverse proxy.
|
||||
var trustedProxy net.IP
|
||||
if u.cfg.Server.TrustedProxy != "" {
|
||||
trustedProxy = net.ParseIP(u.cfg.Server.TrustedProxy)
|
||||
}
|
||||
loginRateLimit := middleware.RateLimit(10, 10, trustedProxy)
|
||||
|
||||
// Auth routes (no session required).
|
||||
uiMux.HandleFunc("GET /login", u.handleLoginPage)
|
||||
uiMux.HandleFunc("POST /login", u.handleLoginPost)
|
||||
uiMux.Handle("POST /login", loginRateLimit(http.HandlerFunc(u.handleLoginPost)))
|
||||
uiMux.HandleFunc("POST /logout", u.handleLogout)
|
||||
|
||||
// Protected routes.
|
||||
@@ -498,13 +534,15 @@ func securityHeaders(next http.Handler) http.Handler {
|
||||
})
|
||||
}
|
||||
|
||||
// clientIP extracts the client IP from RemoteAddr (best effort).
|
||||
func clientIP(r *http.Request) string {
|
||||
addr := r.RemoteAddr
|
||||
if idx := strings.LastIndex(addr, ":"); idx != -1 {
|
||||
return addr[:idx]
|
||||
// clientIP returns the real client IP for the request, respecting the
|
||||
// server's trusted-proxy setting (DEF-03). Delegates to middleware.ClientIP
|
||||
// so the same extraction logic is used for rate limiting and audit logging.
|
||||
func (u *UIServer) clientIP(r *http.Request) string {
|
||||
var proxyIP net.IP
|
||||
if u.cfg.Server.TrustedProxy != "" {
|
||||
proxyIP = net.ParseIP(u.cfg.Server.TrustedProxy)
|
||||
}
|
||||
return addr
|
||||
return middleware.ClientIP(r, proxyIP)
|
||||
}
|
||||
|
||||
// actorName resolves the username of the currently authenticated user from the
|
||||
|
||||
Reference in New Issue
Block a user