Files
mcias/internal/db/accounts.go
2026-03-11 11:48:49 -07:00

609 lines
18 KiB
Go

package db
import (
"database/sql"
"errors"
"fmt"
"time"
"git.wntrmute.dev/kyle/mcias/internal/model"
"github.com/google/uuid"
)
// CreateAccount inserts a new account record. The UUID is generated
// automatically. Returns the created Account with its DB-assigned ID and UUID.
func (db *DB) CreateAccount(username string, accountType model.AccountType, passwordHash string) (*model.Account, error) {
id := uuid.New().String()
n := now()
result, err := db.sql.Exec(`
INSERT INTO accounts (uuid, username, account_type, password_hash, status, created_at, updated_at)
VALUES (?, ?, ?, ?, 'active', ?, ?)
`, id, username, string(accountType), nullString(passwordHash), n, n)
if err != nil {
return nil, fmt.Errorf("db: create account %q: %w", username, err)
}
rowID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("db: last insert id for account %q: %w", username, err)
}
createdAt, err := parseTime(n)
if err != nil {
return nil, err
}
return &model.Account{
ID: rowID,
UUID: id,
Username: username,
AccountType: accountType,
Status: model.AccountStatusActive,
PasswordHash: passwordHash,
CreatedAt: createdAt,
UpdatedAt: createdAt,
}, nil
}
// GetAccountByUUID retrieves an account by its external UUID.
// Returns ErrNotFound if no matching account exists.
func (db *DB) GetAccountByUUID(accountUUID string) (*model.Account, error) {
return db.scanAccount(db.sql.QueryRow(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts WHERE uuid = ?
`, accountUUID))
}
// GetAccountByUsername retrieves an account by username (case-insensitive).
// Returns ErrNotFound if no matching account exists.
func (db *DB) GetAccountByUsername(username string) (*model.Account, error) {
return db.scanAccount(db.sql.QueryRow(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts WHERE username = ?
`, username))
}
// ListAccounts returns all non-deleted accounts ordered by username.
func (db *DB) ListAccounts() ([]*model.Account, error) {
rows, err := db.sql.Query(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts
WHERE status != 'deleted'
ORDER BY username ASC
`)
if err != nil {
return nil, fmt.Errorf("db: list accounts: %w", err)
}
defer rows.Close()
var accounts []*model.Account
for rows.Next() {
a, err := db.scanAccountRow(rows)
if err != nil {
return nil, err
}
accounts = append(accounts, a)
}
return accounts, rows.Err()
}
// UpdateAccountStatus updates the status field and optionally sets deleted_at.
func (db *DB) UpdateAccountStatus(accountID int64, status model.AccountStatus) error {
n := now()
var deletedAt *string
if status == model.AccountStatusDeleted {
deletedAt = &n
}
_, err := db.sql.Exec(`
UPDATE accounts SET status = ?, deleted_at = ?, updated_at = ?
WHERE id = ?
`, string(status), deletedAt, n, accountID)
if err != nil {
return fmt.Errorf("db: update account status: %w", err)
}
return nil
}
// UpdatePasswordHash updates the Argon2id password hash for an account.
func (db *DB) UpdatePasswordHash(accountID int64, hash string) error {
_, err := db.sql.Exec(`
UPDATE accounts SET password_hash = ?, updated_at = ?
WHERE id = ?
`, hash, now(), accountID)
if err != nil {
return fmt.Errorf("db: update password hash: %w", err)
}
return nil
}
// SetTOTP stores the encrypted TOTP secret and marks TOTP as required.
func (db *DB) SetTOTP(accountID int64, secretEnc, secretNonce []byte) error {
_, err := db.sql.Exec(`
UPDATE accounts
SET totp_required = 1, totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ?
WHERE id = ?
`, secretEnc, secretNonce, now(), accountID)
if err != nil {
return fmt.Errorf("db: set TOTP: %w", err)
}
return nil
}
// ClearTOTP removes the TOTP secret and disables TOTP requirement.
func (db *DB) ClearTOTP(accountID int64) error {
_, err := db.sql.Exec(`
UPDATE accounts
SET totp_required = 0, totp_secret_enc = NULL, totp_secret_nonce = NULL, updated_at = ?
WHERE id = ?
`, now(), accountID)
if err != nil {
return fmt.Errorf("db: clear TOTP: %w", err)
}
return nil
}
// scanAccount scans a single account row from a *sql.Row.
func (db *DB) scanAccount(row *sql.Row) (*model.Account, error) {
var a model.Account
var accountType, status string
var totpRequired int
var createdAtStr, updatedAtStr string
var deletedAtStr *string
var totpSecretEnc, totpSecretNonce []byte
err := row.Scan(
&a.ID, &a.UUID, &a.Username,
&accountType, &a.PasswordHash,
&status, &totpRequired,
&totpSecretEnc, &totpSecretNonce,
&createdAtStr, &updatedAtStr, &deletedAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: scan account: %w", err)
}
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
}
// scanAccountRow scans a single account from *sql.Rows.
func (db *DB) scanAccountRow(rows *sql.Rows) (*model.Account, error) {
var a model.Account
var accountType, status string
var totpRequired int
var createdAtStr, updatedAtStr string
var deletedAtStr *string
var totpSecretEnc, totpSecretNonce []byte
err := rows.Scan(
&a.ID, &a.UUID, &a.Username,
&accountType, &a.PasswordHash,
&status, &totpRequired,
&totpSecretEnc, &totpSecretNonce,
&createdAtStr, &updatedAtStr, &deletedAtStr,
)
if err != nil {
return nil, fmt.Errorf("db: scan account row: %w", err)
}
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
}
func finishAccountScan(a *model.Account, accountType, status string, totpRequired int, totpSecretEnc, totpSecretNonce []byte, createdAtStr, updatedAtStr string, deletedAtStr *string) (*model.Account, error) {
a.AccountType = model.AccountType(accountType)
a.Status = model.AccountStatus(status)
a.TOTPRequired = totpRequired == 1
a.TOTPSecretEnc = totpSecretEnc
a.TOTPSecretNonce = totpSecretNonce
var err error
a.CreatedAt, err = parseTime(createdAtStr)
if err != nil {
return nil, err
}
a.UpdatedAt, err = parseTime(updatedAtStr)
if err != nil {
return nil, err
}
a.DeletedAt, err = nullableTime(deletedAtStr)
if err != nil {
return nil, err
}
return a, nil
}
// nullString converts an empty string to nil for nullable SQL columns.
func nullString(s string) *string {
if s == "" {
return nil
}
return &s
}
// GetRoles returns the role strings assigned to an account.
func (db *DB) GetRoles(accountID int64) ([]string, error) {
rows, err := db.sql.Query(`
SELECT role FROM account_roles WHERE account_id = ? ORDER BY role ASC
`, accountID)
if err != nil {
return nil, fmt.Errorf("db: get roles for account %d: %w", accountID, err)
}
defer rows.Close()
var roles []string
for rows.Next() {
var role string
if err := rows.Scan(&role); err != nil {
return nil, fmt.Errorf("db: scan role: %w", err)
}
roles = append(roles, role)
}
return roles, rows.Err()
}
// GrantRole adds a role to an account. If the role already exists, it is a no-op.
func (db *DB) GrantRole(accountID int64, role string, grantedBy *int64) error {
_, err := db.sql.Exec(`
INSERT OR IGNORE INTO account_roles (account_id, role, granted_by, granted_at)
VALUES (?, ?, ?, ?)
`, accountID, role, grantedBy, now())
if err != nil {
return fmt.Errorf("db: grant role %q to account %d: %w", role, accountID, err)
}
return nil
}
// RevokeRole removes a role from an account.
func (db *DB) RevokeRole(accountID int64, role string) error {
_, err := db.sql.Exec(`
DELETE FROM account_roles WHERE account_id = ? AND role = ?
`, accountID, role)
if err != nil {
return fmt.Errorf("db: revoke role %q from account %d: %w", role, accountID, err)
}
return nil
}
// SetRoles replaces the full role set for an account atomically.
func (db *DB) SetRoles(accountID int64, roles []string, grantedBy *int64) error {
tx, err := db.sql.Begin()
if err != nil {
return fmt.Errorf("db: set roles begin tx: %w", err)
}
if _, err := tx.Exec(`DELETE FROM account_roles WHERE account_id = ?`, accountID); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: set roles delete existing: %w", err)
}
n := now()
for _, role := range roles {
if _, err := tx.Exec(`
INSERT INTO account_roles (account_id, role, granted_by, granted_at)
VALUES (?, ?, ?, ?)
`, accountID, role, grantedBy, n); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: set roles insert %q: %w", role, err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("db: set roles commit: %w", err)
}
return nil
}
// HasRole reports whether an account holds the given role.
func (db *DB) HasRole(accountID int64, role string) (bool, error) {
var count int
err := db.sql.QueryRow(`
SELECT COUNT(*) FROM account_roles WHERE account_id = ? AND role = ?
`, accountID, role).Scan(&count)
if err != nil {
return false, fmt.Errorf("db: has role: %w", err)
}
return count > 0, nil
}
// WriteServerConfig stores the encrypted Ed25519 signing key.
// There can only be one row (id=1).
func (db *DB) WriteServerConfig(signingKeyEnc, signingKeyNonce []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO server_config (id, signing_key_enc, signing_key_nonce, created_at, updated_at)
VALUES (1, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
signing_key_enc = excluded.signing_key_enc,
signing_key_nonce = excluded.signing_key_nonce,
updated_at = excluded.updated_at
`, signingKeyEnc, signingKeyNonce, n, n)
if err != nil {
return fmt.Errorf("db: write server config: %w", err)
}
return nil
}
// ReadServerConfig returns the encrypted signing key and nonce.
// Returns ErrNotFound if no config row exists yet.
func (db *DB) ReadServerConfig() (signingKeyEnc, signingKeyNonce []byte, err error) {
err = db.sql.QueryRow(`
SELECT signing_key_enc, signing_key_nonce FROM server_config WHERE id = 1
`).Scan(&signingKeyEnc, &signingKeyNonce)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, ErrNotFound
}
if err != nil {
return nil, nil, fmt.Errorf("db: read server config: %w", err)
}
return signingKeyEnc, signingKeyNonce, nil
}
// WriteMasterKeySalt stores the Argon2id KDF salt for the master key derivation.
// The salt must be stable across restarts so the same passphrase always yields
// the same master key. There can only be one row (id=1).
func (db *DB) WriteMasterKeySalt(salt []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO server_config (id, master_key_salt, created_at, updated_at)
VALUES (1, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
master_key_salt = excluded.master_key_salt,
updated_at = excluded.updated_at
`, salt, n, n)
if err != nil {
return fmt.Errorf("db: write master key salt: %w", err)
}
return nil
}
// ReadMasterKeySalt returns the stored Argon2id KDF salt.
// Returns ErrNotFound if no salt has been stored yet (first run).
func (db *DB) ReadMasterKeySalt() ([]byte, error) {
var salt []byte
err := db.sql.QueryRow(`
SELECT master_key_salt FROM server_config WHERE id = 1
`).Scan(&salt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: read master key salt: %w", err)
}
if salt == nil {
return nil, ErrNotFound
}
return salt, nil
}
// WritePGCredentials stores or replaces the Postgres credentials for an account.
func (db *DB) WritePGCredentials(accountID int64, host string, port int, dbName, username string, passwordEnc, passwordNonce []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO pg_credentials
(account_id, pg_host, pg_port, pg_database, pg_username, pg_password_enc, pg_password_nonce, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(account_id) DO UPDATE SET
pg_host = excluded.pg_host,
pg_port = excluded.pg_port,
pg_database = excluded.pg_database,
pg_username = excluded.pg_username,
pg_password_enc = excluded.pg_password_enc,
pg_password_nonce = excluded.pg_password_nonce,
updated_at = excluded.updated_at
`, accountID, host, port, dbName, username, passwordEnc, passwordNonce, n, n)
if err != nil {
return fmt.Errorf("db: write pg credentials: %w", err)
}
return nil
}
// ReadPGCredentials retrieves the encrypted Postgres credentials for an account.
// Returns ErrNotFound if no credentials are stored.
func (db *DB) ReadPGCredentials(accountID int64) (*model.PGCredential, error) {
var cred model.PGCredential
var createdAtStr, updatedAtStr string
err := db.sql.QueryRow(`
SELECT id, account_id, pg_host, pg_port, pg_database, pg_username,
pg_password_enc, pg_password_nonce, created_at, updated_at
FROM pg_credentials WHERE account_id = ?
`, accountID).Scan(
&cred.ID, &cred.AccountID, &cred.PGHost, &cred.PGPort,
&cred.PGDatabase, &cred.PGUsername,
&cred.PGPasswordEnc, &cred.PGPasswordNonce,
&createdAtStr, &updatedAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: read pg credentials: %w", err)
}
cred.CreatedAt, err = parseTime(createdAtStr)
if err != nil {
return nil, err
}
cred.UpdatedAt, err = parseTime(updatedAtStr)
if err != nil {
return nil, err
}
return &cred, nil
}
// WriteAuditEvent appends an audit log entry.
// Details must never contain credential material.
func (db *DB) WriteAuditEvent(eventType string, actorID, targetID *int64, ipAddress, details string) error {
_, err := db.sql.Exec(`
INSERT INTO audit_log (event_type, actor_id, target_id, ip_address, details)
VALUES (?, ?, ?, ?, ?)
`, eventType, actorID, targetID, nullString(ipAddress), nullString(details))
if err != nil {
return fmt.Errorf("db: write audit event %q: %w", eventType, err)
}
return nil
}
// TrackToken records a newly issued JWT JTI for revocation tracking.
func (db *DB) TrackToken(jti string, accountID int64, issuedAt, expiresAt time.Time) error {
_, err := db.sql.Exec(`
INSERT INTO token_revocation (jti, account_id, issued_at, expires_at)
VALUES (?, ?, ?, ?)
`, jti, accountID, issuedAt.UTC().Format(time.RFC3339), expiresAt.UTC().Format(time.RFC3339))
if err != nil {
return fmt.Errorf("db: track token %q: %w", jti, err)
}
return nil
}
// GetTokenRecord retrieves a token record by JTI.
// Returns ErrNotFound if no record exists (token was never issued by this server).
func (db *DB) GetTokenRecord(jti string) (*model.TokenRecord, error) {
var rec model.TokenRecord
var issuedAtStr, expiresAtStr, createdAtStr string
var revokedAtStr *string
var revokeReason *string
err := db.sql.QueryRow(`
SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at
FROM token_revocation WHERE jti = ?
`, jti).Scan(
&rec.ID, &rec.JTI, &rec.AccountID,
&expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason,
&createdAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: get token record %q: %w", jti, err)
}
var parseErr error
rec.ExpiresAt, parseErr = parseTime(expiresAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.IssuedAt, parseErr = parseTime(issuedAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.CreatedAt, parseErr = parseTime(createdAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.RevokedAt, parseErr = nullableTime(revokedAtStr)
if parseErr != nil {
return nil, parseErr
}
if revokeReason != nil {
rec.RevokeReason = *revokeReason
}
return &rec, nil
}
// RevokeToken marks a token as revoked by JTI.
func (db *DB) RevokeToken(jti, reason string) error {
n := now()
result, err := db.sql.Exec(`
UPDATE token_revocation
SET revoked_at = ?, revoke_reason = ?
WHERE jti = ? AND revoked_at IS NULL
`, n, nullString(reason), jti)
if err != nil {
return fmt.Errorf("db: revoke token %q: %w", jti, err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("db: revoke token rows affected: %w", err)
}
if rows == 0 {
return fmt.Errorf("db: token %q not found or already revoked", jti)
}
return nil
}
// RevokeAllUserTokens revokes all non-expired, non-revoked tokens for an account.
func (db *DB) RevokeAllUserTokens(accountID int64, reason string) error {
n := now()
_, err := db.sql.Exec(`
UPDATE token_revocation
SET revoked_at = ?, revoke_reason = ?
WHERE account_id = ? AND revoked_at IS NULL AND expires_at > ?
`, n, nullString(reason), accountID, n)
if err != nil {
return fmt.Errorf("db: revoke all tokens for account %d: %w", accountID, err)
}
return nil
}
// PruneExpiredTokens removes token_revocation rows that are past their expiry.
// Returns the number of rows deleted.
func (db *DB) PruneExpiredTokens() (int64, error) {
result, err := db.sql.Exec(`
DELETE FROM token_revocation WHERE expires_at < ?
`, now())
if err != nil {
return 0, fmt.Errorf("db: prune expired tokens: %w", err)
}
return result.RowsAffected()
}
// SetSystemToken stores or replaces the active service token JTI for a system account.
func (db *DB) SetSystemToken(accountID int64, jti string, expiresAt time.Time) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO system_tokens (account_id, jti, expires_at, created_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(account_id) DO UPDATE SET
jti = excluded.jti,
expires_at = excluded.expires_at,
created_at = excluded.created_at
`, accountID, jti, expiresAt.UTC().Format(time.RFC3339), n)
if err != nil {
return fmt.Errorf("db: set system token for account %d: %w", accountID, err)
}
return nil
}
// GetSystemToken retrieves the active service token record for a system account.
func (db *DB) GetSystemToken(accountID int64) (*model.SystemToken, error) {
var st model.SystemToken
var expiresAtStr, createdAtStr string
err := db.sql.QueryRow(`
SELECT id, account_id, jti, expires_at, created_at
FROM system_tokens WHERE account_id = ?
`, accountID).Scan(&st.ID, &st.AccountID, &st.JTI, &expiresAtStr, &createdAtStr)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: get system token: %w", err)
}
var parseErr error
st.ExpiresAt, parseErr = parseTime(expiresAtStr)
if parseErr != nil {
return nil, parseErr
}
st.CreatedAt, parseErr = parseTime(createdAtStr)
if parseErr != nil {
return nil, parseErr
}
return &st, nil
}