Files
mcias/internal/db/accounts.go
Kyle Isom bf9002a31c Fix F-03: make token renewal atomic
- db/accounts.go: add RenewToken(oldJTI, reason, newJTI,
  accountID, issuedAt, expiresAt) which wraps RevokeToken +
  TrackToken in a single BEGIN/COMMIT transaction; if either
  step fails the whole tx rolls back, so the user is never
  left with neither old nor new token valid
- server.go (handleRenewToken): replace separate RevokeToken +
  TrackToken calls with single RenewToken call; failure now
  returns 500 instead of silently losing revocation
- grpcserver/auth.go (RenewToken): same replacement
- db/db_test.go: TestRenewTokenAtomic verifies old token is
  revoked with correct reason, new token is tracked and not
  revoked, and a second renewal on the already-revoked old
  token returns an error
- AUDIT.md: mark F-03 as fixed
Security: without atomicity a crash/error between revoke and
  track could leave the old token active alongside the new one
  (two live tokens) or revoke the old token without tracking
  the new one (user locked out). The transaction ensures
  exactly one of the two tokens is valid at all times.
2026-03-11 20:24:32 -07:00

1001 lines
30 KiB
Go

package db
import (
"database/sql"
"errors"
"fmt"
"time"
"git.wntrmute.dev/kyle/mcias/internal/model"
"github.com/google/uuid"
)
// CreateAccount inserts a new account record. The UUID is generated
// automatically. Returns the created Account with its DB-assigned ID and UUID.
func (db *DB) CreateAccount(username string, accountType model.AccountType, passwordHash string) (*model.Account, error) {
id := uuid.New().String()
n := now()
result, err := db.sql.Exec(`
INSERT INTO accounts (uuid, username, account_type, password_hash, status, created_at, updated_at)
VALUES (?, ?, ?, ?, 'active', ?, ?)
`, id, username, string(accountType), nullString(passwordHash), n, n)
if err != nil {
return nil, fmt.Errorf("db: create account %q: %w", username, err)
}
rowID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("db: last insert id for account %q: %w", username, err)
}
createdAt, err := parseTime(n)
if err != nil {
return nil, err
}
return &model.Account{
ID: rowID,
UUID: id,
Username: username,
AccountType: accountType,
Status: model.AccountStatusActive,
PasswordHash: passwordHash,
CreatedAt: createdAt,
UpdatedAt: createdAt,
}, nil
}
// GetAccountByUUID retrieves an account by its external UUID.
// Returns ErrNotFound if no matching account exists.
func (db *DB) GetAccountByUUID(accountUUID string) (*model.Account, error) {
return db.scanAccount(db.sql.QueryRow(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts WHERE uuid = ?
`, accountUUID))
}
// GetAccountByUsername retrieves an account by username (case-insensitive).
// Returns ErrNotFound if no matching account exists.
func (db *DB) GetAccountByUsername(username string) (*model.Account, error) {
return db.scanAccount(db.sql.QueryRow(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts WHERE username = ?
`, username))
}
// ListAccounts returns all non-deleted accounts ordered by username.
func (db *DB) ListAccounts() ([]*model.Account, error) {
rows, err := db.sql.Query(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts
WHERE status != 'deleted'
ORDER BY username ASC
`)
if err != nil {
return nil, fmt.Errorf("db: list accounts: %w", err)
}
defer func() { _ = rows.Close() }()
var accounts []*model.Account
for rows.Next() {
a, err := db.scanAccountRow(rows)
if err != nil {
return nil, err
}
accounts = append(accounts, a)
}
return accounts, rows.Err()
}
// UpdateAccountStatus updates the status field and optionally sets deleted_at.
func (db *DB) UpdateAccountStatus(accountID int64, status model.AccountStatus) error {
n := now()
var deletedAt *string
if status == model.AccountStatusDeleted {
deletedAt = &n
}
_, err := db.sql.Exec(`
UPDATE accounts SET status = ?, deleted_at = ?, updated_at = ?
WHERE id = ?
`, string(status), deletedAt, n, accountID)
if err != nil {
return fmt.Errorf("db: update account status: %w", err)
}
return nil
}
// UpdatePasswordHash updates the Argon2id password hash for an account.
func (db *DB) UpdatePasswordHash(accountID int64, hash string) error {
_, err := db.sql.Exec(`
UPDATE accounts SET password_hash = ?, updated_at = ?
WHERE id = ?
`, hash, now(), accountID)
if err != nil {
return fmt.Errorf("db: update password hash: %w", err)
}
return nil
}
// StorePendingTOTP stores the encrypted TOTP secret without enabling the
// totp_required flag. This is called during enrollment (POST /v1/auth/totp/enroll)
// so the secret is available for confirmation without yet enforcing TOTP on login.
// The flag is set to 1 only after the user successfully confirms the code via
// handleTOTPConfirm, which calls SetTOTP.
//
// Security: keeping totp_required=0 during enrollment prevents the user from
// being locked out if they abandon the enrollment flow after the secret is
// generated but before they have set up their authenticator app.
func (db *DB) StorePendingTOTP(accountID int64, secretEnc, secretNonce []byte) error {
_, err := db.sql.Exec(`
UPDATE accounts
SET totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ?
WHERE id = ?
`, secretEnc, secretNonce, now(), accountID)
if err != nil {
return fmt.Errorf("db: store pending TOTP: %w", err)
}
return nil
}
// SetTOTP stores the encrypted TOTP secret and marks TOTP as required.
// Call this only after the user has confirmed the TOTP code; for the initial
// enrollment step use StorePendingTOTP instead.
func (db *DB) SetTOTP(accountID int64, secretEnc, secretNonce []byte) error {
_, err := db.sql.Exec(`
UPDATE accounts
SET totp_required = 1, totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ?
WHERE id = ?
`, secretEnc, secretNonce, now(), accountID)
if err != nil {
return fmt.Errorf("db: set TOTP: %w", err)
}
return nil
}
// ClearTOTP removes the TOTP secret and disables TOTP requirement.
func (db *DB) ClearTOTP(accountID int64) error {
_, err := db.sql.Exec(`
UPDATE accounts
SET totp_required = 0, totp_secret_enc = NULL, totp_secret_nonce = NULL, updated_at = ?
WHERE id = ?
`, now(), accountID)
if err != nil {
return fmt.Errorf("db: clear TOTP: %w", err)
}
return nil
}
// scanAccount scans a single account row from a *sql.Row.
func (db *DB) scanAccount(row *sql.Row) (*model.Account, error) {
var a model.Account
var accountType, status string
var totpRequired int
var createdAtStr, updatedAtStr string
var deletedAtStr *string
var totpSecretEnc, totpSecretNonce []byte
err := row.Scan(
&a.ID, &a.UUID, &a.Username,
&accountType, &a.PasswordHash,
&status, &totpRequired,
&totpSecretEnc, &totpSecretNonce,
&createdAtStr, &updatedAtStr, &deletedAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: scan account: %w", err)
}
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
}
// scanAccountRow scans a single account from *sql.Rows.
func (db *DB) scanAccountRow(rows *sql.Rows) (*model.Account, error) {
var a model.Account
var accountType, status string
var totpRequired int
var createdAtStr, updatedAtStr string
var deletedAtStr *string
var totpSecretEnc, totpSecretNonce []byte
err := rows.Scan(
&a.ID, &a.UUID, &a.Username,
&accountType, &a.PasswordHash,
&status, &totpRequired,
&totpSecretEnc, &totpSecretNonce,
&createdAtStr, &updatedAtStr, &deletedAtStr,
)
if err != nil {
return nil, fmt.Errorf("db: scan account row: %w", err)
}
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
}
func finishAccountScan(a *model.Account, accountType, status string, totpRequired int, totpSecretEnc, totpSecretNonce []byte, createdAtStr, updatedAtStr string, deletedAtStr *string) (*model.Account, error) {
a.AccountType = model.AccountType(accountType)
a.Status = model.AccountStatus(status)
a.TOTPRequired = totpRequired == 1
a.TOTPSecretEnc = totpSecretEnc
a.TOTPSecretNonce = totpSecretNonce
var err error
a.CreatedAt, err = parseTime(createdAtStr)
if err != nil {
return nil, err
}
a.UpdatedAt, err = parseTime(updatedAtStr)
if err != nil {
return nil, err
}
a.DeletedAt, err = nullableTime(deletedAtStr)
if err != nil {
return nil, err
}
return a, nil
}
// nullString converts an empty string to nil for nullable SQL columns.
func nullString(s string) *string {
if s == "" {
return nil
}
return &s
}
// GetRoles returns the role strings assigned to an account.
func (db *DB) GetRoles(accountID int64) ([]string, error) {
rows, err := db.sql.Query(`
SELECT role FROM account_roles WHERE account_id = ? ORDER BY role ASC
`, accountID)
if err != nil {
return nil, fmt.Errorf("db: get roles for account %d: %w", accountID, err)
}
defer func() { _ = rows.Close() }()
var roles []string
for rows.Next() {
var role string
if err := rows.Scan(&role); err != nil {
return nil, fmt.Errorf("db: scan role: %w", err)
}
roles = append(roles, role)
}
return roles, rows.Err()
}
// GrantRole adds a role to an account. If the role already exists, it is a no-op.
func (db *DB) GrantRole(accountID int64, role string, grantedBy *int64) error {
_, err := db.sql.Exec(`
INSERT OR IGNORE INTO account_roles (account_id, role, granted_by, granted_at)
VALUES (?, ?, ?, ?)
`, accountID, role, grantedBy, now())
if err != nil {
return fmt.Errorf("db: grant role %q to account %d: %w", role, accountID, err)
}
return nil
}
// RevokeRole removes a role from an account.
func (db *DB) RevokeRole(accountID int64, role string) error {
_, err := db.sql.Exec(`
DELETE FROM account_roles WHERE account_id = ? AND role = ?
`, accountID, role)
if err != nil {
return fmt.Errorf("db: revoke role %q from account %d: %w", role, accountID, err)
}
return nil
}
// SetRoles replaces the full role set for an account atomically.
func (db *DB) SetRoles(accountID int64, roles []string, grantedBy *int64) error {
tx, err := db.sql.Begin()
if err != nil {
return fmt.Errorf("db: set roles begin tx: %w", err)
}
if _, err := tx.Exec(`DELETE FROM account_roles WHERE account_id = ?`, accountID); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: set roles delete existing: %w", err)
}
n := now()
for _, role := range roles {
if _, err := tx.Exec(`
INSERT INTO account_roles (account_id, role, granted_by, granted_at)
VALUES (?, ?, ?, ?)
`, accountID, role, grantedBy, n); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: set roles insert %q: %w", role, err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("db: set roles commit: %w", err)
}
return nil
}
// HasRole reports whether an account holds the given role.
func (db *DB) HasRole(accountID int64, role string) (bool, error) {
var count int
err := db.sql.QueryRow(`
SELECT COUNT(*) FROM account_roles WHERE account_id = ? AND role = ?
`, accountID, role).Scan(&count)
if err != nil {
return false, fmt.Errorf("db: has role: %w", err)
}
return count > 0, nil
}
// WriteServerConfig stores the encrypted Ed25519 signing key.
// There can only be one row (id=1).
func (db *DB) WriteServerConfig(signingKeyEnc, signingKeyNonce []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO server_config (id, signing_key_enc, signing_key_nonce, created_at, updated_at)
VALUES (1, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
signing_key_enc = excluded.signing_key_enc,
signing_key_nonce = excluded.signing_key_nonce,
updated_at = excluded.updated_at
`, signingKeyEnc, signingKeyNonce, n, n)
if err != nil {
return fmt.Errorf("db: write server config: %w", err)
}
return nil
}
// ReadServerConfig returns the encrypted signing key and nonce.
// Returns ErrNotFound if no config row exists yet.
func (db *DB) ReadServerConfig() (signingKeyEnc, signingKeyNonce []byte, err error) {
err = db.sql.QueryRow(`
SELECT signing_key_enc, signing_key_nonce FROM server_config WHERE id = 1
`).Scan(&signingKeyEnc, &signingKeyNonce)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, ErrNotFound
}
if err != nil {
return nil, nil, fmt.Errorf("db: read server config: %w", err)
}
return signingKeyEnc, signingKeyNonce, nil
}
// WriteMasterKeySalt stores the Argon2id KDF salt for the master key derivation.
// The salt must be stable across restarts so the same passphrase always yields
// the same master key. There can only be one row (id=1).
func (db *DB) WriteMasterKeySalt(salt []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO server_config (id, master_key_salt, created_at, updated_at)
VALUES (1, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
master_key_salt = excluded.master_key_salt,
updated_at = excluded.updated_at
`, salt, n, n)
if err != nil {
return fmt.Errorf("db: write master key salt: %w", err)
}
return nil
}
// ReadMasterKeySalt returns the stored Argon2id KDF salt.
// Returns ErrNotFound if no salt has been stored yet (first run).
func (db *DB) ReadMasterKeySalt() ([]byte, error) {
var salt []byte
err := db.sql.QueryRow(`
SELECT master_key_salt FROM server_config WHERE id = 1
`).Scan(&salt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: read master key salt: %w", err)
}
if salt == nil {
return nil, ErrNotFound
}
return salt, nil
}
// WritePGCredentials stores or replaces the Postgres credentials for an account.
func (db *DB) WritePGCredentials(accountID int64, host string, port int, dbName, username string, passwordEnc, passwordNonce []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO pg_credentials
(account_id, pg_host, pg_port, pg_database, pg_username, pg_password_enc, pg_password_nonce, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(account_id) DO UPDATE SET
pg_host = excluded.pg_host,
pg_port = excluded.pg_port,
pg_database = excluded.pg_database,
pg_username = excluded.pg_username,
pg_password_enc = excluded.pg_password_enc,
pg_password_nonce = excluded.pg_password_nonce,
updated_at = excluded.updated_at
`, accountID, host, port, dbName, username, passwordEnc, passwordNonce, n, n)
if err != nil {
return fmt.Errorf("db: write pg credentials: %w", err)
}
return nil
}
// ReadPGCredentials retrieves the encrypted Postgres credentials for an account.
// Returns ErrNotFound if no credentials are stored.
func (db *DB) ReadPGCredentials(accountID int64) (*model.PGCredential, error) {
var cred model.PGCredential
var createdAtStr, updatedAtStr string
err := db.sql.QueryRow(`
SELECT id, account_id, pg_host, pg_port, pg_database, pg_username,
pg_password_enc, pg_password_nonce, created_at, updated_at
FROM pg_credentials WHERE account_id = ?
`, accountID).Scan(
&cred.ID, &cred.AccountID, &cred.PGHost, &cred.PGPort,
&cred.PGDatabase, &cred.PGUsername,
&cred.PGPasswordEnc, &cred.PGPasswordNonce,
&createdAtStr, &updatedAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: read pg credentials: %w", err)
}
cred.CreatedAt, err = parseTime(createdAtStr)
if err != nil {
return nil, err
}
cred.UpdatedAt, err = parseTime(updatedAtStr)
if err != nil {
return nil, err
}
return &cred, nil
}
// WriteAuditEvent appends an audit log entry.
// Details must never contain credential material.
func (db *DB) WriteAuditEvent(eventType string, actorID, targetID *int64, ipAddress, details string) error {
_, err := db.sql.Exec(`
INSERT INTO audit_log (event_type, actor_id, target_id, ip_address, details)
VALUES (?, ?, ?, ?, ?)
`, eventType, actorID, targetID, nullString(ipAddress), nullString(details))
if err != nil {
return fmt.Errorf("db: write audit event %q: %w", eventType, err)
}
return nil
}
// TrackToken records a newly issued JWT JTI for revocation tracking.
func (db *DB) TrackToken(jti string, accountID int64, issuedAt, expiresAt time.Time) error {
_, err := db.sql.Exec(`
INSERT INTO token_revocation (jti, account_id, issued_at, expires_at)
VALUES (?, ?, ?, ?)
`, jti, accountID, issuedAt.UTC().Format(time.RFC3339), expiresAt.UTC().Format(time.RFC3339))
if err != nil {
return fmt.Errorf("db: track token %q: %w", jti, err)
}
return nil
}
// GetTokenRecord retrieves a token record by JTI.
// Returns ErrNotFound if no record exists (token was never issued by this server).
func (db *DB) GetTokenRecord(jti string) (*model.TokenRecord, error) {
var rec model.TokenRecord
var issuedAtStr, expiresAtStr, createdAtStr string
var revokedAtStr *string
var revokeReason *string
err := db.sql.QueryRow(`
SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at
FROM token_revocation WHERE jti = ?
`, jti).Scan(
&rec.ID, &rec.JTI, &rec.AccountID,
&expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason,
&createdAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: get token record %q: %w", jti, err)
}
var parseErr error
rec.ExpiresAt, parseErr = parseTime(expiresAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.IssuedAt, parseErr = parseTime(issuedAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.CreatedAt, parseErr = parseTime(createdAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.RevokedAt, parseErr = nullableTime(revokedAtStr)
if parseErr != nil {
return nil, parseErr
}
if revokeReason != nil {
rec.RevokeReason = *revokeReason
}
return &rec, nil
}
// RevokeToken marks a token as revoked by JTI.
func (db *DB) RevokeToken(jti, reason string) error {
n := now()
result, err := db.sql.Exec(`
UPDATE token_revocation
SET revoked_at = ?, revoke_reason = ?
WHERE jti = ? AND revoked_at IS NULL
`, n, nullString(reason), jti)
if err != nil {
return fmt.Errorf("db: revoke token %q: %w", jti, err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("db: revoke token rows affected: %w", err)
}
if rows == 0 {
return fmt.Errorf("db: token %q not found or already revoked", jti)
}
return nil
}
// RenewToken atomically revokes the old token and tracks the new one in a
// single SQLite transaction.
//
// Security: the two operations must be atomic so that a failure between them
// cannot leave the user with neither a valid old token nor a new one. With
// MaxOpenConns(1) and SQLite's serialised write path, BEGIN IMMEDIATE acquires
// the write lock immediately and prevents any other writer from interleaving.
func (db *DB) RenewToken(oldJTI, reason, newJTI string, accountID int64, issuedAt, expiresAt time.Time) error {
tx, err := db.sql.Begin()
if err != nil {
return fmt.Errorf("db: renew token begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
n := now()
// Revoke the old token.
result, err := tx.Exec(`
UPDATE token_revocation
SET revoked_at = ?, revoke_reason = ?
WHERE jti = ? AND revoked_at IS NULL
`, n, nullString(reason), oldJTI)
if err != nil {
return fmt.Errorf("db: renew token revoke old %q: %w", oldJTI, err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("db: renew token revoke rows affected: %w", err)
}
if rows == 0 {
return fmt.Errorf("db: renew token: old token %q not found or already revoked", oldJTI)
}
// Track the new token.
_, err = tx.Exec(`
INSERT INTO token_revocation (jti, account_id, issued_at, expires_at)
VALUES (?, ?, ?, ?)
`, newJTI, accountID,
issuedAt.UTC().Format(time.RFC3339),
expiresAt.UTC().Format(time.RFC3339))
if err != nil {
return fmt.Errorf("db: renew token track new %q: %w", newJTI, err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("db: renew token commit: %w", err)
}
return nil
}
// RevokeAllUserTokens revokes all non-expired, non-revoked tokens for an account.
func (db *DB) RevokeAllUserTokens(accountID int64, reason string) error {
n := now()
_, err := db.sql.Exec(`
UPDATE token_revocation
SET revoked_at = ?, revoke_reason = ?
WHERE account_id = ? AND revoked_at IS NULL AND expires_at > ?
`, n, nullString(reason), accountID, n)
if err != nil {
return fmt.Errorf("db: revoke all tokens for account %d: %w", accountID, err)
}
return nil
}
// PruneExpiredTokens removes token_revocation rows that are past their expiry.
// Returns the number of rows deleted.
func (db *DB) PruneExpiredTokens() (int64, error) {
result, err := db.sql.Exec(`
DELETE FROM token_revocation WHERE expires_at < ?
`, now())
if err != nil {
return 0, fmt.Errorf("db: prune expired tokens: %w", err)
}
return result.RowsAffected()
}
// ListTokensForAccount returns all token_revocation rows for the given account,
// ordered by issued_at descending (newest first).
func (db *DB) ListTokensForAccount(accountID int64) ([]*model.TokenRecord, error) {
rows, err := db.sql.Query(`
SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at
FROM token_revocation
WHERE account_id = ?
ORDER BY issued_at DESC
`, accountID)
if err != nil {
return nil, fmt.Errorf("db: list tokens for account %d: %w", accountID, err)
}
defer func() { _ = rows.Close() }()
var records []*model.TokenRecord
for rows.Next() {
var rec model.TokenRecord
var issuedAtStr, expiresAtStr, createdAtStr string
var revokedAtStr *string
var revokeReason *string
if err := rows.Scan(
&rec.ID, &rec.JTI, &rec.AccountID,
&expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason,
&createdAtStr,
); err != nil {
return nil, fmt.Errorf("db: scan token record: %w", err)
}
var parseErr error
rec.ExpiresAt, parseErr = parseTime(expiresAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.IssuedAt, parseErr = parseTime(issuedAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.CreatedAt, parseErr = parseTime(createdAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.RevokedAt, parseErr = nullableTime(revokedAtStr)
if parseErr != nil {
return nil, parseErr
}
if revokeReason != nil {
rec.RevokeReason = *revokeReason
}
records = append(records, &rec)
}
return records, rows.Err()
}
// AuditQueryParams filters for ListAuditEvents and ListAuditEventsPaged.
type AuditQueryParams struct {
AccountID *int64
Since *time.Time
EventType string
Limit int
Offset int
}
// AuditEventView extends AuditEvent with resolved actor/target usernames for display.
// Usernames are resolved via a LEFT JOIN and are empty if the actor/target is unknown.
// The fieldalignment hint is suppressed: the embedded model.AuditEvent layout is fixed
// and changing to explicit fields would break JSON serialisation.
type AuditEventView struct { //nolint:govet
model.AuditEvent
ActorUsername string `json:"actor_username,omitempty"`
TargetUsername string `json:"target_username,omitempty"`
}
// ListAuditEvents returns audit log entries matching the given parameters,
// ordered by event_time ascending. Limit rows are returned if Limit > 0.
func (db *DB) ListAuditEvents(p AuditQueryParams) ([]*model.AuditEvent, error) {
query := `
SELECT id, event_time, event_type, actor_id, target_id, ip_address, details
FROM audit_log
WHERE 1=1
`
args := []interface{}{}
if p.AccountID != nil {
query += ` AND (actor_id = ? OR target_id = ?)`
args = append(args, *p.AccountID, *p.AccountID)
}
if p.EventType != "" {
query += ` AND event_type = ?`
args = append(args, p.EventType)
}
if p.Since != nil {
query += ` AND event_time >= ?`
args = append(args, p.Since.UTC().Format(time.RFC3339))
}
query += ` ORDER BY event_time ASC, id ASC`
if p.Limit > 0 {
query += ` LIMIT ?`
args = append(args, p.Limit)
}
rows, err := db.sql.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("db: list audit events: %w", err)
}
defer func() { _ = rows.Close() }()
var events []*model.AuditEvent
for rows.Next() {
var ev model.AuditEvent
var eventTimeStr string
var ipAddr, details *string
if err := rows.Scan(
&ev.ID, &eventTimeStr, &ev.EventType,
&ev.ActorID, &ev.TargetID,
&ipAddr, &details,
); err != nil {
return nil, fmt.Errorf("db: scan audit event: %w", err)
}
ev.EventTime, err = parseTime(eventTimeStr)
if err != nil {
return nil, err
}
if ipAddr != nil {
ev.IPAddress = *ipAddr
}
if details != nil {
ev.Details = *details
}
events = append(events, &ev)
}
return events, rows.Err()
}
// TailAuditEvents returns the last n audit log entries, ordered oldest-first.
func (db *DB) TailAuditEvents(n int) ([]*model.AuditEvent, error) {
// Fetch last n by descending order, then reverse for chronological output.
rows, err := db.sql.Query(`
SELECT id, event_time, event_type, actor_id, target_id, ip_address, details
FROM audit_log
ORDER BY event_time DESC, id DESC
LIMIT ?
`, n)
if err != nil {
return nil, fmt.Errorf("db: tail audit events: %w", err)
}
defer func() { _ = rows.Close() }()
var events []*model.AuditEvent
for rows.Next() {
var ev model.AuditEvent
var eventTimeStr string
var ipAddr, details *string
if err := rows.Scan(
&ev.ID, &eventTimeStr, &ev.EventType,
&ev.ActorID, &ev.TargetID,
&ipAddr, &details,
); err != nil {
return nil, fmt.Errorf("db: scan audit event: %w", err)
}
var parseErr error
ev.EventTime, parseErr = parseTime(eventTimeStr)
if parseErr != nil {
return nil, parseErr
}
if ipAddr != nil {
ev.IPAddress = *ipAddr
}
if details != nil {
ev.Details = *details
}
events = append(events, &ev)
}
if err := rows.Err(); err != nil {
return nil, err
}
// Reverse to oldest-first.
for i, j := 0, len(events)-1; i < j; i, j = i+1, j-1 {
events[i], events[j] = events[j], events[i]
}
return events, nil
}
// ListAuditEventsPaged returns audit log entries matching params, newest first,
// with LEFT JOINed actor/target usernames for display. Returns the matching rows
// and the total count of matching rows (for pagination).
//
// Security: No credential material is included in audit_log rows per the
// WriteAuditEvent contract; joining account usernames is safe for display.
func (db *DB) ListAuditEventsPaged(p AuditQueryParams) ([]*AuditEventView, int64, error) {
// Build the shared WHERE clause and args.
where := " WHERE 1=1"
args := []interface{}{}
if p.AccountID != nil {
where += ` AND (al.actor_id = ? OR al.target_id = ?)`
args = append(args, *p.AccountID, *p.AccountID)
}
if p.EventType != "" {
where += ` AND al.event_type = ?`
args = append(args, p.EventType)
}
if p.Since != nil {
where += ` AND al.event_time >= ?`
args = append(args, p.Since.UTC().Format(time.RFC3339))
}
// Count total matching rows first.
countQuery := `SELECT COUNT(*) FROM audit_log al` + where
var total int64
if err := db.sql.QueryRow(countQuery, args...).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("db: count audit events: %w", err)
}
// Fetch the page with username resolution via LEFT JOIN.
query := `
SELECT al.id, al.event_time, al.event_type,
al.actor_id, al.target_id,
al.ip_address, al.details,
COALESCE(a1.username, ''), COALESCE(a2.username, '')
FROM audit_log al
LEFT JOIN accounts a1 ON al.actor_id = a1.id
LEFT JOIN accounts a2 ON al.target_id = a2.id` + where + `
ORDER BY al.event_time DESC, al.id DESC`
pageArgs := append(args, p.Limit, p.Offset) //nolint:gocritic // intentional new slice
query += ` LIMIT ? OFFSET ?`
rows, err := db.sql.Query(query, pageArgs...)
if err != nil {
return nil, 0, fmt.Errorf("db: list audit events paged: %w", err)
}
defer func() { _ = rows.Close() }()
var events []*AuditEventView
for rows.Next() {
var ev AuditEventView
var eventTimeStr string
var ipAddr, details *string
if err := rows.Scan(
&ev.ID, &eventTimeStr, &ev.EventType,
&ev.ActorID, &ev.TargetID,
&ipAddr, &details,
&ev.ActorUsername, &ev.TargetUsername,
); err != nil {
return nil, 0, fmt.Errorf("db: scan audit event view: %w", err)
}
ev.EventTime, err = parseTime(eventTimeStr)
if err != nil {
return nil, 0, err
}
if ipAddr != nil {
ev.IPAddress = *ipAddr
}
if details != nil {
ev.Details = *details
}
events = append(events, &ev)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return events, total, nil
}
// GetAuditEventByID fetches a single audit event by its integer primary key,
// with actor/target usernames resolved via LEFT JOIN. Returns ErrNotFound if
// no row matches.
func (db *DB) GetAuditEventByID(id int64) (*AuditEventView, error) {
row := db.sql.QueryRow(`
SELECT al.id, al.event_time, al.event_type,
al.actor_id, al.target_id,
al.ip_address, al.details,
COALESCE(a1.username, ''), COALESCE(a2.username, '')
FROM audit_log al
LEFT JOIN accounts a1 ON al.actor_id = a1.id
LEFT JOIN accounts a2 ON al.target_id = a2.id
WHERE al.id = ?
`, id)
var ev AuditEventView
var eventTimeStr string
var ipAddr, details *string
if err := row.Scan(
&ev.ID, &eventTimeStr, &ev.EventType,
&ev.ActorID, &ev.TargetID,
&ipAddr, &details,
&ev.ActorUsername, &ev.TargetUsername,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return nil, fmt.Errorf("db: get audit event %d: %w", id, err)
}
var err error
ev.EventTime, err = parseTime(eventTimeStr)
if err != nil {
return nil, err
}
if ipAddr != nil {
ev.IPAddress = *ipAddr
}
if details != nil {
ev.Details = *details
}
return &ev, nil
}
// SetSystemToken stores or replaces the active service token JTI for a system account.
func (db *DB) SetSystemToken(accountID int64, jti string, expiresAt time.Time) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO system_tokens (account_id, jti, expires_at, created_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(account_id) DO UPDATE SET
jti = excluded.jti,
expires_at = excluded.expires_at,
created_at = excluded.created_at
`, accountID, jti, expiresAt.UTC().Format(time.RFC3339), n)
if err != nil {
return fmt.Errorf("db: set system token for account %d: %w", accountID, err)
}
return nil
}
// GetSystemToken retrieves the active service token record for a system account.
func (db *DB) GetSystemToken(accountID int64) (*model.SystemToken, error) {
var st model.SystemToken
var expiresAtStr, createdAtStr string
err := db.sql.QueryRow(`
SELECT id, account_id, jti, expires_at, created_at
FROM system_tokens WHERE account_id = ?
`, accountID).Scan(&st.ID, &st.AccountID, &st.JTI, &expiresAtStr, &createdAtStr)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: get system token: %w", err)
}
var parseErr error
st.ExpiresAt, parseErr = parseTime(expiresAtStr)
if parseErr != nil {
return nil, parseErr
}
st.CreatedAt, parseErr = parseTime(createdAtStr)
if parseErr != nil {
return nil, parseErr
}
return &st, nil
}