package db import ( "database/sql" "errors" "fmt" "time" "git.wntrmute.dev/kyle/mcias/internal/model" "github.com/google/uuid" ) // CreateAccount inserts a new account record. The UUID is generated // automatically. Returns the created Account with its DB-assigned ID and UUID. func (db *DB) CreateAccount(username string, accountType model.AccountType, passwordHash string) (*model.Account, error) { id := uuid.New().String() n := now() result, err := db.sql.Exec(` INSERT INTO accounts (uuid, username, account_type, password_hash, status, created_at, updated_at) VALUES (?, ?, ?, ?, 'active', ?, ?) `, id, username, string(accountType), nullString(passwordHash), n, n) if err != nil { return nil, fmt.Errorf("db: create account %q: %w", username, err) } rowID, err := result.LastInsertId() if err != nil { return nil, fmt.Errorf("db: last insert id for account %q: %w", username, err) } createdAt, err := parseTime(n) if err != nil { return nil, err } return &model.Account{ ID: rowID, UUID: id, Username: username, AccountType: accountType, Status: model.AccountStatusActive, PasswordHash: passwordHash, CreatedAt: createdAt, UpdatedAt: createdAt, }, nil } // GetAccountByUUID retrieves an account by its external UUID. // Returns ErrNotFound if no matching account exists. func (db *DB) GetAccountByUUID(accountUUID string) (*model.Account, error) { return db.scanAccount(db.sql.QueryRow(` SELECT id, uuid, username, account_type, COALESCE(password_hash,''), status, totp_required, totp_secret_enc, totp_secret_nonce, created_at, updated_at, deleted_at FROM accounts WHERE uuid = ? `, accountUUID)) } // GetAccountByID retrieves an account by its numeric primary key. // Returns ErrNotFound if no matching account exists. func (db *DB) GetAccountByID(id int64) (*model.Account, error) { return db.scanAccount(db.sql.QueryRow(` SELECT id, uuid, username, account_type, COALESCE(password_hash,''), status, totp_required, totp_secret_enc, totp_secret_nonce, created_at, updated_at, deleted_at FROM accounts WHERE id = ? `, id)) } // GetAccountByUsername retrieves an account by username (case-insensitive). // Returns ErrNotFound if no matching account exists. func (db *DB) GetAccountByUsername(username string) (*model.Account, error) { return db.scanAccount(db.sql.QueryRow(` SELECT id, uuid, username, account_type, COALESCE(password_hash,''), status, totp_required, totp_secret_enc, totp_secret_nonce, created_at, updated_at, deleted_at FROM accounts WHERE username = ? `, username)) } // ListAccounts returns all non-deleted accounts ordered by username. func (db *DB) ListAccounts() ([]*model.Account, error) { rows, err := db.sql.Query(` SELECT id, uuid, username, account_type, COALESCE(password_hash,''), status, totp_required, totp_secret_enc, totp_secret_nonce, created_at, updated_at, deleted_at FROM accounts WHERE status != 'deleted' ORDER BY username ASC `) if err != nil { return nil, fmt.Errorf("db: list accounts: %w", err) } defer func() { _ = rows.Close() }() var accounts []*model.Account for rows.Next() { a, err := db.scanAccountRow(rows) if err != nil { return nil, err } accounts = append(accounts, a) } return accounts, rows.Err() } // UpdateAccountStatus updates the status field and optionally sets deleted_at. func (db *DB) UpdateAccountStatus(accountID int64, status model.AccountStatus) error { n := now() var deletedAt *string if status == model.AccountStatusDeleted { deletedAt = &n } _, err := db.sql.Exec(` UPDATE accounts SET status = ?, deleted_at = ?, updated_at = ? WHERE id = ? `, string(status), deletedAt, n, accountID) if err != nil { return fmt.Errorf("db: update account status: %w", err) } return nil } // UpdatePasswordHash updates the Argon2id password hash for an account. func (db *DB) UpdatePasswordHash(accountID int64, hash string) error { _, err := db.sql.Exec(` UPDATE accounts SET password_hash = ?, updated_at = ? WHERE id = ? `, hash, now(), accountID) if err != nil { return fmt.Errorf("db: update password hash: %w", err) } return nil } // StorePendingTOTP stores the encrypted TOTP secret without enabling the // totp_required flag. This is called during enrollment (POST /v1/auth/totp/enroll) // so the secret is available for confirmation without yet enforcing TOTP on login. // The flag is set to 1 only after the user successfully confirms the code via // handleTOTPConfirm, which calls SetTOTP. // // Security: keeping totp_required=0 during enrollment prevents the user from // being locked out if they abandon the enrollment flow after the secret is // generated but before they have set up their authenticator app. func (db *DB) StorePendingTOTP(accountID int64, secretEnc, secretNonce []byte) error { _, err := db.sql.Exec(` UPDATE accounts SET totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ? WHERE id = ? `, secretEnc, secretNonce, now(), accountID) if err != nil { return fmt.Errorf("db: store pending TOTP: %w", err) } return nil } // SetTOTP stores the encrypted TOTP secret and marks TOTP as required. // Call this only after the user has confirmed the TOTP code; for the initial // enrollment step use StorePendingTOTP instead. func (db *DB) SetTOTP(accountID int64, secretEnc, secretNonce []byte) error { _, err := db.sql.Exec(` UPDATE accounts SET totp_required = 1, totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ? WHERE id = ? `, secretEnc, secretNonce, now(), accountID) if err != nil { return fmt.Errorf("db: set TOTP: %w", err) } return nil } // ClearTOTP removes the TOTP secret and disables TOTP requirement. func (db *DB) ClearTOTP(accountID int64) error { _, err := db.sql.Exec(` UPDATE accounts SET totp_required = 0, totp_secret_enc = NULL, totp_secret_nonce = NULL, updated_at = ? WHERE id = ? `, now(), accountID) if err != nil { return fmt.Errorf("db: clear TOTP: %w", err) } return nil } // scanAccount scans a single account row from a *sql.Row. func (db *DB) scanAccount(row *sql.Row) (*model.Account, error) { var a model.Account var accountType, status string var totpRequired int var createdAtStr, updatedAtStr string var deletedAtStr *string var totpSecretEnc, totpSecretNonce []byte err := row.Scan( &a.ID, &a.UUID, &a.Username, &accountType, &a.PasswordHash, &status, &totpRequired, &totpSecretEnc, &totpSecretNonce, &createdAtStr, &updatedAtStr, &deletedAtStr, ) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("db: scan account: %w", err) } return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr) } // scanAccountRow scans a single account from *sql.Rows. func (db *DB) scanAccountRow(rows *sql.Rows) (*model.Account, error) { var a model.Account var accountType, status string var totpRequired int var createdAtStr, updatedAtStr string var deletedAtStr *string var totpSecretEnc, totpSecretNonce []byte err := rows.Scan( &a.ID, &a.UUID, &a.Username, &accountType, &a.PasswordHash, &status, &totpRequired, &totpSecretEnc, &totpSecretNonce, &createdAtStr, &updatedAtStr, &deletedAtStr, ) if err != nil { return nil, fmt.Errorf("db: scan account row: %w", err) } return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr) } func finishAccountScan(a *model.Account, accountType, status string, totpRequired int, totpSecretEnc, totpSecretNonce []byte, createdAtStr, updatedAtStr string, deletedAtStr *string) (*model.Account, error) { a.AccountType = model.AccountType(accountType) a.Status = model.AccountStatus(status) a.TOTPRequired = totpRequired == 1 a.TOTPSecretEnc = totpSecretEnc a.TOTPSecretNonce = totpSecretNonce var err error a.CreatedAt, err = parseTime(createdAtStr) if err != nil { return nil, err } a.UpdatedAt, err = parseTime(updatedAtStr) if err != nil { return nil, err } a.DeletedAt, err = nullableTime(deletedAtStr) if err != nil { return nil, err } return a, nil } // nullString converts an empty string to nil for nullable SQL columns. func nullString(s string) *string { if s == "" { return nil } return &s } // GetRoles returns the role strings assigned to an account. func (db *DB) GetRoles(accountID int64) ([]string, error) { rows, err := db.sql.Query(` SELECT role FROM account_roles WHERE account_id = ? ORDER BY role ASC `, accountID) if err != nil { return nil, fmt.Errorf("db: get roles for account %d: %w", accountID, err) } defer func() { _ = rows.Close() }() var roles []string for rows.Next() { var role string if err := rows.Scan(&role); err != nil { return nil, fmt.Errorf("db: scan role: %w", err) } roles = append(roles, role) } return roles, rows.Err() } // GrantRole adds a role to an account. If the role already exists, it is a no-op. func (db *DB) GrantRole(accountID int64, role string, grantedBy *int64) error { _, err := db.sql.Exec(` INSERT OR IGNORE INTO account_roles (account_id, role, granted_by, granted_at) VALUES (?, ?, ?, ?) `, accountID, role, grantedBy, now()) if err != nil { return fmt.Errorf("db: grant role %q to account %d: %w", role, accountID, err) } return nil } // RevokeRole removes a role from an account. func (db *DB) RevokeRole(accountID int64, role string) error { _, err := db.sql.Exec(` DELETE FROM account_roles WHERE account_id = ? AND role = ? `, accountID, role) if err != nil { return fmt.Errorf("db: revoke role %q from account %d: %w", role, accountID, err) } return nil } // SetRoles replaces the full role set for an account atomically. func (db *DB) SetRoles(accountID int64, roles []string, grantedBy *int64) error { tx, err := db.sql.Begin() if err != nil { return fmt.Errorf("db: set roles begin tx: %w", err) } if _, err := tx.Exec(`DELETE FROM account_roles WHERE account_id = ?`, accountID); err != nil { _ = tx.Rollback() return fmt.Errorf("db: set roles delete existing: %w", err) } n := now() for _, role := range roles { if _, err := tx.Exec(` INSERT INTO account_roles (account_id, role, granted_by, granted_at) VALUES (?, ?, ?, ?) `, accountID, role, grantedBy, n); err != nil { _ = tx.Rollback() return fmt.Errorf("db: set roles insert %q: %w", role, err) } } if err := tx.Commit(); err != nil { return fmt.Errorf("db: set roles commit: %w", err) } return nil } // HasRole reports whether an account holds the given role. func (db *DB) HasRole(accountID int64, role string) (bool, error) { var count int err := db.sql.QueryRow(` SELECT COUNT(*) FROM account_roles WHERE account_id = ? AND role = ? `, accountID, role).Scan(&count) if err != nil { return false, fmt.Errorf("db: has role: %w", err) } return count > 0, nil } // WriteServerConfig stores the encrypted Ed25519 signing key. // There can only be one row (id=1). func (db *DB) WriteServerConfig(signingKeyEnc, signingKeyNonce []byte) error { n := now() _, err := db.sql.Exec(` INSERT INTO server_config (id, signing_key_enc, signing_key_nonce, created_at, updated_at) VALUES (1, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET signing_key_enc = excluded.signing_key_enc, signing_key_nonce = excluded.signing_key_nonce, updated_at = excluded.updated_at `, signingKeyEnc, signingKeyNonce, n, n) if err != nil { return fmt.Errorf("db: write server config: %w", err) } return nil } // ReadServerConfig returns the encrypted signing key and nonce. // Returns ErrNotFound if no config row exists yet. func (db *DB) ReadServerConfig() (signingKeyEnc, signingKeyNonce []byte, err error) { err = db.sql.QueryRow(` SELECT signing_key_enc, signing_key_nonce FROM server_config WHERE id = 1 `).Scan(&signingKeyEnc, &signingKeyNonce) if errors.Is(err, sql.ErrNoRows) { return nil, nil, ErrNotFound } if err != nil { return nil, nil, fmt.Errorf("db: read server config: %w", err) } return signingKeyEnc, signingKeyNonce, nil } // WriteMasterKeySalt stores the Argon2id KDF salt for the master key derivation. // The salt must be stable across restarts so the same passphrase always yields // the same master key. There can only be one row (id=1). func (db *DB) WriteMasterKeySalt(salt []byte) error { n := now() _, err := db.sql.Exec(` INSERT INTO server_config (id, master_key_salt, created_at, updated_at) VALUES (1, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET master_key_salt = excluded.master_key_salt, updated_at = excluded.updated_at `, salt, n, n) if err != nil { return fmt.Errorf("db: write master key salt: %w", err) } return nil } // ReadMasterKeySalt returns the stored Argon2id KDF salt. // Returns ErrNotFound if no salt has been stored yet (first run). func (db *DB) ReadMasterKeySalt() ([]byte, error) { var salt []byte err := db.sql.QueryRow(` SELECT master_key_salt FROM server_config WHERE id = 1 `).Scan(&salt) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("db: read master key salt: %w", err) } if salt == nil { return nil, ErrNotFound } return salt, nil } // WritePGCredentials stores or replaces the Postgres credentials for an account. func (db *DB) WritePGCredentials(accountID int64, host string, port int, dbName, username string, passwordEnc, passwordNonce []byte) error { n := now() _, err := db.sql.Exec(` INSERT INTO pg_credentials (account_id, pg_host, pg_port, pg_database, pg_username, pg_password_enc, pg_password_nonce, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(account_id) DO UPDATE SET pg_host = excluded.pg_host, pg_port = excluded.pg_port, pg_database = excluded.pg_database, pg_username = excluded.pg_username, pg_password_enc = excluded.pg_password_enc, pg_password_nonce = excluded.pg_password_nonce, updated_at = excluded.updated_at `, accountID, host, port, dbName, username, passwordEnc, passwordNonce, n, n) if err != nil { return fmt.Errorf("db: write pg credentials: %w", err) } return nil } // ReadPGCredentials retrieves the encrypted Postgres credentials for an account. // Returns ErrNotFound if no credentials are stored. func (db *DB) ReadPGCredentials(accountID int64) (*model.PGCredential, error) { var cred model.PGCredential var createdAtStr, updatedAtStr string err := db.sql.QueryRow(` SELECT id, account_id, pg_host, pg_port, pg_database, pg_username, pg_password_enc, pg_password_nonce, created_at, updated_at FROM pg_credentials WHERE account_id = ? `, accountID).Scan( &cred.ID, &cred.AccountID, &cred.PGHost, &cred.PGPort, &cred.PGDatabase, &cred.PGUsername, &cred.PGPasswordEnc, &cred.PGPasswordNonce, &createdAtStr, &updatedAtStr, ) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("db: read pg credentials: %w", err) } cred.CreatedAt, err = parseTime(createdAtStr) if err != nil { return nil, err } cred.UpdatedAt, err = parseTime(updatedAtStr) if err != nil { return nil, err } return &cred, nil } // ---- Login lockout (F-08) ---- // LockoutWindow is the rolling window for failed-login counting. // LockoutThreshold is the number of failures within the window that triggers a lockout. // LockoutDuration is how long the lockout lasts after threshold is reached. // These are package-level vars (not consts) so tests can override them. // // Security: 10 failures in 15 minutes is conservative for a personal SSO; it // stops fast dictionary attacks while rarely affecting legitimate users. var ( LockoutWindow = 15 * time.Minute LockoutThreshold = 10 LockoutDuration = 15 * time.Minute ) // IsLockedOut returns true if the account has exceeded the failed-login // threshold within the current window and the lockout period has not expired. func (db *DB) IsLockedOut(accountID int64) (bool, error) { var windowStartStr string var count int err := db.sql.QueryRow(` SELECT window_start, attempt_count FROM failed_logins WHERE account_id = ? `, accountID).Scan(&windowStartStr, &count) if errors.Is(err, sql.ErrNoRows) { return false, nil } if err != nil { return false, fmt.Errorf("db: is locked out %d: %w", accountID, err) } windowStart, err := parseTime(windowStartStr) if err != nil { return false, err } // Window has expired — not locked out. if time.Since(windowStart) > LockoutWindow+LockoutDuration { return false, nil } // Under threshold — not locked out. if count < LockoutThreshold { return false, nil } // Threshold exceeded; locked out until window_start + LockoutDuration. lockedUntil := windowStart.Add(LockoutDuration) return time.Now().Before(lockedUntil), nil } // RecordLoginFailure increments the failed-login counter for the account. // If the current window has expired a new window is started. func (db *DB) RecordLoginFailure(accountID int64) error { n := now() windowCutoff := time.Now().Add(-LockoutWindow).UTC().Format(time.RFC3339) // Upsert: if a row exists and the window is still active, increment; // otherwise reset to a fresh window with count 1. _, err := db.sql.Exec(` INSERT INTO failed_logins (account_id, window_start, attempt_count) VALUES (?, ?, 1) ON CONFLICT(account_id) DO UPDATE SET window_start = CASE WHEN window_start < ? THEN excluded.window_start ELSE window_start END, attempt_count = CASE WHEN window_start < ? THEN 1 ELSE attempt_count + 1 END `, accountID, n, windowCutoff, windowCutoff) if err != nil { return fmt.Errorf("db: record login failure %d: %w", accountID, err) } return nil } // ClearLoginFailures resets the failed-login counter for the account. // Called on successful login. func (db *DB) ClearLoginFailures(accountID int64) error { _, err := db.sql.Exec(`DELETE FROM failed_logins WHERE account_id = ?`, accountID) if err != nil { return fmt.Errorf("db: clear login failures %d: %w", accountID, err) } return nil } // WriteAuditEvent appends an audit log entry. // Details must never contain credential material. func (db *DB) WriteAuditEvent(eventType string, actorID, targetID *int64, ipAddress, details string) error { _, err := db.sql.Exec(` INSERT INTO audit_log (event_type, actor_id, target_id, ip_address, details) VALUES (?, ?, ?, ?, ?) `, eventType, actorID, targetID, nullString(ipAddress), nullString(details)) if err != nil { return fmt.Errorf("db: write audit event %q: %w", eventType, err) } return nil } // TrackToken records a newly issued JWT JTI for revocation tracking. func (db *DB) TrackToken(jti string, accountID int64, issuedAt, expiresAt time.Time) error { _, err := db.sql.Exec(` INSERT INTO token_revocation (jti, account_id, issued_at, expires_at) VALUES (?, ?, ?, ?) `, jti, accountID, issuedAt.UTC().Format(time.RFC3339), expiresAt.UTC().Format(time.RFC3339)) if err != nil { return fmt.Errorf("db: track token %q: %w", jti, err) } return nil } // GetTokenRecord retrieves a token record by JTI. // Returns ErrNotFound if no record exists (token was never issued by this server). func (db *DB) GetTokenRecord(jti string) (*model.TokenRecord, error) { var rec model.TokenRecord var issuedAtStr, expiresAtStr, createdAtStr string var revokedAtStr *string var revokeReason *string err := db.sql.QueryRow(` SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at FROM token_revocation WHERE jti = ? `, jti).Scan( &rec.ID, &rec.JTI, &rec.AccountID, &expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason, &createdAtStr, ) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("db: get token record %q: %w", jti, err) } var parseErr error rec.ExpiresAt, parseErr = parseTime(expiresAtStr) if parseErr != nil { return nil, parseErr } rec.IssuedAt, parseErr = parseTime(issuedAtStr) if parseErr != nil { return nil, parseErr } rec.CreatedAt, parseErr = parseTime(createdAtStr) if parseErr != nil { return nil, parseErr } rec.RevokedAt, parseErr = nullableTime(revokedAtStr) if parseErr != nil { return nil, parseErr } if revokeReason != nil { rec.RevokeReason = *revokeReason } return &rec, nil } // RevokeToken marks a token as revoked by JTI. func (db *DB) RevokeToken(jti, reason string) error { n := now() result, err := db.sql.Exec(` UPDATE token_revocation SET revoked_at = ?, revoke_reason = ? WHERE jti = ? AND revoked_at IS NULL `, n, nullString(reason), jti) if err != nil { return fmt.Errorf("db: revoke token %q: %w", jti, err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("db: revoke token rows affected: %w", err) } if rows == 0 { return fmt.Errorf("db: token %q not found or already revoked", jti) } return nil } // RenewToken atomically revokes the old token and tracks the new one in a // single SQLite transaction. // // Security: the two operations must be atomic so that a failure between them // cannot leave the user with neither a valid old token nor a new one. With // MaxOpenConns(1) and SQLite's serialised write path, BEGIN IMMEDIATE acquires // the write lock immediately and prevents any other writer from interleaving. func (db *DB) RenewToken(oldJTI, reason, newJTI string, accountID int64, issuedAt, expiresAt time.Time) error { tx, err := db.sql.Begin() if err != nil { return fmt.Errorf("db: renew token begin tx: %w", err) } defer func() { _ = tx.Rollback() }() n := now() // Revoke the old token. result, err := tx.Exec(` UPDATE token_revocation SET revoked_at = ?, revoke_reason = ? WHERE jti = ? AND revoked_at IS NULL `, n, nullString(reason), oldJTI) if err != nil { return fmt.Errorf("db: renew token revoke old %q: %w", oldJTI, err) } rows, err := result.RowsAffected() if err != nil { return fmt.Errorf("db: renew token revoke rows affected: %w", err) } if rows == 0 { return fmt.Errorf("db: renew token: old token %q not found or already revoked", oldJTI) } // Track the new token. _, err = tx.Exec(` INSERT INTO token_revocation (jti, account_id, issued_at, expires_at) VALUES (?, ?, ?, ?) `, newJTI, accountID, issuedAt.UTC().Format(time.RFC3339), expiresAt.UTC().Format(time.RFC3339)) if err != nil { return fmt.Errorf("db: renew token track new %q: %w", newJTI, err) } if err := tx.Commit(); err != nil { return fmt.Errorf("db: renew token commit: %w", err) } return nil } // RevokeAllUserTokens revokes all non-expired, non-revoked tokens for an account. func (db *DB) RevokeAllUserTokens(accountID int64, reason string) error { n := now() _, err := db.sql.Exec(` UPDATE token_revocation SET revoked_at = ?, revoke_reason = ? WHERE account_id = ? AND revoked_at IS NULL AND expires_at > ? `, n, nullString(reason), accountID, n) if err != nil { return fmt.Errorf("db: revoke all tokens for account %d: %w", accountID, err) } return nil } // PruneExpiredTokens removes token_revocation rows that are past their expiry. // Returns the number of rows deleted. func (db *DB) PruneExpiredTokens() (int64, error) { result, err := db.sql.Exec(` DELETE FROM token_revocation WHERE expires_at < ? `, now()) if err != nil { return 0, fmt.Errorf("db: prune expired tokens: %w", err) } return result.RowsAffected() } // ListTokensForAccount returns all token_revocation rows for the given account, // ordered by issued_at descending (newest first). func (db *DB) ListTokensForAccount(accountID int64) ([]*model.TokenRecord, error) { rows, err := db.sql.Query(` SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at FROM token_revocation WHERE account_id = ? ORDER BY issued_at DESC `, accountID) if err != nil { return nil, fmt.Errorf("db: list tokens for account %d: %w", accountID, err) } defer func() { _ = rows.Close() }() var records []*model.TokenRecord for rows.Next() { var rec model.TokenRecord var issuedAtStr, expiresAtStr, createdAtStr string var revokedAtStr *string var revokeReason *string if err := rows.Scan( &rec.ID, &rec.JTI, &rec.AccountID, &expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason, &createdAtStr, ); err != nil { return nil, fmt.Errorf("db: scan token record: %w", err) } var parseErr error rec.ExpiresAt, parseErr = parseTime(expiresAtStr) if parseErr != nil { return nil, parseErr } rec.IssuedAt, parseErr = parseTime(issuedAtStr) if parseErr != nil { return nil, parseErr } rec.CreatedAt, parseErr = parseTime(createdAtStr) if parseErr != nil { return nil, parseErr } rec.RevokedAt, parseErr = nullableTime(revokedAtStr) if parseErr != nil { return nil, parseErr } if revokeReason != nil { rec.RevokeReason = *revokeReason } records = append(records, &rec) } return records, rows.Err() } // AuditQueryParams filters for ListAuditEvents and ListAuditEventsPaged. type AuditQueryParams struct { AccountID *int64 Since *time.Time EventType string Limit int Offset int } // AuditEventView extends AuditEvent with resolved actor/target usernames for display. // Usernames are resolved via a LEFT JOIN and are empty if the actor/target is unknown. // The fieldalignment hint is suppressed: the embedded model.AuditEvent layout is fixed // and changing to explicit fields would break JSON serialisation. type AuditEventView struct { //nolint:govet model.AuditEvent ActorUsername string `json:"actor_username,omitempty"` TargetUsername string `json:"target_username,omitempty"` } // ListAuditEvents returns audit log entries matching the given parameters, // ordered by event_time ascending. Limit rows are returned if Limit > 0. func (db *DB) ListAuditEvents(p AuditQueryParams) ([]*model.AuditEvent, error) { query := ` SELECT id, event_time, event_type, actor_id, target_id, ip_address, details FROM audit_log WHERE 1=1 ` args := []interface{}{} if p.AccountID != nil { query += ` AND (actor_id = ? OR target_id = ?)` args = append(args, *p.AccountID, *p.AccountID) } if p.EventType != "" { query += ` AND event_type = ?` args = append(args, p.EventType) } if p.Since != nil { query += ` AND event_time >= ?` args = append(args, p.Since.UTC().Format(time.RFC3339)) } query += ` ORDER BY event_time ASC, id ASC` if p.Limit > 0 { query += ` LIMIT ?` args = append(args, p.Limit) } rows, err := db.sql.Query(query, args...) if err != nil { return nil, fmt.Errorf("db: list audit events: %w", err) } defer func() { _ = rows.Close() }() var events []*model.AuditEvent for rows.Next() { var ev model.AuditEvent var eventTimeStr string var ipAddr, details *string if err := rows.Scan( &ev.ID, &eventTimeStr, &ev.EventType, &ev.ActorID, &ev.TargetID, &ipAddr, &details, ); err != nil { return nil, fmt.Errorf("db: scan audit event: %w", err) } ev.EventTime, err = parseTime(eventTimeStr) if err != nil { return nil, err } if ipAddr != nil { ev.IPAddress = *ipAddr } if details != nil { ev.Details = *details } events = append(events, &ev) } return events, rows.Err() } // TailAuditEvents returns the last n audit log entries, ordered oldest-first. func (db *DB) TailAuditEvents(n int) ([]*model.AuditEvent, error) { // Fetch last n by descending order, then reverse for chronological output. rows, err := db.sql.Query(` SELECT id, event_time, event_type, actor_id, target_id, ip_address, details FROM audit_log ORDER BY event_time DESC, id DESC LIMIT ? `, n) if err != nil { return nil, fmt.Errorf("db: tail audit events: %w", err) } defer func() { _ = rows.Close() }() var events []*model.AuditEvent for rows.Next() { var ev model.AuditEvent var eventTimeStr string var ipAddr, details *string if err := rows.Scan( &ev.ID, &eventTimeStr, &ev.EventType, &ev.ActorID, &ev.TargetID, &ipAddr, &details, ); err != nil { return nil, fmt.Errorf("db: scan audit event: %w", err) } var parseErr error ev.EventTime, parseErr = parseTime(eventTimeStr) if parseErr != nil { return nil, parseErr } if ipAddr != nil { ev.IPAddress = *ipAddr } if details != nil { ev.Details = *details } events = append(events, &ev) } if err := rows.Err(); err != nil { return nil, err } // Reverse to oldest-first. for i, j := 0, len(events)-1; i < j; i, j = i+1, j-1 { events[i], events[j] = events[j], events[i] } return events, nil } // ListAuditEventsPaged returns audit log entries matching params, newest first, // with LEFT JOINed actor/target usernames for display. Returns the matching rows // and the total count of matching rows (for pagination). // // Security: No credential material is included in audit_log rows per the // WriteAuditEvent contract; joining account usernames is safe for display. func (db *DB) ListAuditEventsPaged(p AuditQueryParams) ([]*AuditEventView, int64, error) { // Build the shared WHERE clause and args. where := " WHERE 1=1" args := []interface{}{} if p.AccountID != nil { where += ` AND (al.actor_id = ? OR al.target_id = ?)` args = append(args, *p.AccountID, *p.AccountID) } if p.EventType != "" { where += ` AND al.event_type = ?` args = append(args, p.EventType) } if p.Since != nil { where += ` AND al.event_time >= ?` args = append(args, p.Since.UTC().Format(time.RFC3339)) } // Count total matching rows first. countQuery := `SELECT COUNT(*) FROM audit_log al` + where var total int64 if err := db.sql.QueryRow(countQuery, args...).Scan(&total); err != nil { return nil, 0, fmt.Errorf("db: count audit events: %w", err) } // Fetch the page with username resolution via LEFT JOIN. query := ` SELECT al.id, al.event_time, al.event_type, al.actor_id, al.target_id, al.ip_address, al.details, COALESCE(a1.username, ''), COALESCE(a2.username, '') FROM audit_log al LEFT JOIN accounts a1 ON al.actor_id = a1.id LEFT JOIN accounts a2 ON al.target_id = a2.id` + where + ` ORDER BY al.event_time DESC, al.id DESC` pageArgs := append(args, p.Limit, p.Offset) //nolint:gocritic // intentional new slice query += ` LIMIT ? OFFSET ?` rows, err := db.sql.Query(query, pageArgs...) if err != nil { return nil, 0, fmt.Errorf("db: list audit events paged: %w", err) } defer func() { _ = rows.Close() }() var events []*AuditEventView for rows.Next() { var ev AuditEventView var eventTimeStr string var ipAddr, details *string if err := rows.Scan( &ev.ID, &eventTimeStr, &ev.EventType, &ev.ActorID, &ev.TargetID, &ipAddr, &details, &ev.ActorUsername, &ev.TargetUsername, ); err != nil { return nil, 0, fmt.Errorf("db: scan audit event view: %w", err) } ev.EventTime, err = parseTime(eventTimeStr) if err != nil { return nil, 0, err } if ipAddr != nil { ev.IPAddress = *ipAddr } if details != nil { ev.Details = *details } events = append(events, &ev) } if err := rows.Err(); err != nil { return nil, 0, err } return events, total, nil } // GetAuditEventByID fetches a single audit event by its integer primary key, // with actor/target usernames resolved via LEFT JOIN. Returns ErrNotFound if // no row matches. func (db *DB) GetAuditEventByID(id int64) (*AuditEventView, error) { row := db.sql.QueryRow(` SELECT al.id, al.event_time, al.event_type, al.actor_id, al.target_id, al.ip_address, al.details, COALESCE(a1.username, ''), COALESCE(a2.username, '') FROM audit_log al LEFT JOIN accounts a1 ON al.actor_id = a1.id LEFT JOIN accounts a2 ON al.target_id = a2.id WHERE al.id = ? `, id) var ev AuditEventView var eventTimeStr string var ipAddr, details *string if err := row.Scan( &ev.ID, &eventTimeStr, &ev.EventType, &ev.ActorID, &ev.TargetID, &ipAddr, &details, &ev.ActorUsername, &ev.TargetUsername, ); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } return nil, fmt.Errorf("db: get audit event %d: %w", id, err) } var err error ev.EventTime, err = parseTime(eventTimeStr) if err != nil { return nil, err } if ipAddr != nil { ev.IPAddress = *ipAddr } if details != nil { ev.Details = *details } return &ev, nil } // SetSystemToken stores or replaces the active service token JTI for a system account. func (db *DB) SetSystemToken(accountID int64, jti string, expiresAt time.Time) error { n := now() _, err := db.sql.Exec(` INSERT INTO system_tokens (account_id, jti, expires_at, created_at) VALUES (?, ?, ?, ?) ON CONFLICT(account_id) DO UPDATE SET jti = excluded.jti, expires_at = excluded.expires_at, created_at = excluded.created_at `, accountID, jti, expiresAt.UTC().Format(time.RFC3339), n) if err != nil { return fmt.Errorf("db: set system token for account %d: %w", accountID, err) } return nil } // GetSystemToken retrieves the active service token record for a system account. func (db *DB) GetSystemToken(accountID int64) (*model.SystemToken, error) { var st model.SystemToken var expiresAtStr, createdAtStr string err := db.sql.QueryRow(` SELECT id, account_id, jti, expires_at, created_at FROM system_tokens WHERE account_id = ? `, accountID).Scan(&st.ID, &st.AccountID, &st.JTI, &expiresAtStr, &createdAtStr) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("db: get system token: %w", err) } var parseErr error st.ExpiresAt, parseErr = parseTime(expiresAtStr) if parseErr != nil { return nil, parseErr } st.CreatedAt, parseErr = parseTime(createdAtStr) if parseErr != nil { return nil, parseErr } return &st, nil } // Lockout parameters (package-level vars so tests can override them). // // Security (F-08): per-account failed-login tracking prevents brute-force // attacks. LockoutWindow defines the rolling window during which failures // are counted; LockoutThreshold is the number of failures that triggers a // lockout; LockoutDuration is how long the account remains locked after the // threshold is reached. All three are intentionally kept as vars (not // consts) so that tests can reduce them to millisecond-scale values without // recompiling. var ( LockoutWindow = 15 * time.Minute LockoutThreshold = 10 LockoutDuration = 15 * time.Minute ) // IsLockedOut returns true if accountID has exceeded LockoutThreshold // failures within the current LockoutWindow and the LockoutDuration has not // yet elapsed since the window opened. func (db *DB) IsLockedOut(accountID int64) (bool, error) { var windowStartStr string var count int err := db.sql.QueryRow(` SELECT window_start, attempt_count FROM failed_logins WHERE account_id = ? `, accountID).Scan(&windowStartStr, &count) if errors.Is(err, sql.ErrNoRows) { return false, nil } if err != nil { return false, fmt.Errorf("db: is locked out %d: %w", accountID, err) } windowStart, err := parseTime(windowStartStr) if err != nil { return false, fmt.Errorf("db: parse lockout window_start: %w", err) } // The window has expired: the record is stale, the account is not locked. if time.Now().After(windowStart.Add(LockoutWindow)) { return false, nil } return count >= LockoutThreshold, nil } // RecordLoginFailure increments the failure counter for accountID within the // current rolling window. If the window has expired the counter resets to 1 // and the window_start is updated. Uses an UPSERT so the operation is safe // to call without a prior existence check. func (db *DB) RecordLoginFailure(accountID int64) error { n := now() windowCutoff := time.Now().Add(-LockoutWindow).UTC().Format(time.RFC3339) _, err := db.sql.Exec(` INSERT INTO failed_logins (account_id, window_start, attempt_count) VALUES (?, ?, 1) ON CONFLICT(account_id) DO UPDATE SET window_start = CASE WHEN window_start < ? THEN excluded.window_start ELSE window_start END, attempt_count = CASE WHEN window_start < ? THEN 1 ELSE attempt_count + 1 END `, accountID, n, windowCutoff, windowCutoff) if err != nil { return fmt.Errorf("db: record login failure for account %d: %w", accountID, err) } return nil } // ClearLoginFailures removes the failure record for accountID. Called on a // successful login to reset the lockout state. func (db *DB) ClearLoginFailures(accountID int64) error { _, err := db.sql.Exec(`DELETE FROM failed_logins WHERE account_id = ?`, accountID) if err != nil { return fmt.Errorf("db: clear login failures for account %d: %w", accountID, err) } return nil }