checkpoint mciassrv
This commit is contained in:
608
internal/db/accounts.go
Normal file
608
internal/db/accounts.go
Normal file
@@ -0,0 +1,608 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// CreateAccount inserts a new account record. The UUID is generated
|
||||
// automatically. Returns the created Account with its DB-assigned ID and UUID.
|
||||
func (db *DB) CreateAccount(username string, accountType model.AccountType, passwordHash string) (*model.Account, error) {
|
||||
id := uuid.New().String()
|
||||
n := now()
|
||||
|
||||
result, err := db.sql.Exec(`
|
||||
INSERT INTO accounts (uuid, username, account_type, password_hash, status, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, 'active', ?, ?)
|
||||
`, id, username, string(accountType), nullString(passwordHash), n, n)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: create account %q: %w", username, err)
|
||||
}
|
||||
|
||||
rowID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: last insert id for account %q: %w", username, err)
|
||||
}
|
||||
|
||||
createdAt, err := parseTime(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.Account{
|
||||
ID: rowID,
|
||||
UUID: id,
|
||||
Username: username,
|
||||
AccountType: accountType,
|
||||
Status: model.AccountStatusActive,
|
||||
PasswordHash: passwordHash,
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: createdAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetAccountByUUID retrieves an account by its external UUID.
|
||||
// Returns ErrNotFound if no matching account exists.
|
||||
func (db *DB) GetAccountByUUID(accountUUID string) (*model.Account, error) {
|
||||
return db.scanAccount(db.sql.QueryRow(`
|
||||
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
|
||||
status, totp_required,
|
||||
totp_secret_enc, totp_secret_nonce,
|
||||
created_at, updated_at, deleted_at
|
||||
FROM accounts WHERE uuid = ?
|
||||
`, accountUUID))
|
||||
}
|
||||
|
||||
// GetAccountByUsername retrieves an account by username (case-insensitive).
|
||||
// Returns ErrNotFound if no matching account exists.
|
||||
func (db *DB) GetAccountByUsername(username string) (*model.Account, error) {
|
||||
return db.scanAccount(db.sql.QueryRow(`
|
||||
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
|
||||
status, totp_required,
|
||||
totp_secret_enc, totp_secret_nonce,
|
||||
created_at, updated_at, deleted_at
|
||||
FROM accounts WHERE username = ?
|
||||
`, username))
|
||||
}
|
||||
|
||||
// ListAccounts returns all non-deleted accounts ordered by username.
|
||||
func (db *DB) ListAccounts() ([]*model.Account, error) {
|
||||
rows, err := db.sql.Query(`
|
||||
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
|
||||
status, totp_required,
|
||||
totp_secret_enc, totp_secret_nonce,
|
||||
created_at, updated_at, deleted_at
|
||||
FROM accounts
|
||||
WHERE status != 'deleted'
|
||||
ORDER BY username ASC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: list accounts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var accounts []*model.Account
|
||||
for rows.Next() {
|
||||
a, err := db.scanAccountRow(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accounts = append(accounts, a)
|
||||
}
|
||||
return accounts, rows.Err()
|
||||
}
|
||||
|
||||
// UpdateAccountStatus updates the status field and optionally sets deleted_at.
|
||||
func (db *DB) UpdateAccountStatus(accountID int64, status model.AccountStatus) error {
|
||||
n := now()
|
||||
var deletedAt *string
|
||||
if status == model.AccountStatusDeleted {
|
||||
deletedAt = &n
|
||||
}
|
||||
|
||||
_, err := db.sql.Exec(`
|
||||
UPDATE accounts SET status = ?, deleted_at = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
`, string(status), deletedAt, n, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: update account status: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePasswordHash updates the Argon2id password hash for an account.
|
||||
func (db *DB) UpdatePasswordHash(accountID int64, hash string) error {
|
||||
_, err := db.sql.Exec(`
|
||||
UPDATE accounts SET password_hash = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
`, hash, now(), accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: update password hash: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTOTP stores the encrypted TOTP secret and marks TOTP as required.
|
||||
func (db *DB) SetTOTP(accountID int64, secretEnc, secretNonce []byte) error {
|
||||
_, err := db.sql.Exec(`
|
||||
UPDATE accounts
|
||||
SET totp_required = 1, totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
`, secretEnc, secretNonce, now(), accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: set TOTP: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearTOTP removes the TOTP secret and disables TOTP requirement.
|
||||
func (db *DB) ClearTOTP(accountID int64) error {
|
||||
_, err := db.sql.Exec(`
|
||||
UPDATE accounts
|
||||
SET totp_required = 0, totp_secret_enc = NULL, totp_secret_nonce = NULL, updated_at = ?
|
||||
WHERE id = ?
|
||||
`, now(), accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: clear TOTP: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// scanAccount scans a single account row from a *sql.Row.
|
||||
func (db *DB) scanAccount(row *sql.Row) (*model.Account, error) {
|
||||
var a model.Account
|
||||
var accountType, status string
|
||||
var totpRequired int
|
||||
var createdAtStr, updatedAtStr string
|
||||
var deletedAtStr *string
|
||||
var totpSecretEnc, totpSecretNonce []byte
|
||||
|
||||
err := row.Scan(
|
||||
&a.ID, &a.UUID, &a.Username,
|
||||
&accountType, &a.PasswordHash,
|
||||
&status, &totpRequired,
|
||||
&totpSecretEnc, &totpSecretNonce,
|
||||
&createdAtStr, &updatedAtStr, &deletedAtStr,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: scan account: %w", err)
|
||||
}
|
||||
|
||||
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
|
||||
}
|
||||
|
||||
// scanAccountRow scans a single account from *sql.Rows.
|
||||
func (db *DB) scanAccountRow(rows *sql.Rows) (*model.Account, error) {
|
||||
var a model.Account
|
||||
var accountType, status string
|
||||
var totpRequired int
|
||||
var createdAtStr, updatedAtStr string
|
||||
var deletedAtStr *string
|
||||
var totpSecretEnc, totpSecretNonce []byte
|
||||
|
||||
err := rows.Scan(
|
||||
&a.ID, &a.UUID, &a.Username,
|
||||
&accountType, &a.PasswordHash,
|
||||
&status, &totpRequired,
|
||||
&totpSecretEnc, &totpSecretNonce,
|
||||
&createdAtStr, &updatedAtStr, &deletedAtStr,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: scan account row: %w", err)
|
||||
}
|
||||
|
||||
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
|
||||
}
|
||||
|
||||
func finishAccountScan(a *model.Account, accountType, status string, totpRequired int, totpSecretEnc, totpSecretNonce []byte, createdAtStr, updatedAtStr string, deletedAtStr *string) (*model.Account, error) {
|
||||
a.AccountType = model.AccountType(accountType)
|
||||
a.Status = model.AccountStatus(status)
|
||||
a.TOTPRequired = totpRequired == 1
|
||||
a.TOTPSecretEnc = totpSecretEnc
|
||||
a.TOTPSecretNonce = totpSecretNonce
|
||||
|
||||
var err error
|
||||
a.CreatedAt, err = parseTime(createdAtStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.UpdatedAt, err = parseTime(updatedAtStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.DeletedAt, err = nullableTime(deletedAtStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// nullString converts an empty string to nil for nullable SQL columns.
|
||||
func nullString(s string) *string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
// GetRoles returns the role strings assigned to an account.
|
||||
func (db *DB) GetRoles(accountID int64) ([]string, error) {
|
||||
rows, err := db.sql.Query(`
|
||||
SELECT role FROM account_roles WHERE account_id = ? ORDER BY role ASC
|
||||
`, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: get roles for account %d: %w", accountID, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var roles []string
|
||||
for rows.Next() {
|
||||
var role string
|
||||
if err := rows.Scan(&role); err != nil {
|
||||
return nil, fmt.Errorf("db: scan role: %w", err)
|
||||
}
|
||||
roles = append(roles, role)
|
||||
}
|
||||
return roles, rows.Err()
|
||||
}
|
||||
|
||||
// GrantRole adds a role to an account. If the role already exists, it is a no-op.
|
||||
func (db *DB) GrantRole(accountID int64, role string, grantedBy *int64) error {
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT OR IGNORE INTO account_roles (account_id, role, granted_by, granted_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`, accountID, role, grantedBy, now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: grant role %q to account %d: %w", role, accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeRole removes a role from an account.
|
||||
func (db *DB) RevokeRole(accountID int64, role string) error {
|
||||
_, err := db.sql.Exec(`
|
||||
DELETE FROM account_roles WHERE account_id = ? AND role = ?
|
||||
`, accountID, role)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: revoke role %q from account %d: %w", role, accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetRoles replaces the full role set for an account atomically.
|
||||
func (db *DB) SetRoles(accountID int64, roles []string, grantedBy *int64) error {
|
||||
tx, err := db.sql.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: set roles begin tx: %w", err)
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(`DELETE FROM account_roles WHERE account_id = ?`, accountID); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("db: set roles delete existing: %w", err)
|
||||
}
|
||||
|
||||
n := now()
|
||||
for _, role := range roles {
|
||||
if _, err := tx.Exec(`
|
||||
INSERT INTO account_roles (account_id, role, granted_by, granted_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`, accountID, role, grantedBy, n); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("db: set roles insert %q: %w", role, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("db: set roles commit: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasRole reports whether an account holds the given role.
|
||||
func (db *DB) HasRole(accountID int64, role string) (bool, error) {
|
||||
var count int
|
||||
err := db.sql.QueryRow(`
|
||||
SELECT COUNT(*) FROM account_roles WHERE account_id = ? AND role = ?
|
||||
`, accountID, role).Scan(&count)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("db: has role: %w", err)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// WriteServerConfig stores the encrypted Ed25519 signing key.
|
||||
// There can only be one row (id=1).
|
||||
func (db *DB) WriteServerConfig(signingKeyEnc, signingKeyNonce []byte) error {
|
||||
n := now()
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT INTO server_config (id, signing_key_enc, signing_key_nonce, created_at, updated_at)
|
||||
VALUES (1, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
signing_key_enc = excluded.signing_key_enc,
|
||||
signing_key_nonce = excluded.signing_key_nonce,
|
||||
updated_at = excluded.updated_at
|
||||
`, signingKeyEnc, signingKeyNonce, n, n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: write server config: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadServerConfig returns the encrypted signing key and nonce.
|
||||
// Returns ErrNotFound if no config row exists yet.
|
||||
func (db *DB) ReadServerConfig() (signingKeyEnc, signingKeyNonce []byte, err error) {
|
||||
err = db.sql.QueryRow(`
|
||||
SELECT signing_key_enc, signing_key_nonce FROM server_config WHERE id = 1
|
||||
`).Scan(&signingKeyEnc, &signingKeyNonce)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("db: read server config: %w", err)
|
||||
}
|
||||
return signingKeyEnc, signingKeyNonce, nil
|
||||
}
|
||||
|
||||
// WriteMasterKeySalt stores the Argon2id KDF salt for the master key derivation.
|
||||
// The salt must be stable across restarts so the same passphrase always yields
|
||||
// the same master key. There can only be one row (id=1).
|
||||
func (db *DB) WriteMasterKeySalt(salt []byte) error {
|
||||
n := now()
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT INTO server_config (id, master_key_salt, created_at, updated_at)
|
||||
VALUES (1, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
master_key_salt = excluded.master_key_salt,
|
||||
updated_at = excluded.updated_at
|
||||
`, salt, n, n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: write master key salt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadMasterKeySalt returns the stored Argon2id KDF salt.
|
||||
// Returns ErrNotFound if no salt has been stored yet (first run).
|
||||
func (db *DB) ReadMasterKeySalt() ([]byte, error) {
|
||||
var salt []byte
|
||||
err := db.sql.QueryRow(`
|
||||
SELECT master_key_salt FROM server_config WHERE id = 1
|
||||
`).Scan(&salt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: read master key salt: %w", err)
|
||||
}
|
||||
if salt == nil {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return salt, nil
|
||||
}
|
||||
|
||||
// WritePGCredentials stores or replaces the Postgres credentials for an account.
|
||||
func (db *DB) WritePGCredentials(accountID int64, host string, port int, dbName, username string, passwordEnc, passwordNonce []byte) error {
|
||||
n := now()
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT INTO pg_credentials
|
||||
(account_id, pg_host, pg_port, pg_database, pg_username, pg_password_enc, pg_password_nonce, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(account_id) DO UPDATE SET
|
||||
pg_host = excluded.pg_host,
|
||||
pg_port = excluded.pg_port,
|
||||
pg_database = excluded.pg_database,
|
||||
pg_username = excluded.pg_username,
|
||||
pg_password_enc = excluded.pg_password_enc,
|
||||
pg_password_nonce = excluded.pg_password_nonce,
|
||||
updated_at = excluded.updated_at
|
||||
`, accountID, host, port, dbName, username, passwordEnc, passwordNonce, n, n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: write pg credentials: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadPGCredentials retrieves the encrypted Postgres credentials for an account.
|
||||
// Returns ErrNotFound if no credentials are stored.
|
||||
func (db *DB) ReadPGCredentials(accountID int64) (*model.PGCredential, error) {
|
||||
var cred model.PGCredential
|
||||
var createdAtStr, updatedAtStr string
|
||||
|
||||
err := db.sql.QueryRow(`
|
||||
SELECT id, account_id, pg_host, pg_port, pg_database, pg_username,
|
||||
pg_password_enc, pg_password_nonce, created_at, updated_at
|
||||
FROM pg_credentials WHERE account_id = ?
|
||||
`, accountID).Scan(
|
||||
&cred.ID, &cred.AccountID, &cred.PGHost, &cred.PGPort,
|
||||
&cred.PGDatabase, &cred.PGUsername,
|
||||
&cred.PGPasswordEnc, &cred.PGPasswordNonce,
|
||||
&createdAtStr, &updatedAtStr,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: read pg credentials: %w", err)
|
||||
}
|
||||
|
||||
cred.CreatedAt, err = parseTime(createdAtStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cred.UpdatedAt, err = parseTime(updatedAtStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cred, nil
|
||||
}
|
||||
|
||||
// WriteAuditEvent appends an audit log entry.
|
||||
// Details must never contain credential material.
|
||||
func (db *DB) WriteAuditEvent(eventType string, actorID, targetID *int64, ipAddress, details string) error {
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT INTO audit_log (event_type, actor_id, target_id, ip_address, details)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
`, eventType, actorID, targetID, nullString(ipAddress), nullString(details))
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: write audit event %q: %w", eventType, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TrackToken records a newly issued JWT JTI for revocation tracking.
|
||||
func (db *DB) TrackToken(jti string, accountID int64, issuedAt, expiresAt time.Time) error {
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT INTO token_revocation (jti, account_id, issued_at, expires_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`, jti, accountID, issuedAt.UTC().Format(time.RFC3339), expiresAt.UTC().Format(time.RFC3339))
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: track token %q: %w", jti, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTokenRecord retrieves a token record by JTI.
|
||||
// Returns ErrNotFound if no record exists (token was never issued by this server).
|
||||
func (db *DB) GetTokenRecord(jti string) (*model.TokenRecord, error) {
|
||||
var rec model.TokenRecord
|
||||
var issuedAtStr, expiresAtStr, createdAtStr string
|
||||
var revokedAtStr *string
|
||||
var revokeReason *string
|
||||
|
||||
err := db.sql.QueryRow(`
|
||||
SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at
|
||||
FROM token_revocation WHERE jti = ?
|
||||
`, jti).Scan(
|
||||
&rec.ID, &rec.JTI, &rec.AccountID,
|
||||
&expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason,
|
||||
&createdAtStr,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: get token record %q: %w", jti, err)
|
||||
}
|
||||
|
||||
var parseErr error
|
||||
rec.ExpiresAt, parseErr = parseTime(expiresAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
rec.IssuedAt, parseErr = parseTime(issuedAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
rec.CreatedAt, parseErr = parseTime(createdAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
rec.RevokedAt, parseErr = nullableTime(revokedAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
if revokeReason != nil {
|
||||
rec.RevokeReason = *revokeReason
|
||||
}
|
||||
return &rec, nil
|
||||
}
|
||||
|
||||
// RevokeToken marks a token as revoked by JTI.
|
||||
func (db *DB) RevokeToken(jti, reason string) error {
|
||||
n := now()
|
||||
result, err := db.sql.Exec(`
|
||||
UPDATE token_revocation
|
||||
SET revoked_at = ?, revoke_reason = ?
|
||||
WHERE jti = ? AND revoked_at IS NULL
|
||||
`, n, nullString(reason), jti)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: revoke token %q: %w", jti, err)
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: revoke token rows affected: %w", err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("db: token %q not found or already revoked", jti)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllUserTokens revokes all non-expired, non-revoked tokens for an account.
|
||||
func (db *DB) RevokeAllUserTokens(accountID int64, reason string) error {
|
||||
n := now()
|
||||
_, err := db.sql.Exec(`
|
||||
UPDATE token_revocation
|
||||
SET revoked_at = ?, revoke_reason = ?
|
||||
WHERE account_id = ? AND revoked_at IS NULL AND expires_at > ?
|
||||
`, n, nullString(reason), accountID, n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: revoke all tokens for account %d: %w", accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PruneExpiredTokens removes token_revocation rows that are past their expiry.
|
||||
// Returns the number of rows deleted.
|
||||
func (db *DB) PruneExpiredTokens() (int64, error) {
|
||||
result, err := db.sql.Exec(`
|
||||
DELETE FROM token_revocation WHERE expires_at < ?
|
||||
`, now())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("db: prune expired tokens: %w", err)
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// SetSystemToken stores or replaces the active service token JTI for a system account.
|
||||
func (db *DB) SetSystemToken(accountID int64, jti string, expiresAt time.Time) error {
|
||||
n := now()
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT INTO system_tokens (account_id, jti, expires_at, created_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(account_id) DO UPDATE SET
|
||||
jti = excluded.jti,
|
||||
expires_at = excluded.expires_at,
|
||||
created_at = excluded.created_at
|
||||
`, accountID, jti, expiresAt.UTC().Format(time.RFC3339), n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: set system token for account %d: %w", accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSystemToken retrieves the active service token record for a system account.
|
||||
func (db *DB) GetSystemToken(accountID int64) (*model.SystemToken, error) {
|
||||
var st model.SystemToken
|
||||
var expiresAtStr, createdAtStr string
|
||||
|
||||
err := db.sql.QueryRow(`
|
||||
SELECT id, account_id, jti, expires_at, created_at
|
||||
FROM system_tokens WHERE account_id = ?
|
||||
`, accountID).Scan(&st.ID, &st.AccountID, &st.JTI, &expiresAtStr, &createdAtStr)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: get system token: %w", err)
|
||||
}
|
||||
|
||||
var parseErr error
|
||||
st.ExpiresAt, parseErr = parseTime(expiresAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
st.CreatedAt, parseErr = parseTime(createdAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
return &st, nil
|
||||
}
|
||||
Reference in New Issue
Block a user