package server import ( "bytes" "crypto/hmac" "crypto/sha1" "crypto/ed25519" "crypto/rand" "encoding/binary" "encoding/json" "fmt" "io" "log/slog" "math" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "git.wntrmute.dev/kyle/mcias/internal/auth" "git.wntrmute.dev/kyle/mcias/internal/config" "git.wntrmute.dev/kyle/mcias/internal/db" "git.wntrmute.dev/kyle/mcias/internal/model" "git.wntrmute.dev/kyle/mcias/internal/token" "git.wntrmute.dev/kyle/mcias/internal/vault" ) // generateTOTPCode computes a valid RFC 6238 TOTP code for the current time // using the given raw secret bytes. Used in tests to confirm TOTP enrollment. func generateTOTPCode(t *testing.T, secret []byte) string { t.Helper() counter := uint64(time.Now().Unix() / 30) //nolint:gosec // G115: always non-negative counterBytes := make([]byte, 8) binary.BigEndian.PutUint64(counterBytes, counter) mac := hmac.New(sha1.New, secret) if _, err := mac.Write(counterBytes); err != nil { t.Fatalf("generateTOTPCode: HMAC write: %v", err) } h := mac.Sum(nil) offset := h[len(h)-1] & 0x0F binCode := (int(h[offset]&0x7F)<<24 | int(h[offset+1])<<16 | int(h[offset+2])<<8 | int(h[offset+3])) % int(math.Pow10(6)) return fmt.Sprintf("%06d", binCode) } const testIssuer = "https://auth.example.com" func newTestServer(t *testing.T) (*Server, ed25519.PublicKey, ed25519.PrivateKey, *db.DB) { t.Helper() pub, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatalf("generate key: %v", err) } database, err := db.Open(":memory:") if err != nil { t.Fatalf("open db: %v", err) } if err := db.Migrate(database); err != nil { t.Fatalf("migrate db: %v", err) } t.Cleanup(func() { _ = database.Close() }) masterKey := make([]byte, 32) if _, err := rand.Read(masterKey); err != nil { t.Fatalf("generate master key: %v", err) } cfg := config.NewTestConfig(testIssuer) v := vault.NewUnsealed(masterKey, priv, pub) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) srv := New(database, cfg, v, logger) return srv, pub, priv, database } // createTestHumanAccount creates a human account with password "testpass123". func createTestHumanAccount(t *testing.T, srv *Server, username string) *model.Account { t.Helper() hash, err := auth.HashPassword("testpass123", auth.ArgonParams{Time: 3, Memory: 65536, Threads: 4}) if err != nil { t.Fatalf("hash password: %v", err) } acct, err := srv.db.CreateAccount(username, model.AccountTypeHuman, hash) if err != nil { t.Fatalf("create account: %v", err) } return acct } // issueAdminToken creates an account with admin role, issues a JWT, and tracks it. func issueAdminToken(t *testing.T, srv *Server, priv ed25519.PrivateKey, username string) (string, *model.Account) { t.Helper() acct := createTestHumanAccount(t, srv, username) if err := srv.db.GrantRole(acct.ID, "admin", nil); err != nil { t.Fatalf("grant admin role: %v", err) } tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, []string{"admin"}, time.Hour) if err != nil { t.Fatalf("issue token: %v", err) } if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { t.Fatalf("track token: %v", err) } return tokenStr, acct } func doRequest(t *testing.T, handler http.Handler, method, path string, body interface{}, authToken string) *httptest.ResponseRecorder { t.Helper() var bodyReader io.Reader if body != nil { b, err := json.Marshal(body) if err != nil { t.Fatalf("marshal body: %v", err) } bodyReader = bytes.NewReader(b) } else { bodyReader = bytes.NewReader(nil) } req := httptest.NewRequest(method, path, bodyReader) req.Header.Set("Content-Type", "application/json") if authToken != "" { req.Header.Set("Authorization", "Bearer "+authToken) } rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) return rr } func TestHealth(t *testing.T) { srv, _, _, _ := newTestServer(t) handler := srv.Handler() rr := doRequest(t, handler, "GET", "/v1/health", nil, "") if rr.Code != http.StatusOK { t.Errorf("health status = %d, want 200", rr.Code) } } func TestPublicKey(t *testing.T) { srv, _, _, _ := newTestServer(t) handler := srv.Handler() rr := doRequest(t, handler, "GET", "/v1/keys/public", nil, "") if rr.Code != http.StatusOK { t.Errorf("public key status = %d, want 200", rr.Code) } var jwk map[string]string if err := json.Unmarshal(rr.Body.Bytes(), &jwk); err != nil { t.Fatalf("unmarshal JWK: %v", err) } if jwk["kty"] != "OKP" { t.Errorf("kty = %q, want OKP", jwk["kty"]) } if jwk["alg"] != "EdDSA" { t.Errorf("alg = %q, want EdDSA", jwk["alg"]) } } func TestLoginSuccess(t *testing.T) { srv, _, _, _ := newTestServer(t) createTestHumanAccount(t, srv, "alice") handler := srv.Handler() rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{ "username": "alice", "password": "testpass123", }, "") if rr.Code != http.StatusOK { t.Errorf("login status = %d, want 200; body: %s", rr.Code, rr.Body.String()) } var resp loginResponse if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { t.Fatalf("unmarshal login response: %v", err) } if resp.Token == "" { t.Error("expected non-empty token in login response") } if resp.ExpiresAt == "" { t.Error("expected non-empty expires_at in login response") } } func TestLoginWrongPassword(t *testing.T) { srv, _, _, _ := newTestServer(t) createTestHumanAccount(t, srv, "bob") handler := srv.Handler() rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{ "username": "bob", "password": "wrongpassword", }, "") if rr.Code != http.StatusUnauthorized { t.Errorf("status = %d, want 401", rr.Code) } } func TestLoginUnknownUser(t *testing.T) { srv, _, _, _ := newTestServer(t) handler := srv.Handler() rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{ "username": "nobody", "password": "password", }, "") if rr.Code != http.StatusUnauthorized { t.Errorf("status = %d, want 401", rr.Code) } } func TestLoginResponseDoesNotContainCredentials(t *testing.T) { srv, _, _, _ := newTestServer(t) createTestHumanAccount(t, srv, "charlie") handler := srv.Handler() rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{ "username": "charlie", "password": "testpass123", }, "") body := rr.Body.String() // Security: password hash must never appear in any API response. if strings.Contains(body, "argon2id") || strings.Contains(body, "password_hash") { t.Error("login response contains password hash material") } } func TestTokenValidate(t *testing.T) { srv, _, priv, _ := newTestServer(t) acct := createTestHumanAccount(t, srv, "dave") handler := srv.Handler() // Issue and track a token. tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour) if err != nil { t.Fatalf("IssueToken: %v", err) } if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { t.Fatalf("TrackToken: %v", err) } req := httptest.NewRequest("POST", "/v1/token/validate", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Fatalf("validate status = %d, want 200", rr.Code) } var resp validateResponse if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { t.Fatalf("unmarshal: %v", err) } if !resp.Valid { t.Error("expected valid=true for valid token") } } func TestLogout(t *testing.T) { srv, _, priv, _ := newTestServer(t) acct := createTestHumanAccount(t, srv, "eve") handler := srv.Handler() tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour) if err != nil { t.Fatalf("IssueToken: %v", err) } if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { t.Fatalf("TrackToken: %v", err) } // Logout. rr := doRequest(t, handler, "POST", "/v1/auth/logout", nil, tokenStr) if rr.Code != http.StatusNoContent { t.Errorf("logout status = %d, want 204; body: %s", rr.Code, rr.Body.String()) } // Token should now be invalid on validate. req := httptest.NewRequest("POST", "/v1/token/validate", nil) req.Header.Set("Authorization", "Bearer "+tokenStr) rr2 := httptest.NewRecorder() handler.ServeHTTP(rr2, req) var resp validateResponse _ = json.Unmarshal(rr2.Body.Bytes(), &resp) if resp.Valid { t.Error("expected valid=false after logout") } } func TestCreateAccountAdmin(t *testing.T) { srv, _, priv, _ := newTestServer(t) adminToken, _ := issueAdminToken(t, srv, priv, "admin-user") handler := srv.Handler() rr := doRequest(t, handler, "POST", "/v1/accounts", map[string]string{ "username": "new-user", "password": "newpassword123", "account_type": "human", }, adminToken) if rr.Code != http.StatusCreated { t.Errorf("create account status = %d, want 201; body: %s", rr.Code, rr.Body.String()) } var resp accountResponse if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { t.Fatalf("unmarshal: %v", err) } if resp.Username != "new-user" { t.Errorf("Username = %q, want %q", resp.Username, "new-user") } // Security: password hash must not appear in account response. body := rr.Body.String() if strings.Contains(body, "password_hash") || strings.Contains(body, "argon2id") { t.Error("account creation response contains password hash") } } func TestCreateAccountRequiresAdmin(t *testing.T) { srv, _, priv, _ := newTestServer(t) acct := createTestHumanAccount(t, srv, "regular-user") tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, []string{"reader"}, time.Hour) if err != nil { t.Fatalf("IssueToken: %v", err) } if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { t.Fatalf("TrackToken: %v", err) } handler := srv.Handler() rr := doRequest(t, handler, "POST", "/v1/accounts", map[string]string{ "username": "other-user", "password": "password", "account_type": "human", }, tokenStr) if rr.Code != http.StatusForbidden { t.Errorf("status = %d, want 403", rr.Code) } } func TestListAccounts(t *testing.T) { srv, _, priv, _ := newTestServer(t) adminToken, _ := issueAdminToken(t, srv, priv, "admin2") createTestHumanAccount(t, srv, "user1") createTestHumanAccount(t, srv, "user2") handler := srv.Handler() rr := doRequest(t, handler, "GET", "/v1/accounts", nil, adminToken) if rr.Code != http.StatusOK { t.Errorf("list accounts status = %d, want 200", rr.Code) } var accounts []accountResponse if err := json.Unmarshal(rr.Body.Bytes(), &accounts); err != nil { t.Fatalf("unmarshal: %v", err) } if len(accounts) < 3 { // admin + user1 + user2 t.Errorf("expected at least 3 accounts, got %d", len(accounts)) } // Security: no credential fields in any response. body := rr.Body.String() for _, bad := range []string{"password_hash", "argon2id", "totp_secret", "PasswordHash"} { if strings.Contains(body, bad) { t.Errorf("account list response contains credential field %q", bad) } } } func TestDeleteAccount(t *testing.T) { srv, _, priv, _ := newTestServer(t) adminToken, _ := issueAdminToken(t, srv, priv, "admin3") target := createTestHumanAccount(t, srv, "delete-me") handler := srv.Handler() rr := doRequest(t, handler, "DELETE", "/v1/accounts/"+target.UUID, nil, adminToken) if rr.Code != http.StatusNoContent { t.Errorf("delete status = %d, want 204; body: %s", rr.Code, rr.Body.String()) } } func TestSetAndGetRoles(t *testing.T) { srv, _, priv, _ := newTestServer(t) adminToken, _ := issueAdminToken(t, srv, priv, "admin4") target := createTestHumanAccount(t, srv, "role-target") handler := srv.Handler() // Set roles. rr := doRequest(t, handler, "PUT", "/v1/accounts/"+target.UUID+"/roles", map[string][]string{ "roles": {"admin", "user"}, }, adminToken) if rr.Code != http.StatusNoContent { t.Errorf("set roles status = %d, want 204; body: %s", rr.Code, rr.Body.String()) } // Get roles. rr2 := doRequest(t, handler, "GET", "/v1/accounts/"+target.UUID+"/roles", nil, adminToken) if rr2.Code != http.StatusOK { t.Errorf("get roles status = %d, want 200", rr2.Code) } var resp rolesResponse if err := json.Unmarshal(rr2.Body.Bytes(), &resp); err != nil { t.Fatalf("unmarshal: %v", err) } if len(resp.Roles) != 2 { t.Errorf("expected 2 roles, got %d", len(resp.Roles)) } } func TestLoginRateLimited(t *testing.T) { srv, _, _, _ := newTestServer(t) handler := srv.Handler() // The login endpoint uses RateLimit(10, 10): burst of 10 requests. // We send all burst+1 requests concurrently so they all hit the rate // limiter before any Argon2id hash can complete. This is necessary because // Argon2id takes ~500ms per request; sequential requests refill the // token bucket faster than they drain it at 10 req/s. const burst = 10 bodyJSON := []byte(`{"username":"nobody","password":"wrong"}`) type result struct { hdr http.Header code int } results := make([]result, burst+1) var wg sync.WaitGroup for i := range burst + 1 { wg.Add(1) go func(idx int) { defer wg.Done() req := httptest.NewRequest("POST", "/v1/auth/login", bytes.NewReader(bodyJSON)) req.Header.Set("Content-Type", "application/json") req.RemoteAddr = "10.1.1.1:9999" // same IP for all rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) results[idx] = result{code: rr.Code, hdr: rr.Result().Header} }(i) } wg.Wait() // At least one of the burst+1 concurrent requests must have been // rate-limited (429). Which one is non-deterministic. var got429 bool var retryAfterSet bool for _, r := range results { if r.code == http.StatusTooManyRequests { got429 = true retryAfterSet = r.hdr.Get("Retry-After") != "" break } } if !got429 { t.Error("expected at least one 429 after burst+1 concurrent login requests") } if !retryAfterSet { t.Error("expected Retry-After header on 429 response") } } func TestTokenValidateRateLimited(t *testing.T) { srv, _, _, _ := newTestServer(t) handler := srv.Handler() // The token/validate endpoint shares the same per-IP rate limiter as login. // Use a distinct RemoteAddr so we get a fresh bucket. body := map[string]string{"token": "not.a.valid.token"} for i := range 10 { b, _ := json.Marshal(body) req := httptest.NewRequest("POST", "/v1/token/validate", bytes.NewReader(b)) req.Header.Set("Content-Type", "application/json") req.RemoteAddr = "10.99.99.1:12345" rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code == http.StatusTooManyRequests { t.Fatalf("request %d was rate-limited prematurely", i+1) } } // 11th request should be rate-limited. b, _ := json.Marshal(body) req := httptest.NewRequest("POST", "/v1/token/validate", bytes.NewReader(b)) req.Header.Set("Content-Type", "application/json") req.RemoteAddr = "10.99.99.1:12345" rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusTooManyRequests { t.Errorf("expected 429 after exhausting burst, got %d", rr.Code) } } func TestHealthNotRateLimited(t *testing.T) { srv, _, _, _ := newTestServer(t) handler := srv.Handler() // Health endpoint should not be rate-limited — send 20 rapid requests. for i := range 20 { req := httptest.NewRequest("GET", "/v1/health", nil) req.RemoteAddr = "10.88.88.1:12345" rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusOK { t.Errorf("health request %d: status = %d, want 200", i+1, rr.Code) } } } // TestTOTPEnrollDoesNotRequireTOTP verifies that initiating TOTP enrollment // (POST /v1/auth/totp/enroll) stores the pending secret without setting // totp_required=1. A user who starts but does not complete enrollment must // still be able to log in with password alone — no lockout. // // Security: this guards against F-01 (enrollment sets the flag prematurely), // which would let an attacker initiate enrollment for a victim account and // then prevent that account from authenticating. func TestTOTPEnrollDoesNotRequireTOTP(t *testing.T) { srv, _, priv, _ := newTestServer(t) acct := createTestHumanAccount(t, srv, "totp-enroll-user") handler := srv.Handler() // Issue a token for this user. tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour) if err != nil { t.Fatalf("IssueToken: %v", err) } if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { t.Fatalf("TrackToken: %v", err) } // Start enrollment (password required since SEC-01 fix). rr := doRequest(t, handler, "POST", "/v1/auth/totp/enroll", totpEnrollRequest{ Password: "testpass123", }, tokenStr) if rr.Code != http.StatusOK { t.Fatalf("enroll status = %d, want 200; body: %s", rr.Code, rr.Body.String()) } var enrollResp totpEnrollResponse if err := json.Unmarshal(rr.Body.Bytes(), &enrollResp); err != nil { t.Fatalf("unmarshal enroll response: %v", err) } if enrollResp.Secret == "" { t.Error("expected non-empty TOTP secret in enrollment response") } // Security: totp_required must still be false after enrollment start. // If it were true the user would be locked out until they confirm. freshAcct, err := srv.db.GetAccountByUUID(acct.UUID) if err != nil { t.Fatalf("GetAccountByUUID: %v", err) } if freshAcct.TOTPRequired { t.Error("totp_required = true after enroll — lockout risk (F-01)") } // The pending secret should be stored (needed for confirm). if freshAcct.TOTPSecretEnc == nil { t.Error("totp_secret_enc is nil after enroll — confirm would fail") } // Login without TOTP code must still succeed (enrollment not confirmed). rr2 := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{ "username": "totp-enroll-user", "password": "testpass123", }, "") if rr2.Code != http.StatusOK { t.Errorf("login without TOTP after incomplete enrollment: status = %d, want 200; body: %s", rr2.Code, rr2.Body.String()) } } // TestTOTPEnrollRequiresPassword verifies that TOTP enrollment (SEC-01) // requires the current password. A stolen session token alone must not be // sufficient to add attacker-controlled MFA to the victim's account. func TestTOTPEnrollRequiresPassword(t *testing.T) { srv, _, priv, _ := newTestServer(t) acct := createTestHumanAccount(t, srv, "totp-pw-check") handler := srv.Handler() tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour) if err != nil { t.Fatalf("IssueToken: %v", err) } if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { t.Fatalf("TrackToken: %v", err) } t.Run("no password", func(t *testing.T) { rr := doRequest(t, handler, "POST", "/v1/auth/totp/enroll", totpEnrollRequest{}, tokenStr) if rr.Code != http.StatusBadRequest { t.Errorf("enroll without password: status = %d, want %d; body: %s", rr.Code, http.StatusBadRequest, rr.Body.String()) } }) t.Run("wrong password", func(t *testing.T) { rr := doRequest(t, handler, "POST", "/v1/auth/totp/enroll", totpEnrollRequest{ Password: "wrong-password", }, tokenStr) if rr.Code != http.StatusUnauthorized { t.Errorf("enroll with wrong password: status = %d, want %d; body: %s", rr.Code, http.StatusUnauthorized, rr.Body.String()) } }) t.Run("correct password", func(t *testing.T) { rr := doRequest(t, handler, "POST", "/v1/auth/totp/enroll", totpEnrollRequest{ Password: "testpass123", }, tokenStr) if rr.Code != http.StatusOK { t.Fatalf("enroll with correct password: status = %d, want 200; body: %s", rr.Code, rr.Body.String()) } var resp totpEnrollResponse if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { t.Fatalf("unmarshal: %v", err) } if resp.Secret == "" { t.Error("expected non-empty TOTP secret") } if resp.OTPAuthURI == "" { t.Error("expected non-empty otpauth URI") } }) } func TestRenewToken(t *testing.T) { srv, _, priv, _ := newTestServer(t) acct := createTestHumanAccount(t, srv, "renew-user") handler := srv.Handler() // Issue a short-lived token (4s) so we can wait past the 50% threshold // while leaving enough headroom before expiry to avoid flakiness. oldTokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, 4*time.Second) if err != nil { t.Fatalf("IssueToken: %v", err) } oldJTI := claims.JTI if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { t.Fatalf("TrackToken: %v", err) } // Wait for >50% of the 4s lifetime to elapse. time.Sleep(2100 * time.Millisecond) rr := doRequest(t, handler, "POST", "/v1/auth/renew", nil, oldTokenStr) if rr.Code != http.StatusOK { t.Fatalf("renew status = %d, want 200; body: %s", rr.Code, rr.Body.String()) } var resp loginResponse if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { t.Fatalf("unmarshal renew response: %v", err) } if resp.Token == "" || resp.Token == oldTokenStr { t.Error("expected new, distinct token after renewal") } // Old token should be revoked in the database. rec, err := srv.db.GetTokenRecord(oldJTI) if err != nil { t.Fatalf("GetTokenRecord: %v", err) } if !rec.IsRevoked() { t.Error("old token should be revoked after renewal") } } func TestOversizedJSONBodyRejected(t *testing.T) { srv, _, _, _ := newTestServer(t) handler := srv.Handler() // Build a JSON body larger than 1 MiB. oversized := bytes.Repeat([]byte("A"), (1<<20)+1) body := []byte(`{"username":"admin","password":"` + string(oversized) + `"}`) req := httptest.NewRequest("POST", "/v1/auth/login", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if rr.Code != http.StatusBadRequest { t.Errorf("expected 400 for oversized body, got %d", rr.Code) } } // TestSecurityHeadersOnAPIResponses verifies that the global security-headers // middleware (SEC-04) sets X-Content-Type-Options, Strict-Transport-Security, // and Cache-Control on all API responses, not just the UI. func TestSecurityHeadersOnAPIResponses(t *testing.T) { srv, _, _, _ := newTestServer(t) handler := srv.Handler() wantHeaders := map[string]string{ "X-Content-Type-Options": "nosniff", "Strict-Transport-Security": "max-age=63072000; includeSubDomains", "Cache-Control": "no-store", } t.Run("GET /v1/health", func(t *testing.T) { rr := doRequest(t, handler, "GET", "/v1/health", nil, "") if rr.Code != http.StatusOK { t.Fatalf("status = %d, want 200", rr.Code) } for header, want := range wantHeaders { got := rr.Header().Get(header) if got != want { t.Errorf("%s = %q, want %q", header, got, want) } } }) t.Run("POST /v1/auth/login", func(t *testing.T) { createTestHumanAccount(t, srv, "sec04-user") rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{ "username": "sec04-user", "password": "testpass123", }, "") if rr.Code != http.StatusOK { t.Fatalf("status = %d, want 200; body: %s", rr.Code, rr.Body.String()) } for header, want := range wantHeaders { got := rr.Header().Get(header) if got != want { t.Errorf("%s = %q, want %q", header, got, want) } } }) } // TestLoginLockedAccountReturns401 verifies that a locked-out account gets the // same HTTP 401 / "invalid credentials" response as a wrong-password attempt, // preventing user-enumeration via lockout differentiation (SEC-02). func TestLoginLockedAccountReturns401(t *testing.T) { srv, _, _, database := newTestServer(t) acct := createTestHumanAccount(t, srv, "lockuser") handler := srv.Handler() // Lower the lockout threshold so we don't need 10 failures. origThreshold := db.LockoutThreshold db.LockoutThreshold = 3 t.Cleanup(func() { db.LockoutThreshold = origThreshold }) // Record enough failures to trigger lockout. for range db.LockoutThreshold { if err := database.RecordLoginFailure(acct.ID); err != nil { t.Fatalf("RecordLoginFailure: %v", err) } } // Confirm the account is locked. locked, err := database.IsLockedOut(acct.ID) if err != nil { t.Fatalf("IsLockedOut: %v", err) } if !locked { t.Fatal("expected account to be locked out after threshold failures") } // Attempt login on the locked account. lockedRR := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{ "username": "lockuser", "password": "testpass123", }, "") // Also attempt login with a wrong password (not locked) for comparison. wrongRR := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{ "username": "lockuser", "password": "wrongpassword", }, "") // Both must return 401, not 429. if lockedRR.Code != http.StatusUnauthorized { t.Errorf("locked account: status = %d, want %d", lockedRR.Code, http.StatusUnauthorized) } if wrongRR.Code != http.StatusUnauthorized { t.Errorf("wrong password: status = %d, want %d", wrongRR.Code, http.StatusUnauthorized) } // Parse the JSON bodies and compare — they must be identical. type errResp struct { Error string `json:"error"` Code string `json:"code"` } var lockedBody, wrongBody errResp if err := json.Unmarshal(lockedRR.Body.Bytes(), &lockedBody); err != nil { t.Fatalf("unmarshal locked body: %v", err) } if err := json.Unmarshal(wrongRR.Body.Bytes(), &wrongBody); err != nil { t.Fatalf("unmarshal wrong body: %v", err) } if lockedBody != wrongBody { t.Errorf("locked response %+v differs from wrong-password response %+v", lockedBody, wrongBody) } if lockedBody.Code != "unauthorized" { t.Errorf("locked response code = %q, want %q", lockedBody.Code, "unauthorized") } if lockedBody.Error != "invalid credentials" { t.Errorf("locked response error = %q, want %q", lockedBody.Error, "invalid credentials") } } // TestRenewTokenTooEarly verifies that a token cannot be renewed before 50% // of its lifetime has elapsed (SEC-03). // TestExtractBearerFromRequest verifies that extractBearerFromRequest correctly // validates the "Bearer" prefix before extracting the token string. // Security (PEN-01): the previous implementation sliced at a fixed offset // without checking the prefix, accepting any 8+ character Authorization value. func TestExtractBearerFromRequest(t *testing.T) { tests := []struct { name string header string want string wantErr bool }{ {"valid", "Bearer mytoken123", "mytoken123", false}, {"missing header", "", "", true}, {"no bearer prefix", "Token mytoken123", "", true}, {"basic auth scheme", "Basic dXNlcjpwYXNz", "", true}, {"empty token", "Bearer ", "", true}, {"bearer only no space", "Bearer", "", true}, {"case insensitive", "bearer mytoken123", "mytoken123", false}, {"mixed case", "BEARER mytoken123", "mytoken123", false}, {"garbage 8 chars", "XXXXXXXX", "", true}, {"token with spaces", "Bearer token with spaces", "token with spaces", false}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) if tc.header != "" { req.Header.Set("Authorization", tc.header) } got, err := extractBearerFromRequest(req) if (err != nil) != tc.wantErr { t.Errorf("wantErr=%v, got err=%v", tc.wantErr, err) } if !tc.wantErr && got != tc.want { t.Errorf("token = %q, want %q", got, tc.want) } }) } } func TestRenewTokenTooEarly(t *testing.T) { srv, _, priv, _ := newTestServer(t) acct := createTestHumanAccount(t, srv, "renew-early-user") handler := srv.Handler() // Issue a long-lived token so 50% is far in the future. tokStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour) if err != nil { t.Fatalf("IssueToken: %v", err) } if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { t.Fatalf("TrackToken: %v", err) } // Immediately try to renew — should be rejected. rr := doRequest(t, handler, "POST", "/v1/auth/renew", nil, tokStr) if rr.Code != http.StatusBadRequest { t.Fatalf("renew status = %d, want 400; body: %s", rr.Code, rr.Body.String()) } if !strings.Contains(rr.Body.String(), "not yet eligible for renewal") { t.Errorf("expected eligibility message, got: %s", rr.Body.String()) } } // TestTOTPMissingDoesNotIncrementLockout verifies that a login attempt with // a correct password but missing TOTP code does NOT increment the account // lockout counter (PEN-06 / DEF-08). // // Security: incrementing the lockout counter for a missing TOTP code would // allow an attacker to lock out a TOTP-enrolled account by repeatedly sending // the correct password with no TOTP code — without needing to guess TOTP. // It would also penalise well-behaved two-step clients. func TestTOTPMissingDoesNotIncrementLockout(t *testing.T) { srv, _, priv, database := newTestServer(t) acct := createTestHumanAccount(t, srv, "totp-lockout-user") handler := srv.Handler() // Issue a token so we can call the TOTP enroll and confirm endpoints. tokenStr, claims, err := token.IssueToken(priv, testIssuer, acct.UUID, nil, time.Hour) if err != nil { t.Fatalf("IssueToken: %v", err) } if err := srv.db.TrackToken(claims.JTI, acct.ID, claims.IssuedAt, claims.ExpiresAt); err != nil { t.Fatalf("TrackToken: %v", err) } // Enroll TOTP — get back the base32 secret. enrollRR := doRequest(t, handler, "POST", "/v1/auth/totp/enroll", totpEnrollRequest{ Password: "testpass123", }, tokenStr) if enrollRR.Code != http.StatusOK { t.Fatalf("enroll status = %d, want 200; body: %s", enrollRR.Code, enrollRR.Body.String()) } var enrollResp totpEnrollResponse if err := json.Unmarshal(enrollRR.Body.Bytes(), &enrollResp); err != nil { t.Fatalf("unmarshal enroll: %v", err) } // Decode the secret and generate a valid TOTP code to confirm enrollment. // We compute the TOTP code inline using the same RFC 6238 algorithm used // by auth.ValidateTOTP, since auth.hotp is not exported. secretBytes, err := auth.DecodeTOTPSecret(enrollResp.Secret) if err != nil { t.Fatalf("DecodeTOTPSecret: %v", err) } currentCode := generateTOTPCode(t, secretBytes) // Confirm enrollment. confirmRR := doRequest(t, handler, "POST", "/v1/auth/totp/confirm", map[string]string{ "code": currentCode, }, tokenStr) if confirmRR.Code != http.StatusNoContent { t.Fatalf("confirm status = %d, want 204; body: %s", confirmRR.Code, confirmRR.Body.String()) } // Account should now require TOTP. Lower the lockout threshold to 1 so // that a single RecordLoginFailure call would immediately lock the account. origThreshold := db.LockoutThreshold db.LockoutThreshold = 1 t.Cleanup(func() { db.LockoutThreshold = origThreshold }) // Attempt login with the correct password but no TOTP code. rr := doRequest(t, handler, "POST", "/v1/auth/login", map[string]string{ "username": "totp-lockout-user", "password": "testpass123", }, "") if rr.Code != http.StatusUnauthorized { t.Fatalf("expected 401 for missing TOTP, got %d; body: %s", rr.Code, rr.Body.String()) } // The error code must be totp_required, not unauthorized. var errResp struct { Code string `json:"code"` } if err := json.Unmarshal(rr.Body.Bytes(), &errResp); err != nil { t.Fatalf("unmarshal error response: %v", err) } if errResp.Code != "totp_required" { t.Errorf("error code = %q, want %q", errResp.Code, "totp_required") } // Security (PEN-06): the lockout counter must NOT have been incremented. // With threshold=1, if it had been incremented the account would now be // locked and a subsequent login with correct credentials would fail. locked, err := database.IsLockedOut(acct.ID) if err != nil { t.Fatalf("IsLockedOut: %v", err) } if locked { t.Error("account was locked after TOTP-missing login — lockout counter was incorrectly incremented (PEN-06)") } }