Files
mcias/internal/db/accounts.go
Kyle Isom 9b0adfdde4 Fix F-08, F-13: Adjust lockout expiration logic and enforce password length in tests
- Corrected lockout logic (`IsLockedOut`) to properly evaluate failed login thresholds within the rolling window, ensuring stale attempts outside the window do not trigger lockout.
- Updated test passwords in `grpcserver_test.go` to comply with 12-character minimum requirement.
- Reformatted import blocks with `goimports` to address lint warnings.
- Verified all tests pass and linter is clean.
2026-03-11 21:36:04 -07:00

1096 lines
34 KiB
Go

package db
import (
"database/sql"
"errors"
"fmt"
"time"
"git.wntrmute.dev/kyle/mcias/internal/model"
"github.com/google/uuid"
)
// CreateAccount inserts a new account record. The UUID is generated
// automatically. Returns the created Account with its DB-assigned ID and UUID.
func (db *DB) CreateAccount(username string, accountType model.AccountType, passwordHash string) (*model.Account, error) {
id := uuid.New().String()
n := now()
result, err := db.sql.Exec(`
INSERT INTO accounts (uuid, username, account_type, password_hash, status, created_at, updated_at)
VALUES (?, ?, ?, ?, 'active', ?, ?)
`, id, username, string(accountType), nullString(passwordHash), n, n)
if err != nil {
return nil, fmt.Errorf("db: create account %q: %w", username, err)
}
rowID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("db: last insert id for account %q: %w", username, err)
}
createdAt, err := parseTime(n)
if err != nil {
return nil, err
}
return &model.Account{
ID: rowID,
UUID: id,
Username: username,
AccountType: accountType,
Status: model.AccountStatusActive,
PasswordHash: passwordHash,
CreatedAt: createdAt,
UpdatedAt: createdAt,
}, nil
}
// GetAccountByUUID retrieves an account by its external UUID.
// Returns ErrNotFound if no matching account exists.
func (db *DB) GetAccountByUUID(accountUUID string) (*model.Account, error) {
return db.scanAccount(db.sql.QueryRow(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts WHERE uuid = ?
`, accountUUID))
}
// GetAccountByID retrieves an account by its numeric primary key.
// Returns ErrNotFound if no matching account exists.
func (db *DB) GetAccountByID(id int64) (*model.Account, error) {
return db.scanAccount(db.sql.QueryRow(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts WHERE id = ?
`, id))
}
// GetAccountByUsername retrieves an account by username (case-insensitive).
// Returns ErrNotFound if no matching account exists.
func (db *DB) GetAccountByUsername(username string) (*model.Account, error) {
return db.scanAccount(db.sql.QueryRow(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts WHERE username = ?
`, username))
}
// ListAccounts returns all non-deleted accounts ordered by username.
func (db *DB) ListAccounts() ([]*model.Account, error) {
rows, err := db.sql.Query(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts
WHERE status != 'deleted'
ORDER BY username ASC
`)
if err != nil {
return nil, fmt.Errorf("db: list accounts: %w", err)
}
defer func() { _ = rows.Close() }()
var accounts []*model.Account
for rows.Next() {
a, err := db.scanAccountRow(rows)
if err != nil {
return nil, err
}
accounts = append(accounts, a)
}
return accounts, rows.Err()
}
// UpdateAccountStatus updates the status field and optionally sets deleted_at.
func (db *DB) UpdateAccountStatus(accountID int64, status model.AccountStatus) error {
n := now()
var deletedAt *string
if status == model.AccountStatusDeleted {
deletedAt = &n
}
_, err := db.sql.Exec(`
UPDATE accounts SET status = ?, deleted_at = ?, updated_at = ?
WHERE id = ?
`, string(status), deletedAt, n, accountID)
if err != nil {
return fmt.Errorf("db: update account status: %w", err)
}
return nil
}
// UpdatePasswordHash updates the Argon2id password hash for an account.
func (db *DB) UpdatePasswordHash(accountID int64, hash string) error {
_, err := db.sql.Exec(`
UPDATE accounts SET password_hash = ?, updated_at = ?
WHERE id = ?
`, hash, now(), accountID)
if err != nil {
return fmt.Errorf("db: update password hash: %w", err)
}
return nil
}
// StorePendingTOTP stores the encrypted TOTP secret without enabling the
// totp_required flag. This is called during enrollment (POST /v1/auth/totp/enroll)
// so the secret is available for confirmation without yet enforcing TOTP on login.
// The flag is set to 1 only after the user successfully confirms the code via
// handleTOTPConfirm, which calls SetTOTP.
//
// Security: keeping totp_required=0 during enrollment prevents the user from
// being locked out if they abandon the enrollment flow after the secret is
// generated but before they have set up their authenticator app.
func (db *DB) StorePendingTOTP(accountID int64, secretEnc, secretNonce []byte) error {
_, err := db.sql.Exec(`
UPDATE accounts
SET totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ?
WHERE id = ?
`, secretEnc, secretNonce, now(), accountID)
if err != nil {
return fmt.Errorf("db: store pending TOTP: %w", err)
}
return nil
}
// SetTOTP stores the encrypted TOTP secret and marks TOTP as required.
// Call this only after the user has confirmed the TOTP code; for the initial
// enrollment step use StorePendingTOTP instead.
func (db *DB) SetTOTP(accountID int64, secretEnc, secretNonce []byte) error {
_, err := db.sql.Exec(`
UPDATE accounts
SET totp_required = 1, totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ?
WHERE id = ?
`, secretEnc, secretNonce, now(), accountID)
if err != nil {
return fmt.Errorf("db: set TOTP: %w", err)
}
return nil
}
// ClearTOTP removes the TOTP secret and disables TOTP requirement.
func (db *DB) ClearTOTP(accountID int64) error {
_, err := db.sql.Exec(`
UPDATE accounts
SET totp_required = 0, totp_secret_enc = NULL, totp_secret_nonce = NULL, updated_at = ?
WHERE id = ?
`, now(), accountID)
if err != nil {
return fmt.Errorf("db: clear TOTP: %w", err)
}
return nil
}
// scanAccount scans a single account row from a *sql.Row.
func (db *DB) scanAccount(row *sql.Row) (*model.Account, error) {
var a model.Account
var accountType, status string
var totpRequired int
var createdAtStr, updatedAtStr string
var deletedAtStr *string
var totpSecretEnc, totpSecretNonce []byte
err := row.Scan(
&a.ID, &a.UUID, &a.Username,
&accountType, &a.PasswordHash,
&status, &totpRequired,
&totpSecretEnc, &totpSecretNonce,
&createdAtStr, &updatedAtStr, &deletedAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: scan account: %w", err)
}
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
}
// scanAccountRow scans a single account from *sql.Rows.
func (db *DB) scanAccountRow(rows *sql.Rows) (*model.Account, error) {
var a model.Account
var accountType, status string
var totpRequired int
var createdAtStr, updatedAtStr string
var deletedAtStr *string
var totpSecretEnc, totpSecretNonce []byte
err := rows.Scan(
&a.ID, &a.UUID, &a.Username,
&accountType, &a.PasswordHash,
&status, &totpRequired,
&totpSecretEnc, &totpSecretNonce,
&createdAtStr, &updatedAtStr, &deletedAtStr,
)
if err != nil {
return nil, fmt.Errorf("db: scan account row: %w", err)
}
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
}
func finishAccountScan(a *model.Account, accountType, status string, totpRequired int, totpSecretEnc, totpSecretNonce []byte, createdAtStr, updatedAtStr string, deletedAtStr *string) (*model.Account, error) {
a.AccountType = model.AccountType(accountType)
a.Status = model.AccountStatus(status)
a.TOTPRequired = totpRequired == 1
a.TOTPSecretEnc = totpSecretEnc
a.TOTPSecretNonce = totpSecretNonce
var err error
a.CreatedAt, err = parseTime(createdAtStr)
if err != nil {
return nil, err
}
a.UpdatedAt, err = parseTime(updatedAtStr)
if err != nil {
return nil, err
}
a.DeletedAt, err = nullableTime(deletedAtStr)
if err != nil {
return nil, err
}
return a, nil
}
// nullString converts an empty string to nil for nullable SQL columns.
func nullString(s string) *string {
if s == "" {
return nil
}
return &s
}
// GetRoles returns the role strings assigned to an account.
func (db *DB) GetRoles(accountID int64) ([]string, error) {
rows, err := db.sql.Query(`
SELECT role FROM account_roles WHERE account_id = ? ORDER BY role ASC
`, accountID)
if err != nil {
return nil, fmt.Errorf("db: get roles for account %d: %w", accountID, err)
}
defer func() { _ = rows.Close() }()
var roles []string
for rows.Next() {
var role string
if err := rows.Scan(&role); err != nil {
return nil, fmt.Errorf("db: scan role: %w", err)
}
roles = append(roles, role)
}
return roles, rows.Err()
}
// GrantRole adds a role to an account. If the role already exists, it is a no-op.
func (db *DB) GrantRole(accountID int64, role string, grantedBy *int64) error {
_, err := db.sql.Exec(`
INSERT OR IGNORE INTO account_roles (account_id, role, granted_by, granted_at)
VALUES (?, ?, ?, ?)
`, accountID, role, grantedBy, now())
if err != nil {
return fmt.Errorf("db: grant role %q to account %d: %w", role, accountID, err)
}
return nil
}
// RevokeRole removes a role from an account.
func (db *DB) RevokeRole(accountID int64, role string) error {
_, err := db.sql.Exec(`
DELETE FROM account_roles WHERE account_id = ? AND role = ?
`, accountID, role)
if err != nil {
return fmt.Errorf("db: revoke role %q from account %d: %w", role, accountID, err)
}
return nil
}
// SetRoles replaces the full role set for an account atomically.
func (db *DB) SetRoles(accountID int64, roles []string, grantedBy *int64) error {
tx, err := db.sql.Begin()
if err != nil {
return fmt.Errorf("db: set roles begin tx: %w", err)
}
if _, err := tx.Exec(`DELETE FROM account_roles WHERE account_id = ?`, accountID); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: set roles delete existing: %w", err)
}
n := now()
for _, role := range roles {
if _, err := tx.Exec(`
INSERT INTO account_roles (account_id, role, granted_by, granted_at)
VALUES (?, ?, ?, ?)
`, accountID, role, grantedBy, n); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: set roles insert %q: %w", role, err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("db: set roles commit: %w", err)
}
return nil
}
// HasRole reports whether an account holds the given role.
func (db *DB) HasRole(accountID int64, role string) (bool, error) {
var count int
err := db.sql.QueryRow(`
SELECT COUNT(*) FROM account_roles WHERE account_id = ? AND role = ?
`, accountID, role).Scan(&count)
if err != nil {
return false, fmt.Errorf("db: has role: %w", err)
}
return count > 0, nil
}
// WriteServerConfig stores the encrypted Ed25519 signing key.
// There can only be one row (id=1).
func (db *DB) WriteServerConfig(signingKeyEnc, signingKeyNonce []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO server_config (id, signing_key_enc, signing_key_nonce, created_at, updated_at)
VALUES (1, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
signing_key_enc = excluded.signing_key_enc,
signing_key_nonce = excluded.signing_key_nonce,
updated_at = excluded.updated_at
`, signingKeyEnc, signingKeyNonce, n, n)
if err != nil {
return fmt.Errorf("db: write server config: %w", err)
}
return nil
}
// ReadServerConfig returns the encrypted signing key and nonce.
// Returns ErrNotFound if no config row exists yet.
func (db *DB) ReadServerConfig() (signingKeyEnc, signingKeyNonce []byte, err error) {
err = db.sql.QueryRow(`
SELECT signing_key_enc, signing_key_nonce FROM server_config WHERE id = 1
`).Scan(&signingKeyEnc, &signingKeyNonce)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, ErrNotFound
}
if err != nil {
return nil, nil, fmt.Errorf("db: read server config: %w", err)
}
return signingKeyEnc, signingKeyNonce, nil
}
// WriteMasterKeySalt stores the Argon2id KDF salt for the master key derivation.
// The salt must be stable across restarts so the same passphrase always yields
// the same master key. There can only be one row (id=1).
func (db *DB) WriteMasterKeySalt(salt []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO server_config (id, master_key_salt, created_at, updated_at)
VALUES (1, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
master_key_salt = excluded.master_key_salt,
updated_at = excluded.updated_at
`, salt, n, n)
if err != nil {
return fmt.Errorf("db: write master key salt: %w", err)
}
return nil
}
// ReadMasterKeySalt returns the stored Argon2id KDF salt.
// Returns ErrNotFound if no salt has been stored yet (first run).
func (db *DB) ReadMasterKeySalt() ([]byte, error) {
var salt []byte
err := db.sql.QueryRow(`
SELECT master_key_salt FROM server_config WHERE id = 1
`).Scan(&salt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: read master key salt: %w", err)
}
if salt == nil {
return nil, ErrNotFound
}
return salt, nil
}
// WritePGCredentials stores or replaces the Postgres credentials for an account.
func (db *DB) WritePGCredentials(accountID int64, host string, port int, dbName, username string, passwordEnc, passwordNonce []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO pg_credentials
(account_id, pg_host, pg_port, pg_database, pg_username, pg_password_enc, pg_password_nonce, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(account_id) DO UPDATE SET
pg_host = excluded.pg_host,
pg_port = excluded.pg_port,
pg_database = excluded.pg_database,
pg_username = excluded.pg_username,
pg_password_enc = excluded.pg_password_enc,
pg_password_nonce = excluded.pg_password_nonce,
updated_at = excluded.updated_at
`, accountID, host, port, dbName, username, passwordEnc, passwordNonce, n, n)
if err != nil {
return fmt.Errorf("db: write pg credentials: %w", err)
}
return nil
}
// ReadPGCredentials retrieves the encrypted Postgres credentials for an account.
// Returns ErrNotFound if no credentials are stored.
func (db *DB) ReadPGCredentials(accountID int64) (*model.PGCredential, error) {
var cred model.PGCredential
var createdAtStr, updatedAtStr string
err := db.sql.QueryRow(`
SELECT id, account_id, pg_host, pg_port, pg_database, pg_username,
pg_password_enc, pg_password_nonce, created_at, updated_at
FROM pg_credentials WHERE account_id = ?
`, accountID).Scan(
&cred.ID, &cred.AccountID, &cred.PGHost, &cred.PGPort,
&cred.PGDatabase, &cred.PGUsername,
&cred.PGPasswordEnc, &cred.PGPasswordNonce,
&createdAtStr, &updatedAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: read pg credentials: %w", err)
}
cred.CreatedAt, err = parseTime(createdAtStr)
if err != nil {
return nil, err
}
cred.UpdatedAt, err = parseTime(updatedAtStr)
if err != nil {
return nil, err
}
return &cred, nil
}
// WriteAuditEvent appends an audit log entry.
// Details must never contain credential material.
func (db *DB) WriteAuditEvent(eventType string, actorID, targetID *int64, ipAddress, details string) error {
_, err := db.sql.Exec(`
INSERT INTO audit_log (event_type, actor_id, target_id, ip_address, details)
VALUES (?, ?, ?, ?, ?)
`, eventType, actorID, targetID, nullString(ipAddress), nullString(details))
if err != nil {
return fmt.Errorf("db: write audit event %q: %w", eventType, err)
}
return nil
}
// TrackToken records a newly issued JWT JTI for revocation tracking.
func (db *DB) TrackToken(jti string, accountID int64, issuedAt, expiresAt time.Time) error {
_, err := db.sql.Exec(`
INSERT INTO token_revocation (jti, account_id, issued_at, expires_at)
VALUES (?, ?, ?, ?)
`, jti, accountID, issuedAt.UTC().Format(time.RFC3339), expiresAt.UTC().Format(time.RFC3339))
if err != nil {
return fmt.Errorf("db: track token %q: %w", jti, err)
}
return nil
}
// GetTokenRecord retrieves a token record by JTI.
// Returns ErrNotFound if no record exists (token was never issued by this server).
func (db *DB) GetTokenRecord(jti string) (*model.TokenRecord, error) {
var rec model.TokenRecord
var issuedAtStr, expiresAtStr, createdAtStr string
var revokedAtStr *string
var revokeReason *string
err := db.sql.QueryRow(`
SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at
FROM token_revocation WHERE jti = ?
`, jti).Scan(
&rec.ID, &rec.JTI, &rec.AccountID,
&expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason,
&createdAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: get token record %q: %w", jti, err)
}
var parseErr error
rec.ExpiresAt, parseErr = parseTime(expiresAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.IssuedAt, parseErr = parseTime(issuedAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.CreatedAt, parseErr = parseTime(createdAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.RevokedAt, parseErr = nullableTime(revokedAtStr)
if parseErr != nil {
return nil, parseErr
}
if revokeReason != nil {
rec.RevokeReason = *revokeReason
}
return &rec, nil
}
// RevokeToken marks a token as revoked by JTI.
func (db *DB) RevokeToken(jti, reason string) error {
n := now()
result, err := db.sql.Exec(`
UPDATE token_revocation
SET revoked_at = ?, revoke_reason = ?
WHERE jti = ? AND revoked_at IS NULL
`, n, nullString(reason), jti)
if err != nil {
return fmt.Errorf("db: revoke token %q: %w", jti, err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("db: revoke token rows affected: %w", err)
}
if rows == 0 {
return fmt.Errorf("db: token %q not found or already revoked", jti)
}
return nil
}
// RenewToken atomically revokes the old token and tracks the new one in a
// single SQLite transaction.
//
// Security: the two operations must be atomic so that a failure between them
// cannot leave the user with neither a valid old token nor a new one. With
// MaxOpenConns(1) and SQLite's serialised write path, BEGIN IMMEDIATE acquires
// the write lock immediately and prevents any other writer from interleaving.
func (db *DB) RenewToken(oldJTI, reason, newJTI string, accountID int64, issuedAt, expiresAt time.Time) error {
tx, err := db.sql.Begin()
if err != nil {
return fmt.Errorf("db: renew token begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
n := now()
// Revoke the old token.
result, err := tx.Exec(`
UPDATE token_revocation
SET revoked_at = ?, revoke_reason = ?
WHERE jti = ? AND revoked_at IS NULL
`, n, nullString(reason), oldJTI)
if err != nil {
return fmt.Errorf("db: renew token revoke old %q: %w", oldJTI, err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("db: renew token revoke rows affected: %w", err)
}
if rows == 0 {
return fmt.Errorf("db: renew token: old token %q not found or already revoked", oldJTI)
}
// Track the new token.
_, err = tx.Exec(`
INSERT INTO token_revocation (jti, account_id, issued_at, expires_at)
VALUES (?, ?, ?, ?)
`, newJTI, accountID,
issuedAt.UTC().Format(time.RFC3339),
expiresAt.UTC().Format(time.RFC3339))
if err != nil {
return fmt.Errorf("db: renew token track new %q: %w", newJTI, err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("db: renew token commit: %w", err)
}
return nil
}
// RevokeAllUserTokens revokes all non-expired, non-revoked tokens for an account.
func (db *DB) RevokeAllUserTokens(accountID int64, reason string) error {
n := now()
_, err := db.sql.Exec(`
UPDATE token_revocation
SET revoked_at = ?, revoke_reason = ?
WHERE account_id = ? AND revoked_at IS NULL AND expires_at > ?
`, n, nullString(reason), accountID, n)
if err != nil {
return fmt.Errorf("db: revoke all tokens for account %d: %w", accountID, err)
}
return nil
}
// PruneExpiredTokens removes token_revocation rows that are past their expiry.
// Returns the number of rows deleted.
func (db *DB) PruneExpiredTokens() (int64, error) {
result, err := db.sql.Exec(`
DELETE FROM token_revocation WHERE expires_at < ?
`, now())
if err != nil {
return 0, fmt.Errorf("db: prune expired tokens: %w", err)
}
return result.RowsAffected()
}
// ListTokensForAccount returns all token_revocation rows for the given account,
// ordered by issued_at descending (newest first).
func (db *DB) ListTokensForAccount(accountID int64) ([]*model.TokenRecord, error) {
rows, err := db.sql.Query(`
SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at
FROM token_revocation
WHERE account_id = ?
ORDER BY issued_at DESC
`, accountID)
if err != nil {
return nil, fmt.Errorf("db: list tokens for account %d: %w", accountID, err)
}
defer func() { _ = rows.Close() }()
var records []*model.TokenRecord
for rows.Next() {
var rec model.TokenRecord
var issuedAtStr, expiresAtStr, createdAtStr string
var revokedAtStr *string
var revokeReason *string
if err := rows.Scan(
&rec.ID, &rec.JTI, &rec.AccountID,
&expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason,
&createdAtStr,
); err != nil {
return nil, fmt.Errorf("db: scan token record: %w", err)
}
var parseErr error
rec.ExpiresAt, parseErr = parseTime(expiresAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.IssuedAt, parseErr = parseTime(issuedAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.CreatedAt, parseErr = parseTime(createdAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.RevokedAt, parseErr = nullableTime(revokedAtStr)
if parseErr != nil {
return nil, parseErr
}
if revokeReason != nil {
rec.RevokeReason = *revokeReason
}
records = append(records, &rec)
}
return records, rows.Err()
}
// AuditQueryParams filters for ListAuditEvents and ListAuditEventsPaged.
type AuditQueryParams struct {
AccountID *int64
Since *time.Time
EventType string
Limit int
Offset int
}
// AuditEventView extends AuditEvent with resolved actor/target usernames for display.
// Usernames are resolved via a LEFT JOIN and are empty if the actor/target is unknown.
// The fieldalignment hint is suppressed: the embedded model.AuditEvent layout is fixed
// and changing to explicit fields would break JSON serialisation.
type AuditEventView struct { //nolint:govet
model.AuditEvent
ActorUsername string `json:"actor_username,omitempty"`
TargetUsername string `json:"target_username,omitempty"`
}
// ListAuditEvents returns audit log entries matching the given parameters,
// ordered by event_time ascending. Limit rows are returned if Limit > 0.
func (db *DB) ListAuditEvents(p AuditQueryParams) ([]*model.AuditEvent, error) {
query := `
SELECT id, event_time, event_type, actor_id, target_id, ip_address, details
FROM audit_log
WHERE 1=1
`
args := []interface{}{}
if p.AccountID != nil {
query += ` AND (actor_id = ? OR target_id = ?)`
args = append(args, *p.AccountID, *p.AccountID)
}
if p.EventType != "" {
query += ` AND event_type = ?`
args = append(args, p.EventType)
}
if p.Since != nil {
query += ` AND event_time >= ?`
args = append(args, p.Since.UTC().Format(time.RFC3339))
}
query += ` ORDER BY event_time ASC, id ASC`
if p.Limit > 0 {
query += ` LIMIT ?`
args = append(args, p.Limit)
}
rows, err := db.sql.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("db: list audit events: %w", err)
}
defer func() { _ = rows.Close() }()
var events []*model.AuditEvent
for rows.Next() {
var ev model.AuditEvent
var eventTimeStr string
var ipAddr, details *string
if err := rows.Scan(
&ev.ID, &eventTimeStr, &ev.EventType,
&ev.ActorID, &ev.TargetID,
&ipAddr, &details,
); err != nil {
return nil, fmt.Errorf("db: scan audit event: %w", err)
}
ev.EventTime, err = parseTime(eventTimeStr)
if err != nil {
return nil, err
}
if ipAddr != nil {
ev.IPAddress = *ipAddr
}
if details != nil {
ev.Details = *details
}
events = append(events, &ev)
}
return events, rows.Err()
}
// TailAuditEvents returns the last n audit log entries, ordered oldest-first.
func (db *DB) TailAuditEvents(n int) ([]*model.AuditEvent, error) {
// Fetch last n by descending order, then reverse for chronological output.
rows, err := db.sql.Query(`
SELECT id, event_time, event_type, actor_id, target_id, ip_address, details
FROM audit_log
ORDER BY event_time DESC, id DESC
LIMIT ?
`, n)
if err != nil {
return nil, fmt.Errorf("db: tail audit events: %w", err)
}
defer func() { _ = rows.Close() }()
var events []*model.AuditEvent
for rows.Next() {
var ev model.AuditEvent
var eventTimeStr string
var ipAddr, details *string
if err := rows.Scan(
&ev.ID, &eventTimeStr, &ev.EventType,
&ev.ActorID, &ev.TargetID,
&ipAddr, &details,
); err != nil {
return nil, fmt.Errorf("db: scan audit event: %w", err)
}
var parseErr error
ev.EventTime, parseErr = parseTime(eventTimeStr)
if parseErr != nil {
return nil, parseErr
}
if ipAddr != nil {
ev.IPAddress = *ipAddr
}
if details != nil {
ev.Details = *details
}
events = append(events, &ev)
}
if err := rows.Err(); err != nil {
return nil, err
}
// Reverse to oldest-first.
for i, j := 0, len(events)-1; i < j; i, j = i+1, j-1 {
events[i], events[j] = events[j], events[i]
}
return events, nil
}
// ListAuditEventsPaged returns audit log entries matching params, newest first,
// with LEFT JOINed actor/target usernames for display. Returns the matching rows
// and the total count of matching rows (for pagination).
//
// Security: No credential material is included in audit_log rows per the
// WriteAuditEvent contract; joining account usernames is safe for display.
func (db *DB) ListAuditEventsPaged(p AuditQueryParams) ([]*AuditEventView, int64, error) {
// Build the shared WHERE clause and args.
where := " WHERE 1=1"
args := []interface{}{}
if p.AccountID != nil {
where += ` AND (al.actor_id = ? OR al.target_id = ?)`
args = append(args, *p.AccountID, *p.AccountID)
}
if p.EventType != "" {
where += ` AND al.event_type = ?`
args = append(args, p.EventType)
}
if p.Since != nil {
where += ` AND al.event_time >= ?`
args = append(args, p.Since.UTC().Format(time.RFC3339))
}
// Count total matching rows first.
countQuery := `SELECT COUNT(*) FROM audit_log al` + where
var total int64
if err := db.sql.QueryRow(countQuery, args...).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("db: count audit events: %w", err)
}
// Fetch the page with username resolution via LEFT JOIN.
query := `
SELECT al.id, al.event_time, al.event_type,
al.actor_id, al.target_id,
al.ip_address, al.details,
COALESCE(a1.username, ''), COALESCE(a2.username, '')
FROM audit_log al
LEFT JOIN accounts a1 ON al.actor_id = a1.id
LEFT JOIN accounts a2 ON al.target_id = a2.id` + where + `
ORDER BY al.event_time DESC, al.id DESC`
pageArgs := append(args, p.Limit, p.Offset) //nolint:gocritic // intentional new slice
query += ` LIMIT ? OFFSET ?`
rows, err := db.sql.Query(query, pageArgs...)
if err != nil {
return nil, 0, fmt.Errorf("db: list audit events paged: %w", err)
}
defer func() { _ = rows.Close() }()
var events []*AuditEventView
for rows.Next() {
var ev AuditEventView
var eventTimeStr string
var ipAddr, details *string
if err := rows.Scan(
&ev.ID, &eventTimeStr, &ev.EventType,
&ev.ActorID, &ev.TargetID,
&ipAddr, &details,
&ev.ActorUsername, &ev.TargetUsername,
); err != nil {
return nil, 0, fmt.Errorf("db: scan audit event view: %w", err)
}
ev.EventTime, err = parseTime(eventTimeStr)
if err != nil {
return nil, 0, err
}
if ipAddr != nil {
ev.IPAddress = *ipAddr
}
if details != nil {
ev.Details = *details
}
events = append(events, &ev)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return events, total, nil
}
// GetAuditEventByID fetches a single audit event by its integer primary key,
// with actor/target usernames resolved via LEFT JOIN. Returns ErrNotFound if
// no row matches.
func (db *DB) GetAuditEventByID(id int64) (*AuditEventView, error) {
row := db.sql.QueryRow(`
SELECT al.id, al.event_time, al.event_type,
al.actor_id, al.target_id,
al.ip_address, al.details,
COALESCE(a1.username, ''), COALESCE(a2.username, '')
FROM audit_log al
LEFT JOIN accounts a1 ON al.actor_id = a1.id
LEFT JOIN accounts a2 ON al.target_id = a2.id
WHERE al.id = ?
`, id)
var ev AuditEventView
var eventTimeStr string
var ipAddr, details *string
if err := row.Scan(
&ev.ID, &eventTimeStr, &ev.EventType,
&ev.ActorID, &ev.TargetID,
&ipAddr, &details,
&ev.ActorUsername, &ev.TargetUsername,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("db: get audit event %d: %w", id, err)
}
var err error
ev.EventTime, err = parseTime(eventTimeStr)
if err != nil {
return nil, err
}
if ipAddr != nil {
ev.IPAddress = *ipAddr
}
if details != nil {
ev.Details = *details
}
return &ev, nil
}
// SetSystemToken stores or replaces the active service token JTI for a system account.
func (db *DB) SetSystemToken(accountID int64, jti string, expiresAt time.Time) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO system_tokens (account_id, jti, expires_at, created_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(account_id) DO UPDATE SET
jti = excluded.jti,
expires_at = excluded.expires_at,
created_at = excluded.created_at
`, accountID, jti, expiresAt.UTC().Format(time.RFC3339), n)
if err != nil {
return fmt.Errorf("db: set system token for account %d: %w", accountID, err)
}
return nil
}
// GetSystemToken retrieves the active service token record for a system account.
func (db *DB) GetSystemToken(accountID int64) (*model.SystemToken, error) {
var st model.SystemToken
var expiresAtStr, createdAtStr string
err := db.sql.QueryRow(`
SELECT id, account_id, jti, expires_at, created_at
FROM system_tokens WHERE account_id = ?
`, accountID).Scan(&st.ID, &st.AccountID, &st.JTI, &expiresAtStr, &createdAtStr)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: get system token: %w", err)
}
var parseErr error
st.ExpiresAt, parseErr = parseTime(expiresAtStr)
if parseErr != nil {
return nil, parseErr
}
st.CreatedAt, parseErr = parseTime(createdAtStr)
if parseErr != nil {
return nil, parseErr
}
return &st, nil
}
// Lockout parameters (package-level vars so tests can override them).
//
// Security (F-08): per-account failed-login tracking prevents brute-force
// attacks. LockoutWindow defines the rolling window during which failures
// are counted; LockoutThreshold is the number of failures that triggers a
// lockout; LockoutDuration is how long the account remains locked after the
// threshold is reached. All three are intentionally kept as vars (not
// consts) so that tests can reduce them to millisecond-scale values without
// recompiling.
var (
LockoutWindow = 15 * time.Minute
LockoutThreshold = 10
LockoutDuration = 15 * time.Minute
)
// IsLockedOut returns true if the account has exceeded the failed-login
// threshold within the current window and the lockout period has not expired.
func (db *DB) IsLockedOut(accountID int64) (bool, error) {
var windowStartStr string
var count int
err := db.sql.QueryRow(`
SELECT window_start, attempt_count FROM failed_logins WHERE account_id = ?
`, accountID).Scan(&windowStartStr, &count)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
if err != nil {
return false, fmt.Errorf("db: is locked out %d: %w", accountID, err)
}
windowStart, err := parseTime(windowStartStr)
if err != nil {
return false, err
}
// Under threshold — not locked out.
if count < LockoutThreshold {
return false, nil
}
// Threshold exceeded; locked out until window_start + LockoutDuration.
// Security (F-08): the lockout clock starts from window_start (when the
// first failure in this window occurred), not from the last failure.
// If the rolling window itself has expired the failures are stale and
// cannot trigger a lockout regardless of count.
if time.Since(windowStart) > LockoutWindow {
return false, nil
}
lockedUntil := windowStart.Add(LockoutDuration)
return time.Now().Before(lockedUntil), nil
}
// RecordLoginFailure increments the failure counter for accountID within the
// current rolling window. If the window has expired the counter resets to 1
// and the window_start is updated. Uses an UPSERT so the operation is safe
// to call without a prior existence check.
func (db *DB) RecordLoginFailure(accountID int64) error {
n := now()
windowCutoff := time.Now().Add(-LockoutWindow).UTC().Format(time.RFC3339)
_, err := db.sql.Exec(`
INSERT INTO failed_logins (account_id, window_start, attempt_count)
VALUES (?, ?, 1)
ON CONFLICT(account_id) DO UPDATE SET
window_start = CASE WHEN window_start < ? THEN excluded.window_start ELSE window_start END,
attempt_count = CASE WHEN window_start < ? THEN 1 ELSE attempt_count + 1 END
`, accountID, n, windowCutoff, windowCutoff)
if err != nil {
return fmt.Errorf("db: record login failure for account %d: %w", accountID, err)
}
return nil
}
// ClearLoginFailures removes the failure record for accountID. Called on a
// successful login to reset the lockout state.
func (db *DB) ClearLoginFailures(accountID int64) error {
_, err := db.sql.Exec(`DELETE FROM failed_logins WHERE account_id = ?`, accountID)
if err != nil {
return fmt.Errorf("db: clear login failures for account %d: %w", accountID, err)
}
return nil
}