checkpoint mciassrv

This commit is contained in:
2026-03-11 11:48:24 -07:00
parent 9e4e7aba7a
commit d75a1d6fd3
21 changed files with 5307 additions and 0 deletions

608
internal/db/accounts.go Normal file
View File

@@ -0,0 +1,608 @@
package db
import (
"database/sql"
"errors"
"fmt"
"time"
"git.wntrmute.dev/kyle/mcias/internal/model"
"github.com/google/uuid"
)
// CreateAccount inserts a new account record. The UUID is generated
// automatically. Returns the created Account with its DB-assigned ID and UUID.
func (db *DB) CreateAccount(username string, accountType model.AccountType, passwordHash string) (*model.Account, error) {
id := uuid.New().String()
n := now()
result, err := db.sql.Exec(`
INSERT INTO accounts (uuid, username, account_type, password_hash, status, created_at, updated_at)
VALUES (?, ?, ?, ?, 'active', ?, ?)
`, id, username, string(accountType), nullString(passwordHash), n, n)
if err != nil {
return nil, fmt.Errorf("db: create account %q: %w", username, err)
}
rowID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("db: last insert id for account %q: %w", username, err)
}
createdAt, err := parseTime(n)
if err != nil {
return nil, err
}
return &model.Account{
ID: rowID,
UUID: id,
Username: username,
AccountType: accountType,
Status: model.AccountStatusActive,
PasswordHash: passwordHash,
CreatedAt: createdAt,
UpdatedAt: createdAt,
}, nil
}
// GetAccountByUUID retrieves an account by its external UUID.
// Returns ErrNotFound if no matching account exists.
func (db *DB) GetAccountByUUID(accountUUID string) (*model.Account, error) {
return db.scanAccount(db.sql.QueryRow(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts WHERE uuid = ?
`, accountUUID))
}
// GetAccountByUsername retrieves an account by username (case-insensitive).
// Returns ErrNotFound if no matching account exists.
func (db *DB) GetAccountByUsername(username string) (*model.Account, error) {
return db.scanAccount(db.sql.QueryRow(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts WHERE username = ?
`, username))
}
// ListAccounts returns all non-deleted accounts ordered by username.
func (db *DB) ListAccounts() ([]*model.Account, error) {
rows, err := db.sql.Query(`
SELECT id, uuid, username, account_type, COALESCE(password_hash,''),
status, totp_required,
totp_secret_enc, totp_secret_nonce,
created_at, updated_at, deleted_at
FROM accounts
WHERE status != 'deleted'
ORDER BY username ASC
`)
if err != nil {
return nil, fmt.Errorf("db: list accounts: %w", err)
}
defer rows.Close()
var accounts []*model.Account
for rows.Next() {
a, err := db.scanAccountRow(rows)
if err != nil {
return nil, err
}
accounts = append(accounts, a)
}
return accounts, rows.Err()
}
// UpdateAccountStatus updates the status field and optionally sets deleted_at.
func (db *DB) UpdateAccountStatus(accountID int64, status model.AccountStatus) error {
n := now()
var deletedAt *string
if status == model.AccountStatusDeleted {
deletedAt = &n
}
_, err := db.sql.Exec(`
UPDATE accounts SET status = ?, deleted_at = ?, updated_at = ?
WHERE id = ?
`, string(status), deletedAt, n, accountID)
if err != nil {
return fmt.Errorf("db: update account status: %w", err)
}
return nil
}
// UpdatePasswordHash updates the Argon2id password hash for an account.
func (db *DB) UpdatePasswordHash(accountID int64, hash string) error {
_, err := db.sql.Exec(`
UPDATE accounts SET password_hash = ?, updated_at = ?
WHERE id = ?
`, hash, now(), accountID)
if err != nil {
return fmt.Errorf("db: update password hash: %w", err)
}
return nil
}
// SetTOTP stores the encrypted TOTP secret and marks TOTP as required.
func (db *DB) SetTOTP(accountID int64, secretEnc, secretNonce []byte) error {
_, err := db.sql.Exec(`
UPDATE accounts
SET totp_required = 1, totp_secret_enc = ?, totp_secret_nonce = ?, updated_at = ?
WHERE id = ?
`, secretEnc, secretNonce, now(), accountID)
if err != nil {
return fmt.Errorf("db: set TOTP: %w", err)
}
return nil
}
// ClearTOTP removes the TOTP secret and disables TOTP requirement.
func (db *DB) ClearTOTP(accountID int64) error {
_, err := db.sql.Exec(`
UPDATE accounts
SET totp_required = 0, totp_secret_enc = NULL, totp_secret_nonce = NULL, updated_at = ?
WHERE id = ?
`, now(), accountID)
if err != nil {
return fmt.Errorf("db: clear TOTP: %w", err)
}
return nil
}
// scanAccount scans a single account row from a *sql.Row.
func (db *DB) scanAccount(row *sql.Row) (*model.Account, error) {
var a model.Account
var accountType, status string
var totpRequired int
var createdAtStr, updatedAtStr string
var deletedAtStr *string
var totpSecretEnc, totpSecretNonce []byte
err := row.Scan(
&a.ID, &a.UUID, &a.Username,
&accountType, &a.PasswordHash,
&status, &totpRequired,
&totpSecretEnc, &totpSecretNonce,
&createdAtStr, &updatedAtStr, &deletedAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: scan account: %w", err)
}
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
}
// scanAccountRow scans a single account from *sql.Rows.
func (db *DB) scanAccountRow(rows *sql.Rows) (*model.Account, error) {
var a model.Account
var accountType, status string
var totpRequired int
var createdAtStr, updatedAtStr string
var deletedAtStr *string
var totpSecretEnc, totpSecretNonce []byte
err := rows.Scan(
&a.ID, &a.UUID, &a.Username,
&accountType, &a.PasswordHash,
&status, &totpRequired,
&totpSecretEnc, &totpSecretNonce,
&createdAtStr, &updatedAtStr, &deletedAtStr,
)
if err != nil {
return nil, fmt.Errorf("db: scan account row: %w", err)
}
return finishAccountScan(&a, accountType, status, totpRequired, totpSecretEnc, totpSecretNonce, createdAtStr, updatedAtStr, deletedAtStr)
}
func finishAccountScan(a *model.Account, accountType, status string, totpRequired int, totpSecretEnc, totpSecretNonce []byte, createdAtStr, updatedAtStr string, deletedAtStr *string) (*model.Account, error) {
a.AccountType = model.AccountType(accountType)
a.Status = model.AccountStatus(status)
a.TOTPRequired = totpRequired == 1
a.TOTPSecretEnc = totpSecretEnc
a.TOTPSecretNonce = totpSecretNonce
var err error
a.CreatedAt, err = parseTime(createdAtStr)
if err != nil {
return nil, err
}
a.UpdatedAt, err = parseTime(updatedAtStr)
if err != nil {
return nil, err
}
a.DeletedAt, err = nullableTime(deletedAtStr)
if err != nil {
return nil, err
}
return a, nil
}
// nullString converts an empty string to nil for nullable SQL columns.
func nullString(s string) *string {
if s == "" {
return nil
}
return &s
}
// GetRoles returns the role strings assigned to an account.
func (db *DB) GetRoles(accountID int64) ([]string, error) {
rows, err := db.sql.Query(`
SELECT role FROM account_roles WHERE account_id = ? ORDER BY role ASC
`, accountID)
if err != nil {
return nil, fmt.Errorf("db: get roles for account %d: %w", accountID, err)
}
defer rows.Close()
var roles []string
for rows.Next() {
var role string
if err := rows.Scan(&role); err != nil {
return nil, fmt.Errorf("db: scan role: %w", err)
}
roles = append(roles, role)
}
return roles, rows.Err()
}
// GrantRole adds a role to an account. If the role already exists, it is a no-op.
func (db *DB) GrantRole(accountID int64, role string, grantedBy *int64) error {
_, err := db.sql.Exec(`
INSERT OR IGNORE INTO account_roles (account_id, role, granted_by, granted_at)
VALUES (?, ?, ?, ?)
`, accountID, role, grantedBy, now())
if err != nil {
return fmt.Errorf("db: grant role %q to account %d: %w", role, accountID, err)
}
return nil
}
// RevokeRole removes a role from an account.
func (db *DB) RevokeRole(accountID int64, role string) error {
_, err := db.sql.Exec(`
DELETE FROM account_roles WHERE account_id = ? AND role = ?
`, accountID, role)
if err != nil {
return fmt.Errorf("db: revoke role %q from account %d: %w", role, accountID, err)
}
return nil
}
// SetRoles replaces the full role set for an account atomically.
func (db *DB) SetRoles(accountID int64, roles []string, grantedBy *int64) error {
tx, err := db.sql.Begin()
if err != nil {
return fmt.Errorf("db: set roles begin tx: %w", err)
}
if _, err := tx.Exec(`DELETE FROM account_roles WHERE account_id = ?`, accountID); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: set roles delete existing: %w", err)
}
n := now()
for _, role := range roles {
if _, err := tx.Exec(`
INSERT INTO account_roles (account_id, role, granted_by, granted_at)
VALUES (?, ?, ?, ?)
`, accountID, role, grantedBy, n); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: set roles insert %q: %w", role, err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("db: set roles commit: %w", err)
}
return nil
}
// HasRole reports whether an account holds the given role.
func (db *DB) HasRole(accountID int64, role string) (bool, error) {
var count int
err := db.sql.QueryRow(`
SELECT COUNT(*) FROM account_roles WHERE account_id = ? AND role = ?
`, accountID, role).Scan(&count)
if err != nil {
return false, fmt.Errorf("db: has role: %w", err)
}
return count > 0, nil
}
// WriteServerConfig stores the encrypted Ed25519 signing key.
// There can only be one row (id=1).
func (db *DB) WriteServerConfig(signingKeyEnc, signingKeyNonce []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO server_config (id, signing_key_enc, signing_key_nonce, created_at, updated_at)
VALUES (1, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
signing_key_enc = excluded.signing_key_enc,
signing_key_nonce = excluded.signing_key_nonce,
updated_at = excluded.updated_at
`, signingKeyEnc, signingKeyNonce, n, n)
if err != nil {
return fmt.Errorf("db: write server config: %w", err)
}
return nil
}
// ReadServerConfig returns the encrypted signing key and nonce.
// Returns ErrNotFound if no config row exists yet.
func (db *DB) ReadServerConfig() (signingKeyEnc, signingKeyNonce []byte, err error) {
err = db.sql.QueryRow(`
SELECT signing_key_enc, signing_key_nonce FROM server_config WHERE id = 1
`).Scan(&signingKeyEnc, &signingKeyNonce)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil, ErrNotFound
}
if err != nil {
return nil, nil, fmt.Errorf("db: read server config: %w", err)
}
return signingKeyEnc, signingKeyNonce, nil
}
// WriteMasterKeySalt stores the Argon2id KDF salt for the master key derivation.
// The salt must be stable across restarts so the same passphrase always yields
// the same master key. There can only be one row (id=1).
func (db *DB) WriteMasterKeySalt(salt []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO server_config (id, master_key_salt, created_at, updated_at)
VALUES (1, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
master_key_salt = excluded.master_key_salt,
updated_at = excluded.updated_at
`, salt, n, n)
if err != nil {
return fmt.Errorf("db: write master key salt: %w", err)
}
return nil
}
// ReadMasterKeySalt returns the stored Argon2id KDF salt.
// Returns ErrNotFound if no salt has been stored yet (first run).
func (db *DB) ReadMasterKeySalt() ([]byte, error) {
var salt []byte
err := db.sql.QueryRow(`
SELECT master_key_salt FROM server_config WHERE id = 1
`).Scan(&salt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: read master key salt: %w", err)
}
if salt == nil {
return nil, ErrNotFound
}
return salt, nil
}
// WritePGCredentials stores or replaces the Postgres credentials for an account.
func (db *DB) WritePGCredentials(accountID int64, host string, port int, dbName, username string, passwordEnc, passwordNonce []byte) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO pg_credentials
(account_id, pg_host, pg_port, pg_database, pg_username, pg_password_enc, pg_password_nonce, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(account_id) DO UPDATE SET
pg_host = excluded.pg_host,
pg_port = excluded.pg_port,
pg_database = excluded.pg_database,
pg_username = excluded.pg_username,
pg_password_enc = excluded.pg_password_enc,
pg_password_nonce = excluded.pg_password_nonce,
updated_at = excluded.updated_at
`, accountID, host, port, dbName, username, passwordEnc, passwordNonce, n, n)
if err != nil {
return fmt.Errorf("db: write pg credentials: %w", err)
}
return nil
}
// ReadPGCredentials retrieves the encrypted Postgres credentials for an account.
// Returns ErrNotFound if no credentials are stored.
func (db *DB) ReadPGCredentials(accountID int64) (*model.PGCredential, error) {
var cred model.PGCredential
var createdAtStr, updatedAtStr string
err := db.sql.QueryRow(`
SELECT id, account_id, pg_host, pg_port, pg_database, pg_username,
pg_password_enc, pg_password_nonce, created_at, updated_at
FROM pg_credentials WHERE account_id = ?
`, accountID).Scan(
&cred.ID, &cred.AccountID, &cred.PGHost, &cred.PGPort,
&cred.PGDatabase, &cred.PGUsername,
&cred.PGPasswordEnc, &cred.PGPasswordNonce,
&createdAtStr, &updatedAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: read pg credentials: %w", err)
}
cred.CreatedAt, err = parseTime(createdAtStr)
if err != nil {
return nil, err
}
cred.UpdatedAt, err = parseTime(updatedAtStr)
if err != nil {
return nil, err
}
return &cred, nil
}
// WriteAuditEvent appends an audit log entry.
// Details must never contain credential material.
func (db *DB) WriteAuditEvent(eventType string, actorID, targetID *int64, ipAddress, details string) error {
_, err := db.sql.Exec(`
INSERT INTO audit_log (event_type, actor_id, target_id, ip_address, details)
VALUES (?, ?, ?, ?, ?)
`, eventType, actorID, targetID, nullString(ipAddress), nullString(details))
if err != nil {
return fmt.Errorf("db: write audit event %q: %w", eventType, err)
}
return nil
}
// TrackToken records a newly issued JWT JTI for revocation tracking.
func (db *DB) TrackToken(jti string, accountID int64, issuedAt, expiresAt time.Time) error {
_, err := db.sql.Exec(`
INSERT INTO token_revocation (jti, account_id, issued_at, expires_at)
VALUES (?, ?, ?, ?)
`, jti, accountID, issuedAt.UTC().Format(time.RFC3339), expiresAt.UTC().Format(time.RFC3339))
if err != nil {
return fmt.Errorf("db: track token %q: %w", jti, err)
}
return nil
}
// GetTokenRecord retrieves a token record by JTI.
// Returns ErrNotFound if no record exists (token was never issued by this server).
func (db *DB) GetTokenRecord(jti string) (*model.TokenRecord, error) {
var rec model.TokenRecord
var issuedAtStr, expiresAtStr, createdAtStr string
var revokedAtStr *string
var revokeReason *string
err := db.sql.QueryRow(`
SELECT id, jti, account_id, expires_at, issued_at, revoked_at, revoke_reason, created_at
FROM token_revocation WHERE jti = ?
`, jti).Scan(
&rec.ID, &rec.JTI, &rec.AccountID,
&expiresAtStr, &issuedAtStr, &revokedAtStr, &revokeReason,
&createdAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: get token record %q: %w", jti, err)
}
var parseErr error
rec.ExpiresAt, parseErr = parseTime(expiresAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.IssuedAt, parseErr = parseTime(issuedAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.CreatedAt, parseErr = parseTime(createdAtStr)
if parseErr != nil {
return nil, parseErr
}
rec.RevokedAt, parseErr = nullableTime(revokedAtStr)
if parseErr != nil {
return nil, parseErr
}
if revokeReason != nil {
rec.RevokeReason = *revokeReason
}
return &rec, nil
}
// RevokeToken marks a token as revoked by JTI.
func (db *DB) RevokeToken(jti, reason string) error {
n := now()
result, err := db.sql.Exec(`
UPDATE token_revocation
SET revoked_at = ?, revoke_reason = ?
WHERE jti = ? AND revoked_at IS NULL
`, n, nullString(reason), jti)
if err != nil {
return fmt.Errorf("db: revoke token %q: %w", jti, err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("db: revoke token rows affected: %w", err)
}
if rows == 0 {
return fmt.Errorf("db: token %q not found or already revoked", jti)
}
return nil
}
// RevokeAllUserTokens revokes all non-expired, non-revoked tokens for an account.
func (db *DB) RevokeAllUserTokens(accountID int64, reason string) error {
n := now()
_, err := db.sql.Exec(`
UPDATE token_revocation
SET revoked_at = ?, revoke_reason = ?
WHERE account_id = ? AND revoked_at IS NULL AND expires_at > ?
`, n, nullString(reason), accountID, n)
if err != nil {
return fmt.Errorf("db: revoke all tokens for account %d: %w", accountID, err)
}
return nil
}
// PruneExpiredTokens removes token_revocation rows that are past their expiry.
// Returns the number of rows deleted.
func (db *DB) PruneExpiredTokens() (int64, error) {
result, err := db.sql.Exec(`
DELETE FROM token_revocation WHERE expires_at < ?
`, now())
if err != nil {
return 0, fmt.Errorf("db: prune expired tokens: %w", err)
}
return result.RowsAffected()
}
// SetSystemToken stores or replaces the active service token JTI for a system account.
func (db *DB) SetSystemToken(accountID int64, jti string, expiresAt time.Time) error {
n := now()
_, err := db.sql.Exec(`
INSERT INTO system_tokens (account_id, jti, expires_at, created_at)
VALUES (?, ?, ?, ?)
ON CONFLICT(account_id) DO UPDATE SET
jti = excluded.jti,
expires_at = excluded.expires_at,
created_at = excluded.created_at
`, accountID, jti, expiresAt.UTC().Format(time.RFC3339), n)
if err != nil {
return fmt.Errorf("db: set system token for account %d: %w", accountID, err)
}
return nil
}
// GetSystemToken retrieves the active service token record for a system account.
func (db *DB) GetSystemToken(accountID int64) (*model.SystemToken, error) {
var st model.SystemToken
var expiresAtStr, createdAtStr string
err := db.sql.QueryRow(`
SELECT id, account_id, jti, expires_at, created_at
FROM system_tokens WHERE account_id = ?
`, accountID).Scan(&st.ID, &st.AccountID, &st.JTI, &expiresAtStr, &createdAtStr)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: get system token: %w", err)
}
var parseErr error
st.ExpiresAt, parseErr = parseTime(expiresAtStr)
if parseErr != nil {
return nil, parseErr
}
st.CreatedAt, parseErr = parseTime(createdAtStr)
if parseErr != nil {
return nil, parseErr
}
return &st, nil
}

109
internal/db/db.go Normal file
View File

@@ -0,0 +1,109 @@
// Package db provides the SQLite database access layer for MCIAS.
//
// Security design:
// - All queries use parameterized statements; no string concatenation.
// - Foreign keys are enforced (PRAGMA foreign_keys = ON).
// - WAL mode is enabled for safe concurrent reads during writes.
// - The audit log is append-only: no update or delete operations are provided.
package db
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
_ "modernc.org/sqlite" // register the sqlite3 driver
)
// DB wraps a *sql.DB with MCIAS-specific helpers.
type DB struct {
sql *sql.DB
}
// Open opens (or creates) the SQLite database at path and configures it for
// MCIAS use (WAL mode, foreign keys, busy timeout).
func Open(path string) (*DB, error) {
// The modernc.org/sqlite driver is registered as "sqlite".
sqlDB, err := sql.Open("sqlite", path)
if err != nil {
return nil, fmt.Errorf("db: open sqlite: %w", err)
}
// Use a single connection for writes; reads can use the pool.
sqlDB.SetMaxOpenConns(1)
db := &DB{sql: sqlDB}
if err := db.configure(); err != nil {
_ = sqlDB.Close()
return nil, err
}
return db, nil
}
// configure applies PRAGMAs that must be set on every connection.
func (db *DB) configure() error {
pragmas := []string{
"PRAGMA journal_mode=WAL",
"PRAGMA foreign_keys=ON",
"PRAGMA busy_timeout=5000",
"PRAGMA synchronous=NORMAL",
}
for _, p := range pragmas {
if _, err := db.sql.Exec(p); err != nil {
return fmt.Errorf("db: configure pragma %q: %w", p, err)
}
}
return nil
}
// Close closes the underlying database connection.
func (db *DB) Close() error {
return db.sql.Close()
}
// Ping verifies the database connection is alive.
func (db *DB) Ping(ctx context.Context) error {
return db.sql.PingContext(ctx)
}
// SQL returns the underlying *sql.DB for use in tests or advanced queries.
// Prefer the typed methods on DB for all production code.
func (db *DB) SQL() *sql.DB {
return db.sql
}
// now returns the current UTC time formatted as ISO-8601.
func now() string {
return time.Now().UTC().Format(time.RFC3339)
}
// parseTime parses an ISO-8601 UTC time string returned by SQLite.
func parseTime(s string) (time.Time, error) {
t, err := time.Parse(time.RFC3339, s)
if err != nil {
// Try without timezone suffix (some SQLite defaults).
t, err = time.Parse("2006-01-02T15:04:05", s)
if err != nil {
return time.Time{}, fmt.Errorf("db: parse time %q: %w", s, err)
}
return t.UTC(), nil
}
return t.UTC(), nil
}
// ErrNotFound is returned when a requested record does not exist.
var ErrNotFound = errors.New("db: record not found")
// nullableTime converts a *string from SQLite into a *time.Time.
func nullableTime(s *string) (*time.Time, error) {
if s == nil {
return nil, nil
}
t, err := parseTime(*s)
if err != nil {
return nil, err
}
return &t, nil
}

355
internal/db/db_test.go Normal file
View File

@@ -0,0 +1,355 @@
package db
import (
"testing"
"time"
"git.wntrmute.dev/kyle/mcias/internal/model"
)
// openTestDB opens an in-memory SQLite database for testing.
func openTestDB(t *testing.T) *DB {
t.Helper()
db, err := Open(":memory:")
if err != nil {
t.Fatalf("Open: %v", err)
}
if err := Migrate(db); err != nil {
t.Fatalf("Migrate: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
return db
}
func TestMigrateIdempotent(t *testing.T) {
db := openTestDB(t)
// Run again — should be a no-op.
if err := Migrate(db); err != nil {
t.Errorf("second Migrate call returned error: %v", err)
}
}
func TestCreateAndGetAccount(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("alice", model.AccountTypeHuman, "$argon2id$v=19$...")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
if acct.UUID == "" {
t.Error("expected non-empty UUID")
}
if acct.Username != "alice" {
t.Errorf("Username = %q, want %q", acct.Username, "alice")
}
if acct.Status != model.AccountStatusActive {
t.Errorf("Status = %q, want active", acct.Status)
}
// Retrieve by UUID.
got, err := db.GetAccountByUUID(acct.UUID)
if err != nil {
t.Fatalf("GetAccountByUUID: %v", err)
}
if got.Username != "alice" {
t.Errorf("fetched Username = %q, want %q", got.Username, "alice")
}
// Retrieve by username.
got2, err := db.GetAccountByUsername("alice")
if err != nil {
t.Fatalf("GetAccountByUsername: %v", err)
}
if got2.UUID != acct.UUID {
t.Errorf("UUID mismatch: got %q, want %q", got2.UUID, acct.UUID)
}
}
func TestGetAccountNotFound(t *testing.T) {
db := openTestDB(t)
_, err := db.GetAccountByUUID("nonexistent-uuid")
if err != ErrNotFound {
t.Errorf("expected ErrNotFound, got %v", err)
}
_, err = db.GetAccountByUsername("nobody")
if err != ErrNotFound {
t.Errorf("expected ErrNotFound, got %v", err)
}
}
func TestUpdateAccountStatus(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("bob", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
if err := db.UpdateAccountStatus(acct.ID, model.AccountStatusInactive); err != nil {
t.Fatalf("UpdateAccountStatus: %v", err)
}
got, err := db.GetAccountByUUID(acct.UUID)
if err != nil {
t.Fatalf("GetAccountByUUID: %v", err)
}
if got.Status != model.AccountStatusInactive {
t.Errorf("Status = %q, want inactive", got.Status)
}
}
func TestListAccounts(t *testing.T) {
db := openTestDB(t)
for _, name := range []string{"charlie", "delta", "eve"} {
if _, err := db.CreateAccount(name, model.AccountTypeHuman, "hash"); err != nil {
t.Fatalf("CreateAccount %q: %v", name, err)
}
}
accts, err := db.ListAccounts()
if err != nil {
t.Fatalf("ListAccounts: %v", err)
}
if len(accts) != 3 {
t.Errorf("ListAccounts returned %d accounts, want 3", len(accts))
}
}
func TestRoleOperations(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("frank", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
// GrantRole
if err := db.GrantRole(acct.ID, "admin", nil); err != nil {
t.Fatalf("GrantRole: %v", err)
}
// Grant again — should be no-op.
if err := db.GrantRole(acct.ID, "admin", nil); err != nil {
t.Fatalf("GrantRole duplicate: %v", err)
}
roles, err := db.GetRoles(acct.ID)
if err != nil {
t.Fatalf("GetRoles: %v", err)
}
if len(roles) != 1 || roles[0] != "admin" {
t.Errorf("GetRoles = %v, want [admin]", roles)
}
has, err := db.HasRole(acct.ID, "admin")
if err != nil {
t.Fatalf("HasRole: %v", err)
}
if !has {
t.Error("expected HasRole to return true for 'admin'")
}
// RevokeRole
if err := db.RevokeRole(acct.ID, "admin"); err != nil {
t.Fatalf("RevokeRole: %v", err)
}
roles, err = db.GetRoles(acct.ID)
if err != nil {
t.Fatalf("GetRoles after revoke: %v", err)
}
if len(roles) != 0 {
t.Errorf("expected no roles after revoke, got %v", roles)
}
// SetRoles
if err := db.SetRoles(acct.ID, []string{"reader", "writer"}, nil); err != nil {
t.Fatalf("SetRoles: %v", err)
}
roles, err = db.GetRoles(acct.ID)
if err != nil {
t.Fatalf("GetRoles after SetRoles: %v", err)
}
if len(roles) != 2 {
t.Errorf("expected 2 roles after SetRoles, got %d", len(roles))
}
}
func TestTokenTrackingAndRevocation(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("grace", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
jti := "test-jti-1234"
issuedAt := time.Now().UTC()
expiresAt := issuedAt.Add(time.Hour)
if err := db.TrackToken(jti, acct.ID, issuedAt, expiresAt); err != nil {
t.Fatalf("TrackToken: %v", err)
}
// Retrieve
rec, err := db.GetTokenRecord(jti)
if err != nil {
t.Fatalf("GetTokenRecord: %v", err)
}
if rec.JTI != jti {
t.Errorf("JTI = %q, want %q", rec.JTI, jti)
}
if rec.IsRevoked() {
t.Error("newly tracked token should not be revoked")
}
// Revoke
if err := db.RevokeToken(jti, "test revocation"); err != nil {
t.Fatalf("RevokeToken: %v", err)
}
rec, err = db.GetTokenRecord(jti)
if err != nil {
t.Fatalf("GetTokenRecord after revoke: %v", err)
}
if !rec.IsRevoked() {
t.Error("token should be revoked after RevokeToken")
}
// Revoking again should fail (already revoked).
if err := db.RevokeToken(jti, "again"); err == nil {
t.Error("expected error when revoking already-revoked token")
}
}
func TestGetTokenRecordNotFound(t *testing.T) {
db := openTestDB(t)
_, err := db.GetTokenRecord("no-such-jti")
if err != ErrNotFound {
t.Errorf("expected ErrNotFound, got %v", err)
}
}
func TestPruneExpiredTokens(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("henry", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
past := time.Now().UTC().Add(-time.Hour)
future := time.Now().UTC().Add(time.Hour)
if err := db.TrackToken("expired-jti", acct.ID, past.Add(-time.Hour), past); err != nil {
t.Fatalf("TrackToken expired: %v", err)
}
if err := db.TrackToken("valid-jti", acct.ID, time.Now(), future); err != nil {
t.Fatalf("TrackToken valid: %v", err)
}
n, err := db.PruneExpiredTokens()
if err != nil {
t.Fatalf("PruneExpiredTokens: %v", err)
}
if n != 1 {
t.Errorf("pruned %d rows, want 1", n)
}
// Valid token should still be retrievable.
if _, err := db.GetTokenRecord("valid-jti"); err != nil {
t.Errorf("valid token missing after prune: %v", err)
}
}
func TestServerConfig(t *testing.T) {
db := openTestDB(t)
// No config initially.
_, _, err := db.ReadServerConfig()
if err != ErrNotFound {
t.Errorf("expected ErrNotFound for missing config, got %v", err)
}
enc := []byte("encrypted-key-data")
nonce := []byte("nonce12345678901")
if err := db.WriteServerConfig(enc, nonce); err != nil {
t.Fatalf("WriteServerConfig: %v", err)
}
gotEnc, gotNonce, err := db.ReadServerConfig()
if err != nil {
t.Fatalf("ReadServerConfig: %v", err)
}
if string(gotEnc) != string(enc) {
t.Errorf("enc mismatch: got %q, want %q", gotEnc, enc)
}
if string(gotNonce) != string(nonce) {
t.Errorf("nonce mismatch: got %q, want %q", gotNonce, nonce)
}
// Overwrite — should work without error.
if err := db.WriteServerConfig([]byte("new-key"), []byte("new-nonce123456")); err != nil {
t.Fatalf("WriteServerConfig overwrite: %v", err)
}
}
func TestForeignKeyEnforcement(t *testing.T) {
db := openTestDB(t)
// Attempting to track a token for a non-existent account should fail.
err := db.TrackToken("jti-x", 999999, time.Now(), time.Now().Add(time.Hour))
if err == nil {
t.Error("expected foreign key error for non-existent account_id, got nil")
}
}
func TestPGCredentials(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("svc", model.AccountTypeSystem, "")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
enc := []byte("encrypted-pg-password")
nonce := []byte("pg-nonce12345678")
if err := db.WritePGCredentials(acct.ID, "localhost", 5432, "mydb", "myuser", enc, nonce); err != nil {
t.Fatalf("WritePGCredentials: %v", err)
}
cred, err := db.ReadPGCredentials(acct.ID)
if err != nil {
t.Fatalf("ReadPGCredentials: %v", err)
}
if cred.PGHost != "localhost" {
t.Errorf("PGHost = %q, want %q", cred.PGHost, "localhost")
}
if cred.PGDatabase != "mydb" {
t.Errorf("PGDatabase = %q, want %q", cred.PGDatabase, "mydb")
}
}
func TestRevokeAllUserTokens(t *testing.T) {
db := openTestDB(t)
acct, err := db.CreateAccount("ivan", model.AccountTypeHuman, "hash")
if err != nil {
t.Fatalf("CreateAccount: %v", err)
}
future := time.Now().UTC().Add(time.Hour)
for _, jti := range []string{"tok1", "tok2", "tok3"} {
if err := db.TrackToken(jti, acct.ID, time.Now(), future); err != nil {
t.Fatalf("TrackToken %q: %v", jti, err)
}
}
if err := db.RevokeAllUserTokens(acct.ID, "account suspended"); err != nil {
t.Fatalf("RevokeAllUserTokens: %v", err)
}
for _, jti := range []string{"tok1", "tok2", "tok3"} {
rec, err := db.GetTokenRecord(jti)
if err != nil {
t.Fatalf("GetTokenRecord %q: %v", jti, err)
}
if !rec.IsRevoked() {
t.Errorf("token %q should be revoked", jti)
}
}
}

187
internal/db/migrate.go Normal file
View File

@@ -0,0 +1,187 @@
package db
import (
"database/sql"
"fmt"
)
// migration represents a single schema migration with an ID and SQL statement.
type migration struct {
id int
sql string
}
// migrations is the ordered list of schema migrations applied to the database.
// Once applied, migrations must never be modified — only new ones appended.
var migrations = []migration{
{
id: 1,
sql: `
CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS server_config (
id INTEGER PRIMARY KEY CHECK (id = 1),
signing_key_enc BLOB,
signing_key_nonce BLOB,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE TABLE IF NOT EXISTS accounts (
id INTEGER PRIMARY KEY,
uuid TEXT NOT NULL UNIQUE,
username TEXT NOT NULL UNIQUE COLLATE NOCASE,
account_type TEXT NOT NULL CHECK (account_type IN ('human','system')),
password_hash TEXT,
status TEXT NOT NULL DEFAULT 'active'
CHECK (status IN ('active','inactive','deleted')),
totp_required INTEGER NOT NULL DEFAULT 0 CHECK (totp_required IN (0,1)),
totp_secret_enc BLOB,
totp_secret_nonce BLOB,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
deleted_at TEXT
);
CREATE INDEX IF NOT EXISTS idx_accounts_username ON accounts (username);
CREATE INDEX IF NOT EXISTS idx_accounts_uuid ON accounts (uuid);
CREATE INDEX IF NOT EXISTS idx_accounts_status ON accounts (status);
CREATE TABLE IF NOT EXISTS account_roles (
id INTEGER PRIMARY KEY,
account_id INTEGER NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
role TEXT NOT NULL,
granted_by INTEGER REFERENCES accounts(id),
granted_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
UNIQUE (account_id, role)
);
CREATE INDEX IF NOT EXISTS idx_account_roles_account ON account_roles (account_id);
CREATE TABLE IF NOT EXISTS token_revocation (
id INTEGER PRIMARY KEY,
jti TEXT NOT NULL UNIQUE,
account_id INTEGER NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
expires_at TEXT NOT NULL,
revoked_at TEXT,
revoke_reason TEXT,
issued_at TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE INDEX IF NOT EXISTS idx_token_jti ON token_revocation (jti);
CREATE INDEX IF NOT EXISTS idx_token_account ON token_revocation (account_id);
CREATE INDEX IF NOT EXISTS idx_token_expires ON token_revocation (expires_at);
CREATE TABLE IF NOT EXISTS system_tokens (
id INTEGER PRIMARY KEY,
account_id INTEGER NOT NULL UNIQUE REFERENCES accounts(id) ON DELETE CASCADE,
jti TEXT NOT NULL UNIQUE,
expires_at TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE TABLE IF NOT EXISTS pg_credentials (
id INTEGER PRIMARY KEY,
account_id INTEGER NOT NULL UNIQUE REFERENCES accounts(id) ON DELETE CASCADE,
pg_host TEXT NOT NULL,
pg_port INTEGER NOT NULL DEFAULT 5432,
pg_database TEXT NOT NULL,
pg_username TEXT NOT NULL,
pg_password_enc BLOB NOT NULL,
pg_password_nonce BLOB NOT NULL,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
);
CREATE TABLE IF NOT EXISTS audit_log (
id INTEGER PRIMARY KEY,
event_time TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
event_type TEXT NOT NULL,
actor_id INTEGER REFERENCES accounts(id),
target_id INTEGER REFERENCES accounts(id),
ip_address TEXT,
details TEXT
);
CREATE INDEX IF NOT EXISTS idx_audit_time ON audit_log (event_time);
CREATE INDEX IF NOT EXISTS idx_audit_actor ON audit_log (actor_id);
CREATE INDEX IF NOT EXISTS idx_audit_event ON audit_log (event_type);
`,
},
{
id: 2,
sql: `
-- Add master_key_salt to server_config for Argon2id KDF salt storage.
-- The salt must be stable across restarts so the passphrase always yields the same key.
-- We allow NULL signing_key_enc/nonce temporarily until the first signing key is generated.
ALTER TABLE server_config ADD COLUMN master_key_salt BLOB;
`,
},
}
// Migrate applies any unapplied schema migrations to the database in order.
// It is idempotent: running it multiple times is safe.
func Migrate(db *DB) error {
// Ensure the schema_version table exists first.
if _, err := db.sql.Exec(`
CREATE TABLE IF NOT EXISTS schema_version (
version INTEGER NOT NULL
)
`); err != nil {
return fmt.Errorf("db: ensure schema_version: %w", err)
}
currentVersion, err := currentSchemaVersion(db.sql)
if err != nil {
return fmt.Errorf("db: get current schema version: %w", err)
}
for _, m := range migrations {
if m.id <= currentVersion {
continue
}
tx, err := db.sql.Begin()
if err != nil {
return fmt.Errorf("db: begin migration %d transaction: %w", m.id, err)
}
if _, err := tx.Exec(m.sql); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: apply migration %d: %w", m.id, err)
}
// Update the schema version within the same transaction.
if currentVersion == 0 {
if _, err := tx.Exec(`INSERT INTO schema_version (version) VALUES (?)`, m.id); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: insert schema version %d: %w", m.id, err)
}
} else {
if _, err := tx.Exec(`UPDATE schema_version SET version = ?`, m.id); err != nil {
_ = tx.Rollback()
return fmt.Errorf("db: update schema version to %d: %w", m.id, err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("db: commit migration %d: %w", m.id, err)
}
currentVersion = m.id
}
return nil
}
// currentSchemaVersion returns the current schema version, or 0 if none applied.
func currentSchemaVersion(db *sql.DB) (int, error) {
var version int
err := db.QueryRow(`SELECT version FROM schema_version LIMIT 1`).Scan(&version)
if err != nil {
// No rows means version 0 (fresh database).
return 0, nil //nolint:nilerr
}
return version, nil
}