- ui/ui.go: add pendingLogin struct and pendingLogins sync.Map to UIServer; add issueTOTPNonce (generates 128-bit random nonce, stores accountID with 90s TTL) and consumeTOTPNonce (single-use, expiry-checked LoadAndDelete); add dummyHash() method - ui/handlers_auth.go: split handleLoginPost into step 1 (password verify → issue nonce) and step 2 (handleTOTPStep, consume nonce → validate TOTP) via a new finishLogin helper; password never transmitted or stored after step 1 - ui/ui_test.go: refactor newTestMux to reuse new newTestUIServer; add TestTOTPNonceIssuedAndConsumed, TestTOTPNonceUnknownRejected, TestTOTPNonceExpired, and TestLoginPostPasswordNotInTOTPForm; 11/11 tests pass - web/templates/fragments/totp_step.html: replace 'name=password' hidden field with 'name=totp_nonce' - db/accounts.go: add GetAccountByID for TOTP step lookup - AUDIT.md: mark F-02 as fixed Security: the plaintext password previously survived two HTTP round-trips and lived in the browser DOM during the TOTP step. The nonce approach means the password is verified once and immediately discarded; only an opaque random token tied to an account ID (never a credential) crosses the wire on step 2. Nonces are single-use and expire after 90 seconds to limit the window if one is captured.
1013 lines
30 KiB
Go
1013 lines
30 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"git.wntrmute.dev/kyle/mcias/internal/model"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
// CreateAccount inserts a new account record. The UUID is generated
|
|
// automatically. Returns the created Account with its DB-assigned ID and UUID.
|
|
func (db *DB) CreateAccount(username string, accountType model.AccountType, passwordHash string) (*model.Account, error) {
|
|
id := uuid.New().String()
|
|
n := now()
|
|
|
|
result, err := db.sql.Exec(`
|
|
INSERT INTO accounts (uuid, username, account_type, password_hash, status, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, 'active', ?, ?)
|
|
`, id, username, string(accountType), nullString(passwordHash), n, n)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db: create account %q: %w", username, err)
|
|
}
|
|
|
|
rowID, err := result.LastInsertId()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db: last insert id for account %q: %w", username, err)
|
|
}
|
|
|
|
createdAt, err := parseTime(n)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &model.Account{
|
|
ID: rowID,
|
|
UUID: id,
|
|
Username: username,
|
|
AccountType: accountType,
|
|
Status: model.AccountStatusActive,
|
|
PasswordHash: passwordHash,
|
|
CreatedAt: createdAt,
|
|
UpdatedAt: createdAt,
|
|
}, nil
|
|
}
|
|
|
|
// GetAccountByUUID retrieves an account by its external UUID.
|
|
// Returns ErrNotFound if no matching account exists.
|
|
func (db *DB) GetAccountByUUID(accountUUID string) (*model.Account, error) {
|
|
return db.scanAccount(db.sql.QueryRow(`
|
|
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
|
|
status, totp_required,
|
|
totp_secret_enc, totp_secret_nonce,
|
|
created_at, updated_at, deleted_at
|
|
FROM accounts WHERE uuid = ?
|
|
`, accountUUID))
|
|
}
|
|
|
|
// GetAccountByID retrieves an account by its numeric primary key.
|
|
// Returns ErrNotFound if no matching account exists.
|
|
func (db *DB) GetAccountByID(id int64) (*model.Account, error) {
|
|
return db.scanAccount(db.sql.QueryRow(`
|
|
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
|
|
status, totp_required,
|
|
totp_secret_enc, totp_secret_nonce,
|
|
created_at, updated_at, deleted_at
|
|
FROM accounts WHERE id = ?
|
|
`, id))
|
|
}
|
|
|
|
// GetAccountByUsername retrieves an account by username (case-insensitive).
|
|
// Returns ErrNotFound if no matching account exists.
|
|
func (db *DB) GetAccountByUsername(username string) (*model.Account, error) {
|
|
return db.scanAccount(db.sql.QueryRow(`
|
|
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
|
|
status, totp_required,
|
|
totp_secret_enc, totp_secret_nonce,
|
|
created_at, updated_at, deleted_at
|
|
FROM accounts WHERE username = ?
|
|
`, username))
|
|
}
|
|
|
|
// ListAccounts returns all non-deleted accounts ordered by username.
|
|
func (db *DB) ListAccounts() ([]*model.Account, error) {
|
|
rows, err := db.sql.Query(`
|
|
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
|
|
status, totp_required,
|
|
totp_secret_enc, totp_secret_nonce,
|
|
created_at, updated_at, deleted_at
|
|
FROM accounts
|
|
WHERE status != 'deleted'
|
|
ORDER BY username ASC
|
|
`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db: list accounts: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var accounts []*model.Account
|
|
for rows.Next() {
|
|
a, err := db.scanAccountRow(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
accounts = append(accounts, a)
|
|
}
|
|
return accounts, rows.Err()
|
|
}
|
|
|
|
// UpdateAccountStatus updates the status field and optionally sets deleted_at.
|
|
func (db *DB) UpdateAccountStatus(accountID int64, status model.AccountStatus) error {
|
|
n := now()
|
|
var deletedAt *string
|
|
if status == model.AccountStatusDeleted {
|
|
deletedAt = &n
|
|
}
|
|
|
|
_, err := db.sql.Exec(`
|
|
UPDATE accounts SET status = ?, deleted_at = ?, updated_at = ?
|
|
WHERE id = ?
|
|
`, string(status), deletedAt, n, accountID)
|
|
if err != nil {
|
|
return fmt.Errorf("db: update account status: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdatePasswordHash updates the Argon2id password hash for an account.
|
|
func (db *DB) UpdatePasswordHash(accountID int64, hash string) error {
|
|
_, err := db.sql.Exec(`
|
|
UPDATE accounts SET password_hash = ?, updated_at = ?
|
|
WHERE id = ?
|
|
`, hash, now(), accountID)
|
|
if err != nil {
|
|
return fmt.Errorf("db: update password hash: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// StorePendingTOTP stores the encrypted TOTP secret without enabling the
|
|
// totp_required flag. This is called during enrollment (POST /v1/auth/totp/enroll)
|
|
// so the secret is available for confirmation without yet enforcing TOTP on login.
|
|
// The flag is set to 1 only after the user successfully confirms the code via
|
|
// handleTOTPConfirm, which calls SetTOTP.
|
|
//
|
|
// Security: keeping totp_required=0 during enrollment prevents the user from
|
|
// being locked out if they abandon the enrollment flow after the secret is
|
|
// generated but before they have set up their authenticator app.
|
|
func (db *DB) StorePendingTOTP(accountID int64, secretEnc, secretNonce []byte) error {
|
|
_, err := db.sql.Exec(`
|
|
UPDATE accounts
|
|
SET totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ?
|
|
WHERE id = ?
|
|
`, secretEnc, secretNonce, now(), accountID)
|
|
if err != nil {
|
|
return fmt.Errorf("db: store pending TOTP: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SetTOTP stores the encrypted TOTP secret and marks TOTP as required.
|
|
// Call this only after the user has confirmed the TOTP code; for the initial
|
|
// enrollment step use StorePendingTOTP instead.
|
|
func (db *DB) SetTOTP(accountID int64, secretEnc, secretNonce []byte) error {
|
|
_, err := db.sql.Exec(`
|
|
UPDATE accounts
|
|
SET totp_required = 1, totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ?
|
|
WHERE id = ?
|
|
`, secretEnc, secretNonce, now(), accountID)
|
|
if err != nil {
|
|
return fmt.Errorf("db: set TOTP: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ClearTOTP removes the TOTP secret and disables TOTP requirement.
|
|
func (db *DB) ClearTOTP(accountID int64) error {
|
|
_, err := db.sql.Exec(`
|
|
UPDATE accounts
|
|
SET totp_required = 0, totp_secret_enc = NULL, totp_secret_nonce = NULL, updated_at = ?
|
|
WHERE id = ?
|
|
`, now(), accountID)
|
|
if err != nil {
|
|
return fmt.Errorf("db: clear TOTP: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// scanAccount scans a single account row from a *sql.Row.
|
|
func (db *DB) scanAccount(row *sql.Row) (*model.Account, error) {
|
|
var a model.Account
|
|
var accountType, status string
|
|
var totpRequired int
|
|
var createdAtStr, updatedAtStr string
|
|
var deletedAtStr *string
|
|
var totpSecretEnc, totpSecretNonce []byte
|
|
|
|
err := row.Scan(
|
|
&a.ID, &a.UUID, &a.Username,
|
|
&accountType, &a.PasswordHash,
|
|
&status, &totpRequired,
|
|
&totpSecretEnc, &totpSecretNonce,
|
|
&createdAtStr, &updatedAtStr, &deletedAtStr,
|
|
)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db: scan account: %w", err)
|
|
}
|
|
|
|
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
|
|
}
|
|
|
|
// scanAccountRow scans a single account from *sql.Rows.
|
|
func (db *DB) scanAccountRow(rows *sql.Rows) (*model.Account, error) {
|
|
var a model.Account
|
|
var accountType, status string
|
|
var totpRequired int
|
|
var createdAtStr, updatedAtStr string
|
|
var deletedAtStr *string
|
|
var totpSecretEnc, totpSecretNonce []byte
|
|
|
|
err := rows.Scan(
|
|
&a.ID, &a.UUID, &a.Username,
|
|
&accountType, &a.PasswordHash,
|
|
&status, &totpRequired,
|
|
&totpSecretEnc, &totpSecretNonce,
|
|
&createdAtStr, &updatedAtStr, &deletedAtStr,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db: scan account row: %w", err)
|
|
}
|
|
|
|
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
|
|
}
|
|
|
|
func finishAccountScan(a *model.Account, accountType, status string, totpRequired int, totpSecretEnc, totpSecretNonce []byte, createdAtStr, updatedAtStr string, deletedAtStr *string) (*model.Account, error) {
|
|
a.AccountType = model.AccountType(accountType)
|
|
a.Status = model.AccountStatus(status)
|
|
a.TOTPRequired = totpRequired == 1
|
|
a.TOTPSecretEnc = totpSecretEnc
|
|
a.TOTPSecretNonce = totpSecretNonce
|
|
|
|
var err error
|
|
a.CreatedAt, err = parseTime(createdAtStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
a.UpdatedAt, err = parseTime(updatedAtStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
a.DeletedAt, err = nullableTime(deletedAtStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return a, nil
|
|
}
|
|
|
|
// nullString converts an empty string to nil for nullable SQL columns.
|
|
func nullString(s string) *string {
|
|
if s == "" {
|
|
return nil
|
|
}
|
|
return &s
|
|
}
|
|
|
|
// GetRoles returns the role strings assigned to an account.
|
|
func (db *DB) GetRoles(accountID int64) ([]string, error) {
|
|
rows, err := db.sql.Query(`
|
|
SELECT role FROM account_roles WHERE account_id = ? ORDER BY role ASC
|
|
`, accountID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db: get roles for account %d: %w", accountID, err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var roles []string
|
|
for rows.Next() {
|
|
var role string
|
|
if err := rows.Scan(&role); err != nil {
|
|
return nil, fmt.Errorf("db: scan role: %w", err)
|
|
}
|
|
roles = append(roles, role)
|
|
}
|
|
return roles, rows.Err()
|
|
}
|
|
|
|
// GrantRole adds a role to an account. If the role already exists, it is a no-op.
|
|
func (db *DB) GrantRole(accountID int64, role string, grantedBy *int64) error {
|
|
_, err := db.sql.Exec(`
|
|
INSERT OR IGNORE INTO account_roles (account_id, role, granted_by, granted_at)
|
|
VALUES (?, ?, ?, ?)
|
|
`, accountID, role, grantedBy, now())
|
|
if err != nil {
|
|
return fmt.Errorf("db: grant role %q to account %d: %w", role, accountID, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RevokeRole removes a role from an account.
|
|
func (db *DB) RevokeRole(accountID int64, role string) error {
|
|
_, err := db.sql.Exec(`
|
|
DELETE FROM account_roles WHERE account_id = ? AND role = ?
|
|
`, accountID, role)
|
|
if err != nil {
|
|
return fmt.Errorf("db: revoke role %q from account %d: %w", role, accountID, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// SetRoles replaces the full role set for an account atomically.
|
|
func (db *DB) SetRoles(accountID int64, roles []string, grantedBy *int64) error {
|
|
tx, err := db.sql.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("db: set roles begin tx: %w", err)
|
|
}
|
|
|
|
if _, err := tx.Exec(`DELETE FROM account_roles WHERE account_id = ?`, accountID); err != nil {
|
|
_ = tx.Rollback()
|
|
return fmt.Errorf("db: set roles delete existing: %w", err)
|
|
}
|
|
|
|
n := now()
|
|
for _, role := range roles {
|
|
if _, err := tx.Exec(`
|
|
INSERT INTO account_roles (account_id, role, granted_by, granted_at)
|
|
VALUES (?, ?, ?, ?)
|
|
`, accountID, role, grantedBy, n); err != nil {
|
|
_ = tx.Rollback()
|
|
return fmt.Errorf("db: set roles insert %q: %w", role, err)
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("db: set roles commit: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// HasRole reports whether an account holds the given role.
|
|
func (db *DB) HasRole(accountID int64, role string) (bool, error) {
|
|
var count int
|
|
err := db.sql.QueryRow(`
|
|
SELECT COUNT(*) FROM account_roles WHERE account_id = ? AND role = ?
|
|
`, accountID, role).Scan(&count)
|
|
if err != nil {
|
|
return false, fmt.Errorf("db: has role: %w", err)
|
|
}
|
|
return count > 0, nil
|
|
}
|
|
|
|
// WriteServerConfig stores the encrypted Ed25519 signing key.
|
|
// There can only be one row (id=1).
|
|
func (db *DB) WriteServerConfig(signingKeyEnc, signingKeyNonce []byte) error {
|
|
n := now()
|
|
_, err := db.sql.Exec(`
|
|
INSERT INTO server_config (id, signing_key_enc, signing_key_nonce, created_at, updated_at)
|
|
VALUES (1, ?, ?, ?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
signing_key_enc = excluded.signing_key_enc,
|
|
signing_key_nonce = excluded.signing_key_nonce,
|
|
updated_at = excluded.updated_at
|
|
`, signingKeyEnc, signingKeyNonce, n, n)
|
|
if err != nil {
|
|
return fmt.Errorf("db: write server config: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ReadServerConfig returns the encrypted signing key and nonce.
|
|
// Returns ErrNotFound if no config row exists yet.
|
|
func (db *DB) ReadServerConfig() (signingKeyEnc, signingKeyNonce []byte, err error) {
|
|
err = db.sql.QueryRow(`
|
|
SELECT signing_key_enc, signing_key_nonce FROM server_config WHERE id = 1
|
|
`).Scan(&signingKeyEnc, &signingKeyNonce)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("db: read server config: %w", err)
|
|
}
|
|
return signingKeyEnc, signingKeyNonce, nil
|
|
}
|
|
|
|
// WriteMasterKeySalt stores the Argon2id KDF salt for the master key derivation.
|
|
// The salt must be stable across restarts so the same passphrase always yields
|
|
// the same master key. There can only be one row (id=1).
|
|
func (db *DB) WriteMasterKeySalt(salt []byte) error {
|
|
n := now()
|
|
_, err := db.sql.Exec(`
|
|
INSERT INTO server_config (id, master_key_salt, created_at, updated_at)
|
|
VALUES (1, ?, ?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
master_key_salt = excluded.master_key_salt,
|
|
updated_at = excluded.updated_at
|
|
`, salt, n, n)
|
|
if err != nil {
|
|
return fmt.Errorf("db: write master key salt: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ReadMasterKeySalt returns the stored Argon2id KDF salt.
|
|
// Returns ErrNotFound if no salt has been stored yet (first run).
|
|
func (db *DB) ReadMasterKeySalt() ([]byte, error) {
|
|
var salt []byte
|
|
err := db.sql.QueryRow(`
|
|
SELECT master_key_salt FROM server_config WHERE id = 1
|
|
`).Scan(&salt)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db: read master key salt: %w", err)
|
|
}
|
|
if salt == nil {
|
|
return nil, ErrNotFound
|
|
}
|
|
return salt, nil
|
|
}
|
|
|
|
// WritePGCredentials stores or replaces the Postgres credentials for an account.
|
|
func (db *DB) WritePGCredentials(accountID int64, host string, port int, dbName, username string, passwordEnc, passwordNonce []byte) error {
|
|
n := now()
|
|
_, err := db.sql.Exec(`
|
|
INSERT INTO pg_credentials
|
|
(account_id, pg_host, pg_port, pg_database, pg_username, pg_password_enc, pg_password_nonce, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(account_id) DO UPDATE SET
|
|
pg_host = excluded.pg_host,
|
|
pg_port = excluded.pg_port,
|
|
pg_database = excluded.pg_database,
|
|
pg_username = excluded.pg_username,
|
|
pg_password_enc = excluded.pg_password_enc,
|
|
pg_password_nonce = excluded.pg_password_nonce,
|
|
updated_at = excluded.updated_at
|
|
`, accountID, host, port, dbName, username, passwordEnc, passwordNonce, n, n)
|
|
if err != nil {
|
|
return fmt.Errorf("db: write pg credentials: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ReadPGCredentials retrieves the encrypted Postgres credentials for an account.
|
|
// Returns ErrNotFound if no credentials are stored.
|
|
func (db *DB) ReadPGCredentials(accountID int64) (*model.PGCredential, error) {
|
|
var cred model.PGCredential
|
|
var createdAtStr, updatedAtStr string
|
|
|
|
err := db.sql.QueryRow(`
|
|
SELECT id, account_id, pg_host, pg_port, pg_database, pg_username,
|
|
pg_password_enc, pg_password_nonce, created_at, updated_at
|
|
FROM pg_credentials WHERE account_id = ?
|
|
`, accountID).Scan(
|
|
&cred.ID, &cred.AccountID, &cred.PGHost, &cred.PGPort,
|
|
&cred.PGDatabase, &cred.PGUsername,
|
|
&cred.PGPasswordEnc, &cred.PGPasswordNonce,
|
|
&createdAtStr, &updatedAtStr,
|
|
)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db: read pg credentials: %w", err)
|
|
}
|
|
|
|
cred.CreatedAt, err = parseTime(createdAtStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cred.UpdatedAt, err = parseTime(updatedAtStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &cred, nil
|
|
}
|
|
|
|
// WriteAuditEvent appends an audit log entry.
|
|
// Details must never contain credential material.
|
|
func (db *DB) WriteAuditEvent(eventType string, actorID, targetID *int64, ipAddress, details string) error {
|
|
_, err := db.sql.Exec(`
|
|
INSERT INTO audit_log (event_type, actor_id, target_id, ip_address, details)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
`, eventType, actorID, targetID, nullString(ipAddress), nullString(details))
|
|
if err != nil {
|
|
return fmt.Errorf("db: write audit event %q: %w", eventType, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// TrackToken records a newly issued JWT JTI for revocation tracking.
|
|
func (db *DB) TrackToken(jti string, accountID int64, issuedAt, expiresAt time.Time) error {
|
|
_, err := db.sql.Exec(`
|
|
INSERT INTO token_revocation (jti, account_id, issued_at, expires_at)
|
|
VALUES (?, ?, ?, ?)
|
|
`, jti, accountID, issuedAt.UTC().Format(time.RFC3339), expiresAt.UTC().Format(time.RFC3339))
|
|
if err != nil {
|
|
return fmt.Errorf("db: track token %q: %w", jti, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetTokenRecord retrieves a token record by JTI.
|
|
// Returns ErrNotFound if no record exists (token was never issued by this server).
|
|
func (db *DB) GetTokenRecord(jti string) (*model.TokenRecord, error) {
|
|
var rec model.TokenRecord
|
|
var issuedAtStr, expiresAtStr, createdAtStr string
|
|
var revokedAtStr *string
|
|
var revokeReason *string
|
|
|
|
err := db.sql.QueryRow(`
|
|
SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at
|
|
FROM token_revocation WHERE jti = ?
|
|
`, jti).Scan(
|
|
&rec.ID, &rec.JTI, &rec.AccountID,
|
|
&expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason,
|
|
&createdAtStr,
|
|
)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db: get token record %q: %w", jti, err)
|
|
}
|
|
|
|
var parseErr error
|
|
rec.ExpiresAt, parseErr = parseTime(expiresAtStr)
|
|
if parseErr != nil {
|
|
return nil, parseErr
|
|
}
|
|
rec.IssuedAt, parseErr = parseTime(issuedAtStr)
|
|
if parseErr != nil {
|
|
return nil, parseErr
|
|
}
|
|
rec.CreatedAt, parseErr = parseTime(createdAtStr)
|
|
if parseErr != nil {
|
|
return nil, parseErr
|
|
}
|
|
rec.RevokedAt, parseErr = nullableTime(revokedAtStr)
|
|
if parseErr != nil {
|
|
return nil, parseErr
|
|
}
|
|
if revokeReason != nil {
|
|
rec.RevokeReason = *revokeReason
|
|
}
|
|
return &rec, nil
|
|
}
|
|
|
|
// RevokeToken marks a token as revoked by JTI.
|
|
func (db *DB) RevokeToken(jti, reason string) error {
|
|
n := now()
|
|
result, err := db.sql.Exec(`
|
|
UPDATE token_revocation
|
|
SET revoked_at = ?, revoke_reason = ?
|
|
WHERE jti = ? AND revoked_at IS NULL
|
|
`, n, nullString(reason), jti)
|
|
if err != nil {
|
|
return fmt.Errorf("db: revoke token %q: %w", jti, err)
|
|
}
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("db: revoke token rows affected: %w", err)
|
|
}
|
|
if rows == 0 {
|
|
return fmt.Errorf("db: token %q not found or already revoked", jti)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RenewToken atomically revokes the old token and tracks the new one in a
|
|
// single SQLite transaction.
|
|
//
|
|
// Security: the two operations must be atomic so that a failure between them
|
|
// cannot leave the user with neither a valid old token nor a new one. With
|
|
// MaxOpenConns(1) and SQLite's serialised write path, BEGIN IMMEDIATE acquires
|
|
// the write lock immediately and prevents any other writer from interleaving.
|
|
func (db *DB) RenewToken(oldJTI, reason, newJTI string, accountID int64, issuedAt, expiresAt time.Time) error {
|
|
tx, err := db.sql.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("db: renew token begin tx: %w", err)
|
|
}
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
n := now()
|
|
|
|
// Revoke the old token.
|
|
result, err := tx.Exec(`
|
|
UPDATE token_revocation
|
|
SET revoked_at = ?, revoke_reason = ?
|
|
WHERE jti = ? AND revoked_at IS NULL
|
|
`, n, nullString(reason), oldJTI)
|
|
if err != nil {
|
|
return fmt.Errorf("db: renew token revoke old %q: %w", oldJTI, err)
|
|
}
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("db: renew token revoke rows affected: %w", err)
|
|
}
|
|
if rows == 0 {
|
|
return fmt.Errorf("db: renew token: old token %q not found or already revoked", oldJTI)
|
|
}
|
|
|
|
// Track the new token.
|
|
_, err = tx.Exec(`
|
|
INSERT INTO token_revocation (jti, account_id, issued_at, expires_at)
|
|
VALUES (?, ?, ?, ?)
|
|
`, newJTI, accountID,
|
|
issuedAt.UTC().Format(time.RFC3339),
|
|
expiresAt.UTC().Format(time.RFC3339))
|
|
if err != nil {
|
|
return fmt.Errorf("db: renew token track new %q: %w", newJTI, err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("db: renew token commit: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RevokeAllUserTokens revokes all non-expired, non-revoked tokens for an account.
|
|
func (db *DB) RevokeAllUserTokens(accountID int64, reason string) error {
|
|
n := now()
|
|
_, err := db.sql.Exec(`
|
|
UPDATE token_revocation
|
|
SET revoked_at = ?, revoke_reason = ?
|
|
WHERE account_id = ? AND revoked_at IS NULL AND expires_at > ?
|
|
`, n, nullString(reason), accountID, n)
|
|
if err != nil {
|
|
return fmt.Errorf("db: revoke all tokens for account %d: %w", accountID, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// PruneExpiredTokens removes token_revocation rows that are past their expiry.
|
|
// Returns the number of rows deleted.
|
|
func (db *DB) PruneExpiredTokens() (int64, error) {
|
|
result, err := db.sql.Exec(`
|
|
DELETE FROM token_revocation WHERE expires_at < ?
|
|
`, now())
|
|
if err != nil {
|
|
return 0, fmt.Errorf("db: prune expired tokens: %w", err)
|
|
}
|
|
return result.RowsAffected()
|
|
}
|
|
|
|
// ListTokensForAccount returns all token_revocation rows for the given account,
|
|
// ordered by issued_at descending (newest first).
|
|
func (db *DB) ListTokensForAccount(accountID int64) ([]*model.TokenRecord, error) {
|
|
rows, err := db.sql.Query(`
|
|
SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at
|
|
FROM token_revocation
|
|
WHERE account_id = ?
|
|
ORDER BY issued_at DESC
|
|
`, accountID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db: list tokens for account %d: %w", accountID, err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var records []*model.TokenRecord
|
|
for rows.Next() {
|
|
var rec model.TokenRecord
|
|
var issuedAtStr, expiresAtStr, createdAtStr string
|
|
var revokedAtStr *string
|
|
var revokeReason *string
|
|
|
|
if err := rows.Scan(
|
|
&rec.ID, &rec.JTI, &rec.AccountID,
|
|
&expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason,
|
|
&createdAtStr,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("db: scan token record: %w", err)
|
|
}
|
|
|
|
var parseErr error
|
|
rec.ExpiresAt, parseErr = parseTime(expiresAtStr)
|
|
if parseErr != nil {
|
|
return nil, parseErr
|
|
}
|
|
rec.IssuedAt, parseErr = parseTime(issuedAtStr)
|
|
if parseErr != nil {
|
|
return nil, parseErr
|
|
}
|
|
rec.CreatedAt, parseErr = parseTime(createdAtStr)
|
|
if parseErr != nil {
|
|
return nil, parseErr
|
|
}
|
|
rec.RevokedAt, parseErr = nullableTime(revokedAtStr)
|
|
if parseErr != nil {
|
|
return nil, parseErr
|
|
}
|
|
if revokeReason != nil {
|
|
rec.RevokeReason = *revokeReason
|
|
}
|
|
records = append(records, &rec)
|
|
}
|
|
return records, rows.Err()
|
|
}
|
|
|
|
// AuditQueryParams filters for ListAuditEvents and ListAuditEventsPaged.
|
|
type AuditQueryParams struct {
|
|
AccountID *int64
|
|
Since *time.Time
|
|
EventType string
|
|
Limit int
|
|
Offset int
|
|
}
|
|
|
|
// AuditEventView extends AuditEvent with resolved actor/target usernames for display.
|
|
// Usernames are resolved via a LEFT JOIN and are empty if the actor/target is unknown.
|
|
// The fieldalignment hint is suppressed: the embedded model.AuditEvent layout is fixed
|
|
// and changing to explicit fields would break JSON serialisation.
|
|
type AuditEventView struct { //nolint:govet
|
|
model.AuditEvent
|
|
ActorUsername string `json:"actor_username,omitempty"`
|
|
TargetUsername string `json:"target_username,omitempty"`
|
|
}
|
|
|
|
// ListAuditEvents returns audit log entries matching the given parameters,
|
|
// ordered by event_time ascending. Limit rows are returned if Limit > 0.
|
|
func (db *DB) ListAuditEvents(p AuditQueryParams) ([]*model.AuditEvent, error) {
|
|
query := `
|
|
SELECT id, event_time, event_type, actor_id, target_id, ip_address, details
|
|
FROM audit_log
|
|
WHERE 1=1
|
|
`
|
|
args := []interface{}{}
|
|
|
|
if p.AccountID != nil {
|
|
query += ` AND (actor_id = ? OR target_id = ?)`
|
|
args = append(args, *p.AccountID, *p.AccountID)
|
|
}
|
|
if p.EventType != "" {
|
|
query += ` AND event_type = ?`
|
|
args = append(args, p.EventType)
|
|
}
|
|
if p.Since != nil {
|
|
query += ` AND event_time >= ?`
|
|
args = append(args, p.Since.UTC().Format(time.RFC3339))
|
|
}
|
|
|
|
query += ` ORDER BY event_time ASC, id ASC`
|
|
|
|
if p.Limit > 0 {
|
|
query += ` LIMIT ?`
|
|
args = append(args, p.Limit)
|
|
}
|
|
|
|
rows, err := db.sql.Query(query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db: list audit events: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var events []*model.AuditEvent
|
|
for rows.Next() {
|
|
var ev model.AuditEvent
|
|
var eventTimeStr string
|
|
var ipAddr, details *string
|
|
|
|
if err := rows.Scan(
|
|
&ev.ID, &eventTimeStr, &ev.EventType,
|
|
&ev.ActorID, &ev.TargetID,
|
|
&ipAddr, &details,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("db: scan audit event: %w", err)
|
|
}
|
|
|
|
ev.EventTime, err = parseTime(eventTimeStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if ipAddr != nil {
|
|
ev.IPAddress = *ipAddr
|
|
}
|
|
if details != nil {
|
|
ev.Details = *details
|
|
}
|
|
events = append(events, &ev)
|
|
}
|
|
return events, rows.Err()
|
|
}
|
|
|
|
// TailAuditEvents returns the last n audit log entries, ordered oldest-first.
|
|
func (db *DB) TailAuditEvents(n int) ([]*model.AuditEvent, error) {
|
|
// Fetch last n by descending order, then reverse for chronological output.
|
|
rows, err := db.sql.Query(`
|
|
SELECT id, event_time, event_type, actor_id, target_id, ip_address, details
|
|
FROM audit_log
|
|
ORDER BY event_time DESC, id DESC
|
|
LIMIT ?
|
|
`, n)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db: tail audit events: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var events []*model.AuditEvent
|
|
for rows.Next() {
|
|
var ev model.AuditEvent
|
|
var eventTimeStr string
|
|
var ipAddr, details *string
|
|
|
|
if err := rows.Scan(
|
|
&ev.ID, &eventTimeStr, &ev.EventType,
|
|
&ev.ActorID, &ev.TargetID,
|
|
&ipAddr, &details,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("db: scan audit event: %w", err)
|
|
}
|
|
|
|
var parseErr error
|
|
ev.EventTime, parseErr = parseTime(eventTimeStr)
|
|
if parseErr != nil {
|
|
return nil, parseErr
|
|
}
|
|
if ipAddr != nil {
|
|
ev.IPAddress = *ipAddr
|
|
}
|
|
if details != nil {
|
|
ev.Details = *details
|
|
}
|
|
events = append(events, &ev)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Reverse to oldest-first.
|
|
for i, j := 0, len(events)-1; i < j; i, j = i+1, j-1 {
|
|
events[i], events[j] = events[j], events[i]
|
|
}
|
|
return events, nil
|
|
}
|
|
|
|
// ListAuditEventsPaged returns audit log entries matching params, newest first,
|
|
// with LEFT JOINed actor/target usernames for display. Returns the matching rows
|
|
// and the total count of matching rows (for pagination).
|
|
//
|
|
// Security: No credential material is included in audit_log rows per the
|
|
// WriteAuditEvent contract; joining account usernames is safe for display.
|
|
func (db *DB) ListAuditEventsPaged(p AuditQueryParams) ([]*AuditEventView, int64, error) {
|
|
// Build the shared WHERE clause and args.
|
|
where := " WHERE 1=1"
|
|
args := []interface{}{}
|
|
|
|
if p.AccountID != nil {
|
|
where += ` AND (al.actor_id = ? OR al.target_id = ?)`
|
|
args = append(args, *p.AccountID, *p.AccountID)
|
|
}
|
|
if p.EventType != "" {
|
|
where += ` AND al.event_type = ?`
|
|
args = append(args, p.EventType)
|
|
}
|
|
if p.Since != nil {
|
|
where += ` AND al.event_time >= ?`
|
|
args = append(args, p.Since.UTC().Format(time.RFC3339))
|
|
}
|
|
|
|
// Count total matching rows first.
|
|
countQuery := `SELECT COUNT(*) FROM audit_log al` + where
|
|
var total int64
|
|
if err := db.sql.QueryRow(countQuery, args...).Scan(&total); err != nil {
|
|
return nil, 0, fmt.Errorf("db: count audit events: %w", err)
|
|
}
|
|
|
|
// Fetch the page with username resolution via LEFT JOIN.
|
|
query := `
|
|
SELECT al.id, al.event_time, al.event_type,
|
|
al.actor_id, al.target_id,
|
|
al.ip_address, al.details,
|
|
COALESCE(a1.username, ''), COALESCE(a2.username, '')
|
|
FROM audit_log al
|
|
LEFT JOIN accounts a1 ON al.actor_id = a1.id
|
|
LEFT JOIN accounts a2 ON al.target_id = a2.id` + where + `
|
|
ORDER BY al.event_time DESC, al.id DESC`
|
|
|
|
pageArgs := append(args, p.Limit, p.Offset) //nolint:gocritic // intentional new slice
|
|
query += ` LIMIT ? OFFSET ?`
|
|
|
|
rows, err := db.sql.Query(query, pageArgs...)
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("db: list audit events paged: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
var events []*AuditEventView
|
|
for rows.Next() {
|
|
var ev AuditEventView
|
|
var eventTimeStr string
|
|
var ipAddr, details *string
|
|
|
|
if err := rows.Scan(
|
|
&ev.ID, &eventTimeStr, &ev.EventType,
|
|
&ev.ActorID, &ev.TargetID,
|
|
&ipAddr, &details,
|
|
&ev.ActorUsername, &ev.TargetUsername,
|
|
); err != nil {
|
|
return nil, 0, fmt.Errorf("db: scan audit event view: %w", err)
|
|
}
|
|
|
|
ev.EventTime, err = parseTime(eventTimeStr)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
if ipAddr != nil {
|
|
ev.IPAddress = *ipAddr
|
|
}
|
|
if details != nil {
|
|
ev.Details = *details
|
|
}
|
|
events = append(events, &ev)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, 0, err
|
|
}
|
|
return events, total, nil
|
|
}
|
|
|
|
// GetAuditEventByID fetches a single audit event by its integer primary key,
|
|
// with actor/target usernames resolved via LEFT JOIN. Returns ErrNotFound if
|
|
// no row matches.
|
|
func (db *DB) GetAuditEventByID(id int64) (*AuditEventView, error) {
|
|
row := db.sql.QueryRow(`
|
|
SELECT al.id, al.event_time, al.event_type,
|
|
al.actor_id, al.target_id,
|
|
al.ip_address, al.details,
|
|
COALESCE(a1.username, ''), COALESCE(a2.username, '')
|
|
FROM audit_log al
|
|
LEFT JOIN accounts a1 ON al.actor_id = a1.id
|
|
LEFT JOIN accounts a2 ON al.target_id = a2.id
|
|
WHERE al.id = ?
|
|
`, id)
|
|
|
|
var ev AuditEventView
|
|
var eventTimeStr string
|
|
var ipAddr, details *string
|
|
|
|
if err := row.Scan(
|
|
&ev.ID, &eventTimeStr, &ev.EventType,
|
|
&ev.ActorID, &ev.TargetID,
|
|
&ipAddr, &details,
|
|
&ev.ActorUsername, &ev.TargetUsername,
|
|
); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return nil, fmt.Errorf("db: get audit event %d: %w", id, err)
|
|
}
|
|
|
|
var err error
|
|
ev.EventTime, err = parseTime(eventTimeStr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if ipAddr != nil {
|
|
ev.IPAddress = *ipAddr
|
|
}
|
|
if details != nil {
|
|
ev.Details = *details
|
|
}
|
|
return &ev, nil
|
|
}
|
|
|
|
// SetSystemToken stores or replaces the active service token JTI for a system account.
|
|
func (db *DB) SetSystemToken(accountID int64, jti string, expiresAt time.Time) error {
|
|
n := now()
|
|
_, err := db.sql.Exec(`
|
|
INSERT INTO system_tokens (account_id, jti, expires_at, created_at)
|
|
VALUES (?, ?, ?, ?)
|
|
ON CONFLICT(account_id) DO UPDATE SET
|
|
jti = excluded.jti,
|
|
expires_at = excluded.expires_at,
|
|
created_at = excluded.created_at
|
|
`, accountID, jti, expiresAt.UTC().Format(time.RFC3339), n)
|
|
if err != nil {
|
|
return fmt.Errorf("db: set system token for account %d: %w", accountID, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetSystemToken retrieves the active service token record for a system account.
|
|
func (db *DB) GetSystemToken(accountID int64) (*model.SystemToken, error) {
|
|
var st model.SystemToken
|
|
var expiresAtStr, createdAtStr string
|
|
|
|
err := db.sql.QueryRow(`
|
|
SELECT id, account_id, jti, expires_at, created_at
|
|
FROM system_tokens WHERE account_id = ?
|
|
`, accountID).Scan(&st.ID, &st.AccountID, &st.JTI, &expiresAtStr, &createdAtStr)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("db: get system token: %w", err)
|
|
}
|
|
|
|
var parseErr error
|
|
st.ExpiresAt, parseErr = parseTime(expiresAtStr)
|
|
if parseErr != nil {
|
|
return nil, parseErr
|
|
}
|
|
st.CreatedAt, parseErr = parseTime(createdAtStr)
|
|
if parseErr != nil {
|
|
return nil, parseErr
|
|
}
|
|
return &st, nil
|
|
}
|