checkpoint mciassrv
This commit is contained in:
608
internal/db/accounts.go
Normal file
608
internal/db/accounts.go
Normal file
@@ -0,0 +1,608 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// CreateAccount inserts a new account record. The UUID is generated
|
||||
// automatically. Returns the created Account with its DB-assigned ID and UUID.
|
||||
func (db *DB) CreateAccount(username string, accountType model.AccountType, passwordHash string) (*model.Account, error) {
|
||||
id := uuid.New().String()
|
||||
n := now()
|
||||
|
||||
result, err := db.sql.Exec(`
|
||||
INSERT INTO accounts (uuid, username, account_type, password_hash, status, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, 'active', ?, ?)
|
||||
`, id, username, string(accountType), nullString(passwordHash), n, n)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: create account %q: %w", username, err)
|
||||
}
|
||||
|
||||
rowID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: last insert id for account %q: %w", username, err)
|
||||
}
|
||||
|
||||
createdAt, err := parseTime(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.Account{
|
||||
ID: rowID,
|
||||
UUID: id,
|
||||
Username: username,
|
||||
AccountType: accountType,
|
||||
Status: model.AccountStatusActive,
|
||||
PasswordHash: passwordHash,
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: createdAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetAccountByUUID retrieves an account by its external UUID.
|
||||
// Returns ErrNotFound if no matching account exists.
|
||||
func (db *DB) GetAccountByUUID(accountUUID string) (*model.Account, error) {
|
||||
return db.scanAccount(db.sql.QueryRow(`
|
||||
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
|
||||
status, totp_required,
|
||||
totp_secret_enc, totp_secret_nonce,
|
||||
created_at, updated_at, deleted_at
|
||||
FROM accounts WHERE uuid = ?
|
||||
`, accountUUID))
|
||||
}
|
||||
|
||||
// GetAccountByUsername retrieves an account by username (case-insensitive).
|
||||
// Returns ErrNotFound if no matching account exists.
|
||||
func (db *DB) GetAccountByUsername(username string) (*model.Account, error) {
|
||||
return db.scanAccount(db.sql.QueryRow(`
|
||||
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
|
||||
status, totp_required,
|
||||
totp_secret_enc, totp_secret_nonce,
|
||||
created_at, updated_at, deleted_at
|
||||
FROM accounts WHERE username = ?
|
||||
`, username))
|
||||
}
|
||||
|
||||
// ListAccounts returns all non-deleted accounts ordered by username.
|
||||
func (db *DB) ListAccounts() ([]*model.Account, error) {
|
||||
rows, err := db.sql.Query(`
|
||||
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
|
||||
status, totp_required,
|
||||
totp_secret_enc, totp_secret_nonce,
|
||||
created_at, updated_at, deleted_at
|
||||
FROM accounts
|
||||
WHERE status != 'deleted'
|
||||
ORDER BY username ASC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: list accounts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var accounts []*model.Account
|
||||
for rows.Next() {
|
||||
a, err := db.scanAccountRow(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accounts = append(accounts, a)
|
||||
}
|
||||
return accounts, rows.Err()
|
||||
}
|
||||
|
||||
// UpdateAccountStatus updates the status field and optionally sets deleted_at.
|
||||
func (db *DB) UpdateAccountStatus(accountID int64, status model.AccountStatus) error {
|
||||
n := now()
|
||||
var deletedAt *string
|
||||
if status == model.AccountStatusDeleted {
|
||||
deletedAt = &n
|
||||
}
|
||||
|
||||
_, err := db.sql.Exec(`
|
||||
UPDATE accounts SET status = ?, deleted_at = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
`, string(status), deletedAt, n, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: update account status: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePasswordHash updates the Argon2id password hash for an account.
|
||||
func (db *DB) UpdatePasswordHash(accountID int64, hash string) error {
|
||||
_, err := db.sql.Exec(`
|
||||
UPDATE accounts SET password_hash = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
`, hash, now(), accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: update password hash: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTOTP stores the encrypted TOTP secret and marks TOTP as required.
|
||||
func (db *DB) SetTOTP(accountID int64, secretEnc, secretNonce []byte) error {
|
||||
_, err := db.sql.Exec(`
|
||||
UPDATE accounts
|
||||
SET totp_required = 1, totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
`, secretEnc, secretNonce, now(), accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: set TOTP: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClearTOTP removes the TOTP secret and disables TOTP requirement.
|
||||
func (db *DB) ClearTOTP(accountID int64) error {
|
||||
_, err := db.sql.Exec(`
|
||||
UPDATE accounts
|
||||
SET totp_required = 0, totp_secret_enc = NULL, totp_secret_nonce = NULL, updated_at = ?
|
||||
WHERE id = ?
|
||||
`, now(), accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: clear TOTP: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// scanAccount scans a single account row from a *sql.Row.
|
||||
func (db *DB) scanAccount(row *sql.Row) (*model.Account, error) {
|
||||
var a model.Account
|
||||
var accountType, status string
|
||||
var totpRequired int
|
||||
var createdAtStr, updatedAtStr string
|
||||
var deletedAtStr *string
|
||||
var totpSecretEnc, totpSecretNonce []byte
|
||||
|
||||
err := row.Scan(
|
||||
&a.ID, &a.UUID, &a.Username,
|
||||
&accountType, &a.PasswordHash,
|
||||
&status, &totpRequired,
|
||||
&totpSecretEnc, &totpSecretNonce,
|
||||
&createdAtStr, &updatedAtStr, &deletedAtStr,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: scan account: %w", err)
|
||||
}
|
||||
|
||||
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
|
||||
}
|
||||
|
||||
// scanAccountRow scans a single account from *sql.Rows.
|
||||
func (db *DB) scanAccountRow(rows *sql.Rows) (*model.Account, error) {
|
||||
var a model.Account
|
||||
var accountType, status string
|
||||
var totpRequired int
|
||||
var createdAtStr, updatedAtStr string
|
||||
var deletedAtStr *string
|
||||
var totpSecretEnc, totpSecretNonce []byte
|
||||
|
||||
err := rows.Scan(
|
||||
&a.ID, &a.UUID, &a.Username,
|
||||
&accountType, &a.PasswordHash,
|
||||
&status, &totpRequired,
|
||||
&totpSecretEnc, &totpSecretNonce,
|
||||
&createdAtStr, &updatedAtStr, &deletedAtStr,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: scan account row: %w", err)
|
||||
}
|
||||
|
||||
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
|
||||
}
|
||||
|
||||
func finishAccountScan(a *model.Account, accountType, status string, totpRequired int, totpSecretEnc, totpSecretNonce []byte, createdAtStr, updatedAtStr string, deletedAtStr *string) (*model.Account, error) {
|
||||
a.AccountType = model.AccountType(accountType)
|
||||
a.Status = model.AccountStatus(status)
|
||||
a.TOTPRequired = totpRequired == 1
|
||||
a.TOTPSecretEnc = totpSecretEnc
|
||||
a.TOTPSecretNonce = totpSecretNonce
|
||||
|
||||
var err error
|
||||
a.CreatedAt, err = parseTime(createdAtStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.UpdatedAt, err = parseTime(updatedAtStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.DeletedAt, err = nullableTime(deletedAtStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// nullString converts an empty string to nil for nullable SQL columns.
|
||||
func nullString(s string) *string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
// GetRoles returns the role strings assigned to an account.
|
||||
func (db *DB) GetRoles(accountID int64) ([]string, error) {
|
||||
rows, err := db.sql.Query(`
|
||||
SELECT role FROM account_roles WHERE account_id = ? ORDER BY role ASC
|
||||
`, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: get roles for account %d: %w", accountID, err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var roles []string
|
||||
for rows.Next() {
|
||||
var role string
|
||||
if err := rows.Scan(&role); err != nil {
|
||||
return nil, fmt.Errorf("db: scan role: %w", err)
|
||||
}
|
||||
roles = append(roles, role)
|
||||
}
|
||||
return roles, rows.Err()
|
||||
}
|
||||
|
||||
// GrantRole adds a role to an account. If the role already exists, it is a no-op.
|
||||
func (db *DB) GrantRole(accountID int64, role string, grantedBy *int64) error {
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT OR IGNORE INTO account_roles (account_id, role, granted_by, granted_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`, accountID, role, grantedBy, now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: grant role %q to account %d: %w", role, accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeRole removes a role from an account.
|
||||
func (db *DB) RevokeRole(accountID int64, role string) error {
|
||||
_, err := db.sql.Exec(`
|
||||
DELETE FROM account_roles WHERE account_id = ? AND role = ?
|
||||
`, accountID, role)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: revoke role %q from account %d: %w", role, accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetRoles replaces the full role set for an account atomically.
|
||||
func (db *DB) SetRoles(accountID int64, roles []string, grantedBy *int64) error {
|
||||
tx, err := db.sql.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: set roles begin tx: %w", err)
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(`DELETE FROM account_roles WHERE account_id = ?`, accountID); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("db: set roles delete existing: %w", err)
|
||||
}
|
||||
|
||||
n := now()
|
||||
for _, role := range roles {
|
||||
if _, err := tx.Exec(`
|
||||
INSERT INTO account_roles (account_id, role, granted_by, granted_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`, accountID, role, grantedBy, n); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("db: set roles insert %q: %w", role, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("db: set roles commit: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HasRole reports whether an account holds the given role.
|
||||
func (db *DB) HasRole(accountID int64, role string) (bool, error) {
|
||||
var count int
|
||||
err := db.sql.QueryRow(`
|
||||
SELECT COUNT(*) FROM account_roles WHERE account_id = ? AND role = ?
|
||||
`, accountID, role).Scan(&count)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("db: has role: %w", err)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// WriteServerConfig stores the encrypted Ed25519 signing key.
|
||||
// There can only be one row (id=1).
|
||||
func (db *DB) WriteServerConfig(signingKeyEnc, signingKeyNonce []byte) error {
|
||||
n := now()
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT INTO server_config (id, signing_key_enc, signing_key_nonce, created_at, updated_at)
|
||||
VALUES (1, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
signing_key_enc = excluded.signing_key_enc,
|
||||
signing_key_nonce = excluded.signing_key_nonce,
|
||||
updated_at = excluded.updated_at
|
||||
`, signingKeyEnc, signingKeyNonce, n, n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: write server config: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadServerConfig returns the encrypted signing key and nonce.
|
||||
// Returns ErrNotFound if no config row exists yet.
|
||||
func (db *DB) ReadServerConfig() (signingKeyEnc, signingKeyNonce []byte, err error) {
|
||||
err = db.sql.QueryRow(`
|
||||
SELECT signing_key_enc, signing_key_nonce FROM server_config WHERE id = 1
|
||||
`).Scan(&signingKeyEnc, &signingKeyNonce)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("db: read server config: %w", err)
|
||||
}
|
||||
return signingKeyEnc, signingKeyNonce, nil
|
||||
}
|
||||
|
||||
// WriteMasterKeySalt stores the Argon2id KDF salt for the master key derivation.
|
||||
// The salt must be stable across restarts so the same passphrase always yields
|
||||
// the same master key. There can only be one row (id=1).
|
||||
func (db *DB) WriteMasterKeySalt(salt []byte) error {
|
||||
n := now()
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT INTO server_config (id, master_key_salt, created_at, updated_at)
|
||||
VALUES (1, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
master_key_salt = excluded.master_key_salt,
|
||||
updated_at = excluded.updated_at
|
||||
`, salt, n, n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: write master key salt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadMasterKeySalt returns the stored Argon2id KDF salt.
|
||||
// Returns ErrNotFound if no salt has been stored yet (first run).
|
||||
func (db *DB) ReadMasterKeySalt() ([]byte, error) {
|
||||
var salt []byte
|
||||
err := db.sql.QueryRow(`
|
||||
SELECT master_key_salt FROM server_config WHERE id = 1
|
||||
`).Scan(&salt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: read master key salt: %w", err)
|
||||
}
|
||||
if salt == nil {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return salt, nil
|
||||
}
|
||||
|
||||
// WritePGCredentials stores or replaces the Postgres credentials for an account.
|
||||
func (db *DB) WritePGCredentials(accountID int64, host string, port int, dbName, username string, passwordEnc, passwordNonce []byte) error {
|
||||
n := now()
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT INTO pg_credentials
|
||||
(account_id, pg_host, pg_port, pg_database, pg_username, pg_password_enc, pg_password_nonce, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(account_id) DO UPDATE SET
|
||||
pg_host = excluded.pg_host,
|
||||
pg_port = excluded.pg_port,
|
||||
pg_database = excluded.pg_database,
|
||||
pg_username = excluded.pg_username,
|
||||
pg_password_enc = excluded.pg_password_enc,
|
||||
pg_password_nonce = excluded.pg_password_nonce,
|
||||
updated_at = excluded.updated_at
|
||||
`, accountID, host, port, dbName, username, passwordEnc, passwordNonce, n, n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: write pg credentials: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadPGCredentials retrieves the encrypted Postgres credentials for an account.
|
||||
// Returns ErrNotFound if no credentials are stored.
|
||||
func (db *DB) ReadPGCredentials(accountID int64) (*model.PGCredential, error) {
|
||||
var cred model.PGCredential
|
||||
var createdAtStr, updatedAtStr string
|
||||
|
||||
err := db.sql.QueryRow(`
|
||||
SELECT id, account_id, pg_host, pg_port, pg_database, pg_username,
|
||||
pg_password_enc, pg_password_nonce, created_at, updated_at
|
||||
FROM pg_credentials WHERE account_id = ?
|
||||
`, accountID).Scan(
|
||||
&cred.ID, &cred.AccountID, &cred.PGHost, &cred.PGPort,
|
||||
&cred.PGDatabase, &cred.PGUsername,
|
||||
&cred.PGPasswordEnc, &cred.PGPasswordNonce,
|
||||
&createdAtStr, &updatedAtStr,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: read pg credentials: %w", err)
|
||||
}
|
||||
|
||||
cred.CreatedAt, err = parseTime(createdAtStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cred.UpdatedAt, err = parseTime(updatedAtStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cred, nil
|
||||
}
|
||||
|
||||
// WriteAuditEvent appends an audit log entry.
|
||||
// Details must never contain credential material.
|
||||
func (db *DB) WriteAuditEvent(eventType string, actorID, targetID *int64, ipAddress, details string) error {
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT INTO audit_log (event_type, actor_id, target_id, ip_address, details)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
`, eventType, actorID, targetID, nullString(ipAddress), nullString(details))
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: write audit event %q: %w", eventType, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TrackToken records a newly issued JWT JTI for revocation tracking.
|
||||
func (db *DB) TrackToken(jti string, accountID int64, issuedAt, expiresAt time.Time) error {
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT INTO token_revocation (jti, account_id, issued_at, expires_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
`, jti, accountID, issuedAt.UTC().Format(time.RFC3339), expiresAt.UTC().Format(time.RFC3339))
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: track token %q: %w", jti, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTokenRecord retrieves a token record by JTI.
|
||||
// Returns ErrNotFound if no record exists (token was never issued by this server).
|
||||
func (db *DB) GetTokenRecord(jti string) (*model.TokenRecord, error) {
|
||||
var rec model.TokenRecord
|
||||
var issuedAtStr, expiresAtStr, createdAtStr string
|
||||
var revokedAtStr *string
|
||||
var revokeReason *string
|
||||
|
||||
err := db.sql.QueryRow(`
|
||||
SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at
|
||||
FROM token_revocation WHERE jti = ?
|
||||
`, jti).Scan(
|
||||
&rec.ID, &rec.JTI, &rec.AccountID,
|
||||
&expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason,
|
||||
&createdAtStr,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: get token record %q: %w", jti, err)
|
||||
}
|
||||
|
||||
var parseErr error
|
||||
rec.ExpiresAt, parseErr = parseTime(expiresAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
rec.IssuedAt, parseErr = parseTime(issuedAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
rec.CreatedAt, parseErr = parseTime(createdAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
rec.RevokedAt, parseErr = nullableTime(revokedAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
if revokeReason != nil {
|
||||
rec.RevokeReason = *revokeReason
|
||||
}
|
||||
return &rec, nil
|
||||
}
|
||||
|
||||
// RevokeToken marks a token as revoked by JTI.
|
||||
func (db *DB) RevokeToken(jti, reason string) error {
|
||||
n := now()
|
||||
result, err := db.sql.Exec(`
|
||||
UPDATE token_revocation
|
||||
SET revoked_at = ?, revoke_reason = ?
|
||||
WHERE jti = ? AND revoked_at IS NULL
|
||||
`, n, nullString(reason), jti)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: revoke token %q: %w", jti, err)
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: revoke token rows affected: %w", err)
|
||||
}
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("db: token %q not found or already revoked", jti)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeAllUserTokens revokes all non-expired, non-revoked tokens for an account.
|
||||
func (db *DB) RevokeAllUserTokens(accountID int64, reason string) error {
|
||||
n := now()
|
||||
_, err := db.sql.Exec(`
|
||||
UPDATE token_revocation
|
||||
SET revoked_at = ?, revoke_reason = ?
|
||||
WHERE account_id = ? AND revoked_at IS NULL AND expires_at > ?
|
||||
`, n, nullString(reason), accountID, n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: revoke all tokens for account %d: %w", accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// PruneExpiredTokens removes token_revocation rows that are past their expiry.
|
||||
// Returns the number of rows deleted.
|
||||
func (db *DB) PruneExpiredTokens() (int64, error) {
|
||||
result, err := db.sql.Exec(`
|
||||
DELETE FROM token_revocation WHERE expires_at < ?
|
||||
`, now())
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("db: prune expired tokens: %w", err)
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// SetSystemToken stores or replaces the active service token JTI for a system account.
|
||||
func (db *DB) SetSystemToken(accountID int64, jti string, expiresAt time.Time) error {
|
||||
n := now()
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT INTO system_tokens (account_id, jti, expires_at, created_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(account_id) DO UPDATE SET
|
||||
jti = excluded.jti,
|
||||
expires_at = excluded.expires_at,
|
||||
created_at = excluded.created_at
|
||||
`, accountID, jti, expiresAt.UTC().Format(time.RFC3339), n)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: set system token for account %d: %w", accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSystemToken retrieves the active service token record for a system account.
|
||||
func (db *DB) GetSystemToken(accountID int64) (*model.SystemToken, error) {
|
||||
var st model.SystemToken
|
||||
var expiresAtStr, createdAtStr string
|
||||
|
||||
err := db.sql.QueryRow(`
|
||||
SELECT id, account_id, jti, expires_at, created_at
|
||||
FROM system_tokens WHERE account_id = ?
|
||||
`, accountID).Scan(&st.ID, &st.AccountID, &st.JTI, &expiresAtStr, &createdAtStr)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: get system token: %w", err)
|
||||
}
|
||||
|
||||
var parseErr error
|
||||
st.ExpiresAt, parseErr = parseTime(expiresAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
st.CreatedAt, parseErr = parseTime(createdAtStr)
|
||||
if parseErr != nil {
|
||||
return nil, parseErr
|
||||
}
|
||||
return &st, nil
|
||||
}
|
||||
109
internal/db/db.go
Normal file
109
internal/db/db.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Package db provides the SQLite database access layer for MCIAS.
|
||||
//
|
||||
// Security design:
|
||||
// - All queries use parameterized statements; no string concatenation.
|
||||
// - Foreign keys are enforced (PRAGMA foreign_keys = ON).
|
||||
// - WAL mode is enabled for safe concurrent reads during writes.
|
||||
// - The audit log is append-only: no update or delete operations are provided.
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite" // register the sqlite3 driver
|
||||
)
|
||||
|
||||
// DB wraps a *sql.DB with MCIAS-specific helpers.
|
||||
type DB struct {
|
||||
sql *sql.DB
|
||||
}
|
||||
|
||||
// Open opens (or creates) the SQLite database at path and configures it for
|
||||
// MCIAS use (WAL mode, foreign keys, busy timeout).
|
||||
func Open(path string) (*DB, error) {
|
||||
// The modernc.org/sqlite driver is registered as "sqlite".
|
||||
sqlDB, err := sql.Open("sqlite", path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: open sqlite: %w", err)
|
||||
}
|
||||
|
||||
// Use a single connection for writes; reads can use the pool.
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
db := &DB{sql: sqlDB}
|
||||
if err := db.configure(); err != nil {
|
||||
_ = sqlDB.Close()
|
||||
return nil, err
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// configure applies PRAGMAs that must be set on every connection.
|
||||
func (db *DB) configure() error {
|
||||
pragmas := []string{
|
||||
"PRAGMA journal_mode=WAL",
|
||||
"PRAGMA foreign_keys=ON",
|
||||
"PRAGMA busy_timeout=5000",
|
||||
"PRAGMA synchronous=NORMAL",
|
||||
}
|
||||
for _, p := range pragmas {
|
||||
if _, err := db.sql.Exec(p); err != nil {
|
||||
return fmt.Errorf("db: configure pragma %q: %w", p, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the underlying database connection.
|
||||
func (db *DB) Close() error {
|
||||
return db.sql.Close()
|
||||
}
|
||||
|
||||
// Ping verifies the database connection is alive.
|
||||
func (db *DB) Ping(ctx context.Context) error {
|
||||
return db.sql.PingContext(ctx)
|
||||
}
|
||||
|
||||
// SQL returns the underlying *sql.DB for use in tests or advanced queries.
|
||||
// Prefer the typed methods on DB for all production code.
|
||||
func (db *DB) SQL() *sql.DB {
|
||||
return db.sql
|
||||
}
|
||||
|
||||
// now returns the current UTC time formatted as ISO-8601.
|
||||
func now() string {
|
||||
return time.Now().UTC().Format(time.RFC3339)
|
||||
}
|
||||
|
||||
// parseTime parses an ISO-8601 UTC time string returned by SQLite.
|
||||
func parseTime(s string) (time.Time, error) {
|
||||
t, err := time.Parse(time.RFC3339, s)
|
||||
if err != nil {
|
||||
// Try without timezone suffix (some SQLite defaults).
|
||||
t, err = time.Parse("2006-01-02T15:04:05", s)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("db: parse time %q: %w", s, err)
|
||||
}
|
||||
return t.UTC(), nil
|
||||
}
|
||||
return t.UTC(), nil
|
||||
}
|
||||
|
||||
// ErrNotFound is returned when a requested record does not exist.
|
||||
var ErrNotFound = errors.New("db: record not found")
|
||||
|
||||
// nullableTime converts a *string from SQLite into a *time.Time.
|
||||
func nullableTime(s *string) (*time.Time, error) {
|
||||
if s == nil {
|
||||
return nil, nil
|
||||
}
|
||||
t, err := parseTime(*s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
355
internal/db/db_test.go
Normal file
355
internal/db/db_test.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
)
|
||||
|
||||
// openTestDB opens an in-memory SQLite database for testing.
|
||||
func openTestDB(t *testing.T) *DB {
|
||||
t.Helper()
|
||||
db, err := Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Open: %v", err)
|
||||
}
|
||||
if err := Migrate(db); err != nil {
|
||||
t.Fatalf("Migrate: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
return db
|
||||
}
|
||||
|
||||
func TestMigrateIdempotent(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
// Run again — should be a no-op.
|
||||
if err := Migrate(db); err != nil {
|
||||
t.Errorf("second Migrate call returned error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAndGetAccount(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
acct, err := db.CreateAccount("alice", model.AccountTypeHuman, "$argon2id$v=19$...")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
if acct.UUID == "" {
|
||||
t.Error("expected non-empty UUID")
|
||||
}
|
||||
if acct.Username != "alice" {
|
||||
t.Errorf("Username = %q, want %q", acct.Username, "alice")
|
||||
}
|
||||
if acct.Status != model.AccountStatusActive {
|
||||
t.Errorf("Status = %q, want active", acct.Status)
|
||||
}
|
||||
|
||||
// Retrieve by UUID.
|
||||
got, err := db.GetAccountByUUID(acct.UUID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAccountByUUID: %v", err)
|
||||
}
|
||||
if got.Username != "alice" {
|
||||
t.Errorf("fetched Username = %q, want %q", got.Username, "alice")
|
||||
}
|
||||
|
||||
// Retrieve by username.
|
||||
got2, err := db.GetAccountByUsername("alice")
|
||||
if err != nil {
|
||||
t.Fatalf("GetAccountByUsername: %v", err)
|
||||
}
|
||||
if got2.UUID != acct.UUID {
|
||||
t.Errorf("UUID mismatch: got %q, want %q", got2.UUID, acct.UUID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountNotFound(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.GetAccountByUUID("nonexistent-uuid")
|
||||
if err != ErrNotFound {
|
||||
t.Errorf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
|
||||
_, err = db.GetAccountByUsername("nobody")
|
||||
if err != ErrNotFound {
|
||||
t.Errorf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAccountStatus(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
acct, err := db.CreateAccount("bob", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
if err := db.UpdateAccountStatus(acct.ID, model.AccountStatusInactive); err != nil {
|
||||
t.Fatalf("UpdateAccountStatus: %v", err)
|
||||
}
|
||||
|
||||
got, err := db.GetAccountByUUID(acct.UUID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAccountByUUID: %v", err)
|
||||
}
|
||||
if got.Status != model.AccountStatusInactive {
|
||||
t.Errorf("Status = %q, want inactive", got.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAccounts(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
for _, name := range []string{"charlie", "delta", "eve"} {
|
||||
if _, err := db.CreateAccount(name, model.AccountTypeHuman, "hash"); err != nil {
|
||||
t.Fatalf("CreateAccount %q: %v", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
accts, err := db.ListAccounts()
|
||||
if err != nil {
|
||||
t.Fatalf("ListAccounts: %v", err)
|
||||
}
|
||||
if len(accts) != 3 {
|
||||
t.Errorf("ListAccounts returned %d accounts, want 3", len(accts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleOperations(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
acct, err := db.CreateAccount("frank", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
// GrantRole
|
||||
if err := db.GrantRole(acct.ID, "admin", nil); err != nil {
|
||||
t.Fatalf("GrantRole: %v", err)
|
||||
}
|
||||
// Grant again — should be no-op.
|
||||
if err := db.GrantRole(acct.ID, "admin", nil); err != nil {
|
||||
t.Fatalf("GrantRole duplicate: %v", err)
|
||||
}
|
||||
|
||||
roles, err := db.GetRoles(acct.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRoles: %v", err)
|
||||
}
|
||||
if len(roles) != 1 || roles[0] != "admin" {
|
||||
t.Errorf("GetRoles = %v, want [admin]", roles)
|
||||
}
|
||||
|
||||
has, err := db.HasRole(acct.ID, "admin")
|
||||
if err != nil {
|
||||
t.Fatalf("HasRole: %v", err)
|
||||
}
|
||||
if !has {
|
||||
t.Error("expected HasRole to return true for 'admin'")
|
||||
}
|
||||
|
||||
// RevokeRole
|
||||
if err := db.RevokeRole(acct.ID, "admin"); err != nil {
|
||||
t.Fatalf("RevokeRole: %v", err)
|
||||
}
|
||||
roles, err = db.GetRoles(acct.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRoles after revoke: %v", err)
|
||||
}
|
||||
if len(roles) != 0 {
|
||||
t.Errorf("expected no roles after revoke, got %v", roles)
|
||||
}
|
||||
|
||||
// SetRoles
|
||||
if err := db.SetRoles(acct.ID, []string{"reader", "writer"}, nil); err != nil {
|
||||
t.Fatalf("SetRoles: %v", err)
|
||||
}
|
||||
roles, err = db.GetRoles(acct.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRoles after SetRoles: %v", err)
|
||||
}
|
||||
if len(roles) != 2 {
|
||||
t.Errorf("expected 2 roles after SetRoles, got %d", len(roles))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenTrackingAndRevocation(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
acct, err := db.CreateAccount("grace", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
jti := "test-jti-1234"
|
||||
issuedAt := time.Now().UTC()
|
||||
expiresAt := issuedAt.Add(time.Hour)
|
||||
|
||||
if err := db.TrackToken(jti, acct.ID, issuedAt, expiresAt); err != nil {
|
||||
t.Fatalf("TrackToken: %v", err)
|
||||
}
|
||||
|
||||
// Retrieve
|
||||
rec, err := db.GetTokenRecord(jti)
|
||||
if err != nil {
|
||||
t.Fatalf("GetTokenRecord: %v", err)
|
||||
}
|
||||
if rec.JTI != jti {
|
||||
t.Errorf("JTI = %q, want %q", rec.JTI, jti)
|
||||
}
|
||||
if rec.IsRevoked() {
|
||||
t.Error("newly tracked token should not be revoked")
|
||||
}
|
||||
|
||||
// Revoke
|
||||
if err := db.RevokeToken(jti, "test revocation"); err != nil {
|
||||
t.Fatalf("RevokeToken: %v", err)
|
||||
}
|
||||
rec, err = db.GetTokenRecord(jti)
|
||||
if err != nil {
|
||||
t.Fatalf("GetTokenRecord after revoke: %v", err)
|
||||
}
|
||||
if !rec.IsRevoked() {
|
||||
t.Error("token should be revoked after RevokeToken")
|
||||
}
|
||||
|
||||
// Revoking again should fail (already revoked).
|
||||
if err := db.RevokeToken(jti, "again"); err == nil {
|
||||
t.Error("expected error when revoking already-revoked token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTokenRecordNotFound(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
_, err := db.GetTokenRecord("no-such-jti")
|
||||
if err != ErrNotFound {
|
||||
t.Errorf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPruneExpiredTokens(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
acct, err := db.CreateAccount("henry", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
past := time.Now().UTC().Add(-time.Hour)
|
||||
future := time.Now().UTC().Add(time.Hour)
|
||||
|
||||
if err := db.TrackToken("expired-jti", acct.ID, past.Add(-time.Hour), past); err != nil {
|
||||
t.Fatalf("TrackToken expired: %v", err)
|
||||
}
|
||||
if err := db.TrackToken("valid-jti", acct.ID, time.Now(), future); err != nil {
|
||||
t.Fatalf("TrackToken valid: %v", err)
|
||||
}
|
||||
|
||||
n, err := db.PruneExpiredTokens()
|
||||
if err != nil {
|
||||
t.Fatalf("PruneExpiredTokens: %v", err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Errorf("pruned %d rows, want 1", n)
|
||||
}
|
||||
|
||||
// Valid token should still be retrievable.
|
||||
if _, err := db.GetTokenRecord("valid-jti"); err != nil {
|
||||
t.Errorf("valid token missing after prune: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerConfig(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
// No config initially.
|
||||
_, _, err := db.ReadServerConfig()
|
||||
if err != ErrNotFound {
|
||||
t.Errorf("expected ErrNotFound for missing config, got %v", err)
|
||||
}
|
||||
|
||||
enc := []byte("encrypted-key-data")
|
||||
nonce := []byte("nonce12345678901")
|
||||
|
||||
if err := db.WriteServerConfig(enc, nonce); err != nil {
|
||||
t.Fatalf("WriteServerConfig: %v", err)
|
||||
}
|
||||
|
||||
gotEnc, gotNonce, err := db.ReadServerConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("ReadServerConfig: %v", err)
|
||||
}
|
||||
if string(gotEnc) != string(enc) {
|
||||
t.Errorf("enc mismatch: got %q, want %q", gotEnc, enc)
|
||||
}
|
||||
if string(gotNonce) != string(nonce) {
|
||||
t.Errorf("nonce mismatch: got %q, want %q", gotNonce, nonce)
|
||||
}
|
||||
|
||||
// Overwrite — should work without error.
|
||||
if err := db.WriteServerConfig([]byte("new-key"), []byte("new-nonce123456")); err != nil {
|
||||
t.Fatalf("WriteServerConfig overwrite: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeignKeyEnforcement(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
// Attempting to track a token for a non-existent account should fail.
|
||||
err := db.TrackToken("jti-x", 999999, time.Now(), time.Now().Add(time.Hour))
|
||||
if err == nil {
|
||||
t.Error("expected foreign key error for non-existent account_id, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPGCredentials(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
acct, err := db.CreateAccount("svc", model.AccountTypeSystem, "")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
enc := []byte("encrypted-pg-password")
|
||||
nonce := []byte("pg-nonce12345678")
|
||||
|
||||
if err := db.WritePGCredentials(acct.ID, "localhost", 5432, "mydb", "myuser", enc, nonce); err != nil {
|
||||
t.Fatalf("WritePGCredentials: %v", err)
|
||||
}
|
||||
|
||||
cred, err := db.ReadPGCredentials(acct.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadPGCredentials: %v", err)
|
||||
}
|
||||
if cred.PGHost != "localhost" {
|
||||
t.Errorf("PGHost = %q, want %q", cred.PGHost, "localhost")
|
||||
}
|
||||
if cred.PGDatabase != "mydb" {
|
||||
t.Errorf("PGDatabase = %q, want %q", cred.PGDatabase, "mydb")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRevokeAllUserTokens(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
acct, err := db.CreateAccount("ivan", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
future := time.Now().UTC().Add(time.Hour)
|
||||
for _, jti := range []string{"tok1", "tok2", "tok3"} {
|
||||
if err := db.TrackToken(jti, acct.ID, time.Now(), future); err != nil {
|
||||
t.Fatalf("TrackToken %q: %v", jti, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := db.RevokeAllUserTokens(acct.ID, "account suspended"); err != nil {
|
||||
t.Fatalf("RevokeAllUserTokens: %v", err)
|
||||
}
|
||||
|
||||
for _, jti := range []string{"tok1", "tok2", "tok3"} {
|
||||
rec, err := db.GetTokenRecord(jti)
|
||||
if err != nil {
|
||||
t.Fatalf("GetTokenRecord %q: %v", jti, err)
|
||||
}
|
||||
if !rec.IsRevoked() {
|
||||
t.Errorf("token %q should be revoked", jti)
|
||||
}
|
||||
}
|
||||
}
|
||||
187
internal/db/migrate.go
Normal file
187
internal/db/migrate.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// migration represents a single schema migration with an ID and SQL statement.
|
||||
type migration struct {
|
||||
id int
|
||||
sql string
|
||||
}
|
||||
|
||||
// migrations is the ordered list of schema migrations applied to the database.
|
||||
// Once applied, migrations must never be modified — only new ones appended.
|
||||
var migrations = []migration{
|
||||
{
|
||||
id: 1,
|
||||
sql: `
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
version INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS server_config (
|
||||
id INTEGER PRIMARY KEY CHECK (id = 1),
|
||||
signing_key_enc BLOB,
|
||||
signing_key_nonce BLOB,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS accounts (
|
||||
id INTEGER PRIMARY KEY,
|
||||
uuid TEXT NOT NULL UNIQUE,
|
||||
username TEXT NOT NULL UNIQUE COLLATE NOCASE,
|
||||
account_type TEXT NOT NULL CHECK (account_type IN ('human','system')),
|
||||
password_hash TEXT,
|
||||
status TEXT NOT NULL DEFAULT 'active'
|
||||
CHECK (status IN ('active','inactive','deleted')),
|
||||
totp_required INTEGER NOT NULL DEFAULT 0 CHECK (totp_required IN (0,1)),
|
||||
totp_secret_enc BLOB,
|
||||
totp_secret_nonce BLOB,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||
deleted_at TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_accounts_username ON accounts (username);
|
||||
CREATE INDEX IF NOT EXISTS idx_accounts_uuid ON accounts (uuid);
|
||||
CREATE INDEX IF NOT EXISTS idx_accounts_status ON accounts (status);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS account_roles (
|
||||
id INTEGER PRIMARY KEY,
|
||||
account_id INTEGER NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
role TEXT NOT NULL,
|
||||
granted_by INTEGER REFERENCES accounts(id),
|
||||
granted_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||
UNIQUE (account_id, role)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_account_roles_account ON account_roles (account_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS token_revocation (
|
||||
id INTEGER PRIMARY KEY,
|
||||
jti TEXT NOT NULL UNIQUE,
|
||||
account_id INTEGER NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
expires_at TEXT NOT NULL,
|
||||
revoked_at TEXT,
|
||||
revoke_reason TEXT,
|
||||
issued_at TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_token_jti ON token_revocation (jti);
|
||||
CREATE INDEX IF NOT EXISTS idx_token_account ON token_revocation (account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_token_expires ON token_revocation (expires_at);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS system_tokens (
|
||||
id INTEGER PRIMARY KEY,
|
||||
account_id INTEGER NOT NULL UNIQUE REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
jti TEXT NOT NULL UNIQUE,
|
||||
expires_at TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS pg_credentials (
|
||||
id INTEGER PRIMARY KEY,
|
||||
account_id INTEGER NOT NULL UNIQUE REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
pg_host TEXT NOT NULL,
|
||||
pg_port INTEGER NOT NULL DEFAULT 5432,
|
||||
pg_database TEXT NOT NULL,
|
||||
pg_username TEXT NOT NULL,
|
||||
pg_password_enc BLOB NOT NULL,
|
||||
pg_password_nonce BLOB NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS audit_log (
|
||||
id INTEGER PRIMARY KEY,
|
||||
event_time TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||
event_type TEXT NOT NULL,
|
||||
actor_id INTEGER REFERENCES accounts(id),
|
||||
target_id INTEGER REFERENCES accounts(id),
|
||||
ip_address TEXT,
|
||||
details TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_time ON audit_log (event_time);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_actor ON audit_log (actor_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_audit_event ON audit_log (event_type);
|
||||
`,
|
||||
},
|
||||
{
|
||||
id: 2,
|
||||
sql: `
|
||||
-- Add master_key_salt to server_config for Argon2id KDF salt storage.
|
||||
-- The salt must be stable across restarts so the passphrase always yields the same key.
|
||||
-- We allow NULL signing_key_enc/nonce temporarily until the first signing key is generated.
|
||||
ALTER TABLE server_config ADD COLUMN master_key_salt BLOB;
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
// Migrate applies any unapplied schema migrations to the database in order.
|
||||
// It is idempotent: running it multiple times is safe.
|
||||
func Migrate(db *DB) error {
|
||||
// Ensure the schema_version table exists first.
|
||||
if _, err := db.sql.Exec(`
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
version INTEGER NOT NULL
|
||||
)
|
||||
`); err != nil {
|
||||
return fmt.Errorf("db: ensure schema_version: %w", err)
|
||||
}
|
||||
|
||||
currentVersion, err := currentSchemaVersion(db.sql)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: get current schema version: %w", err)
|
||||
}
|
||||
|
||||
for _, m := range migrations {
|
||||
if m.id <= currentVersion {
|
||||
continue
|
||||
}
|
||||
|
||||
tx, err := db.sql.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: begin migration %d transaction: %w", m.id, err)
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(m.sql); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("db: apply migration %d: %w", m.id, err)
|
||||
}
|
||||
|
||||
// Update the schema version within the same transaction.
|
||||
if currentVersion == 0 {
|
||||
if _, err := tx.Exec(`INSERT INTO schema_version (version) VALUES (?)`, m.id); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("db: insert schema version %d: %w", m.id, err)
|
||||
}
|
||||
} else {
|
||||
if _, err := tx.Exec(`UPDATE schema_version SET version = ?`, m.id); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("db: update schema version to %d: %w", m.id, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("db: commit migration %d: %w", m.id, err)
|
||||
}
|
||||
currentVersion = m.id
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// currentSchemaVersion returns the current schema version, or 0 if none applied.
|
||||
func currentSchemaVersion(db *sql.DB) (int, error) {
|
||||
var version int
|
||||
err := db.QueryRow(`SELECT version FROM schema_version LIMIT 1`).Scan(&version)
|
||||
if err != nil {
|
||||
// No rows means version 0 (fresh database).
|
||||
return 0, nil //nolint:nilerr
|
||||
}
|
||||
return version, nil
|
||||
}
|
||||
Reference in New Issue
Block a user