Checkpoint: password reset, rule expiry, migrations

- Self-service and admin password-change endpoints
  (PUT /v1/auth/password, PUT /v1/accounts/{id}/password)
- Policy rule time-scoped expiry (not_before / expires_at)
  with migration 000006 and engine filtering
- golang-migrate integration; embedded SQL migrations
- PolicyRecord fieldalignment lint fix

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-12 14:38:38 -07:00
parent d7b69ed983
commit 22158824bd
25 changed files with 1574 additions and 137 deletions

View File

@@ -128,14 +128,23 @@ func (db *DB) UpdateAccountStatus(accountID int64, status model.AccountStatus) e
}
// UpdatePasswordHash updates the Argon2id password hash for an account.
// Returns ErrNotFound if no active account with the given ID exists, consistent
// with the RowsAffected checks in RevokeToken and RenewToken.
func (db *DB) UpdatePasswordHash(accountID int64, hash string) error {
_, err := db.sql.Exec(`
result, err := db.sql.Exec(`
UPDATE accounts SET password_hash = ?, updated_at = ?
WHERE id = ?
`, hash, now(), accountID)
if err != nil {
return fmt.Errorf("db: update password hash: %w", err)
}
rows, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("db: update password hash rows affected: %w", err)
}
if rows == 0 {
return ErrNotFound
}
return nil
}
@@ -640,6 +649,23 @@ func (db *DB) RevokeAllUserTokens(accountID int64, reason string) error {
return nil
}
// RevokeAllUserTokensExcept revokes all non-expired, non-revoked tokens for an
// account except for the token identified by exceptJTI. Used by the
// self-service password change flow to invalidate all other sessions while
// keeping the caller's current session active.
func (db *DB) RevokeAllUserTokensExcept(accountID int64, exceptJTI, reason string) error {
n := now()
_, err := db.sql.Exec(`
UPDATE token_revocation
SET revoked_at = ?, revoke_reason = ?
WHERE account_id = ? AND jti != ? AND revoked_at IS NULL AND expires_at > ?
`, n, nullString(reason), accountID, exceptJTI, n)
if err != nil {
return fmt.Errorf("db: revoke all tokens except %q for account %d: %w", exceptJTI, accountID, err)
}
return nil
}
// PruneExpiredTokens removes token_revocation rows that are past their expiry.
// Returns the number of rows deleted.
func (db *DB) PruneExpiredTokens() (int64, error) {

View File

@@ -21,7 +21,7 @@ var migrationsFS embed.FS
// LatestSchemaVersion is the highest migration version defined in the
// migrations/ directory. Update this constant whenever a new migration file
// is added.
const LatestSchemaVersion = 5
const LatestSchemaVersion = 6
// newMigrate constructs a migrate.Migrate instance backed by the embedded SQL
// files. It opens a dedicated *sql.DB using the same DSN as the main

View File

@@ -0,0 +1,6 @@
-- Add optional time-scoped validity window to policy rules.
-- NULL means "no constraint" (rule is always active / never expires).
-- The policy engine skips rules where not_before > now() or expires_at <= now()
-- at cache-load time (SetRules), not at query time.
ALTER TABLE policy_rules ADD COLUMN not_before TEXT DEFAULT NULL;
ALTER TABLE policy_rules ADD COLUMN expires_at TEXT DEFAULT NULL;

View File

@@ -4,18 +4,23 @@ import (
"database/sql"
"errors"
"fmt"
"time"
"git.wntrmute.dev/kyle/mcias/internal/model"
)
// policyRuleCols is the column list for all policy rule SELECT queries.
const policyRuleCols = `id, priority, description, rule_json, enabled, created_by, created_at, updated_at, not_before, expires_at`
// CreatePolicyRule inserts a new policy rule record. The returned record
// includes the database-assigned ID and timestamps.
func (db *DB) CreatePolicyRule(description string, priority int, ruleJSON string, createdBy *int64) (*model.PolicyRuleRecord, error) {
// notBefore and expiresAt are optional; nil means no constraint.
func (db *DB) CreatePolicyRule(description string, priority int, ruleJSON string, createdBy *int64, notBefore, expiresAt *time.Time) (*model.PolicyRuleRecord, error) {
n := now()
result, err := db.sql.Exec(`
INSERT INTO policy_rules (priority, description, rule_json, enabled, created_by, created_at, updated_at)
VALUES (?, ?, ?, 1, ?, ?, ?)
`, priority, description, ruleJSON, createdBy, n, n)
INSERT INTO policy_rules (priority, description, rule_json, enabled, created_by, created_at, updated_at, not_before, expires_at)
VALUES (?, ?, ?, 1, ?, ?, ?, ?, ?)
`, priority, description, ruleJSON, createdBy, n, n, formatNullableTime(notBefore), formatNullableTime(expiresAt))
if err != nil {
return nil, fmt.Errorf("db: create policy rule: %w", err)
}
@@ -39,6 +44,8 @@ func (db *DB) CreatePolicyRule(description string, priority int, ruleJSON string
CreatedBy: createdBy,
CreatedAt: createdAt,
UpdatedAt: createdAt,
NotBefore: notBefore,
ExpiresAt: expiresAt,
}, nil
}
@@ -46,7 +53,7 @@ func (db *DB) CreatePolicyRule(description string, priority int, ruleJSON string
// Returns ErrNotFound if no such rule exists.
func (db *DB) GetPolicyRule(id int64) (*model.PolicyRuleRecord, error) {
return db.scanPolicyRule(db.sql.QueryRow(`
SELECT id, priority, description, rule_json, enabled, created_by, created_at, updated_at
SELECT `+policyRuleCols+`
FROM policy_rules WHERE id = ?
`, id))
}
@@ -55,7 +62,7 @@ func (db *DB) GetPolicyRule(id int64) (*model.PolicyRuleRecord, error) {
// When enabledOnly is true, only rules with enabled=1 are returned.
func (db *DB) ListPolicyRules(enabledOnly bool) ([]*model.PolicyRuleRecord, error) {
query := `
SELECT id, priority, description, rule_json, enabled, created_by, created_at, updated_at
SELECT ` + policyRuleCols + `
FROM policy_rules`
if enabledOnly {
query += ` WHERE enabled = 1`
@@ -80,8 +87,12 @@ func (db *DB) ListPolicyRules(enabledOnly bool) ([]*model.PolicyRuleRecord, erro
}
// UpdatePolicyRule updates the mutable fields of a policy rule.
// Only the fields in the update map are changed; other fields are untouched.
func (db *DB) UpdatePolicyRule(id int64, description *string, priority *int, ruleJSON *string) error {
// Only non-nil fields are changed; nil fields are left untouched.
// For notBefore and expiresAt, use a non-nil pointer-to-pointer:
// - nil (outer) → don't change
// - non-nil → nil → set column to NULL
// - non-nil → non-nil → set column to the time value
func (db *DB) UpdatePolicyRule(id int64, description *string, priority *int, ruleJSON *string, notBefore, expiresAt **time.Time) error {
n := now()
// Build SET clause dynamically to only update provided fields.
@@ -102,6 +113,14 @@ func (db *DB) UpdatePolicyRule(id int64, description *string, priority *int, rul
setClauses += ", rule_json = ?"
args = append(args, *ruleJSON)
}
if notBefore != nil {
setClauses += ", not_before = ?"
args = append(args, formatNullableTime(*notBefore))
}
if expiresAt != nil {
setClauses += ", expires_at = ?"
args = append(args, formatNullableTime(*expiresAt))
}
args = append(args, id)
_, err := db.sql.Exec(`UPDATE policy_rules SET `+setClauses+` WHERE id = ?`, args...)
@@ -141,10 +160,12 @@ func (db *DB) scanPolicyRule(row *sql.Row) (*model.PolicyRuleRecord, error) {
var enabledInt int
var createdAtStr, updatedAtStr string
var createdBy *int64
var notBeforeStr, expiresAtStr *string
err := row.Scan(
&r.ID, &r.Priority, &r.Description, &r.RuleJSON,
&enabledInt, &createdBy, &createdAtStr, &updatedAtStr,
&notBeforeStr, &expiresAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
@@ -153,7 +174,7 @@ func (db *DB) scanPolicyRule(row *sql.Row) (*model.PolicyRuleRecord, error) {
return nil, fmt.Errorf("db: scan policy rule: %w", err)
}
return finishPolicyRuleScan(&r, enabledInt, createdBy, createdAtStr, updatedAtStr)
return finishPolicyRuleScan(&r, enabledInt, createdBy, createdAtStr, updatedAtStr, notBeforeStr, expiresAtStr)
}
// scanPolicyRuleRow scans a single policy rule from *sql.Rows.
@@ -162,19 +183,21 @@ func (db *DB) scanPolicyRuleRow(rows *sql.Rows) (*model.PolicyRuleRecord, error)
var enabledInt int
var createdAtStr, updatedAtStr string
var createdBy *int64
var notBeforeStr, expiresAtStr *string
err := rows.Scan(
&r.ID, &r.Priority, &r.Description, &r.RuleJSON,
&enabledInt, &createdBy, &createdAtStr, &updatedAtStr,
&notBeforeStr, &expiresAtStr,
)
if err != nil {
return nil, fmt.Errorf("db: scan policy rule row: %w", err)
}
return finishPolicyRuleScan(&r, enabledInt, createdBy, createdAtStr, updatedAtStr)
return finishPolicyRuleScan(&r, enabledInt, createdBy, createdAtStr, updatedAtStr, notBeforeStr, expiresAtStr)
}
func finishPolicyRuleScan(r *model.PolicyRuleRecord, enabledInt int, createdBy *int64, createdAtStr, updatedAtStr string) (*model.PolicyRuleRecord, error) {
func finishPolicyRuleScan(r *model.PolicyRuleRecord, enabledInt int, createdBy *int64, createdAtStr, updatedAtStr string, notBeforeStr, expiresAtStr *string) (*model.PolicyRuleRecord, error) {
r.Enabled = enabledInt == 1
r.CreatedBy = createdBy
@@ -187,5 +210,23 @@ func finishPolicyRuleScan(r *model.PolicyRuleRecord, enabledInt int, createdBy *
if err != nil {
return nil, err
}
r.NotBefore, err = nullableTime(notBeforeStr)
if err != nil {
return nil, err
}
r.ExpiresAt, err = nullableTime(expiresAtStr)
if err != nil {
return nil, err
}
return r, nil
}
// formatNullableTime converts a *time.Time to a *string suitable for SQLite.
// Returns nil if the input is nil (stores NULL).
func formatNullableTime(t *time.Time) *string {
if t == nil {
return nil
}
s := t.UTC().Format(time.RFC3339)
return &s
}

View File

@@ -3,6 +3,7 @@ package db
import (
"errors"
"testing"
"time"
"git.wntrmute.dev/kyle/mcias/internal/model"
)
@@ -11,7 +12,7 @@ func TestCreateAndGetPolicyRule(t *testing.T) {
db := openTestDB(t)
ruleJSON := `{"actions":["pgcreds:read"],"resource_type":"pgcreds","effect":"allow"}`
rec, err := db.CreatePolicyRule("test rule", 50, ruleJSON, nil)
rec, err := db.CreatePolicyRule("test rule", 50, ruleJSON, nil, nil, nil)
if err != nil {
t.Fatalf("CreatePolicyRule: %v", err)
}
@@ -49,9 +50,9 @@ func TestGetPolicyRule_NotFound(t *testing.T) {
func TestListPolicyRules(t *testing.T) {
db := openTestDB(t)
_, _ = db.CreatePolicyRule("rule A", 100, `{"effect":"allow"}`, nil)
_, _ = db.CreatePolicyRule("rule B", 50, `{"effect":"deny"}`, nil)
_, _ = db.CreatePolicyRule("rule C", 200, `{"effect":"allow"}`, nil)
_, _ = db.CreatePolicyRule("rule A", 100, `{"effect":"allow"}`, nil, nil, nil)
_, _ = db.CreatePolicyRule("rule B", 50, `{"effect":"deny"}`, nil, nil, nil)
_, _ = db.CreatePolicyRule("rule C", 200, `{"effect":"allow"}`, nil, nil, nil)
rules, err := db.ListPolicyRules(false)
if err != nil {
@@ -70,8 +71,8 @@ func TestListPolicyRules(t *testing.T) {
func TestListPolicyRules_EnabledOnly(t *testing.T) {
db := openTestDB(t)
r1, _ := db.CreatePolicyRule("enabled rule", 100, `{"effect":"allow"}`, nil)
r2, _ := db.CreatePolicyRule("disabled rule", 100, `{"effect":"deny"}`, nil)
r1, _ := db.CreatePolicyRule("enabled rule", 100, `{"effect":"allow"}`, nil, nil, nil)
r2, _ := db.CreatePolicyRule("disabled rule", 100, `{"effect":"deny"}`, nil, nil, nil)
if err := db.SetPolicyRuleEnabled(r2.ID, false); err != nil {
t.Fatalf("SetPolicyRuleEnabled: %v", err)
@@ -100,11 +101,11 @@ func TestListPolicyRules_EnabledOnly(t *testing.T) {
func TestUpdatePolicyRule(t *testing.T) {
db := openTestDB(t)
rec, _ := db.CreatePolicyRule("original", 100, `{"effect":"allow"}`, nil)
rec, _ := db.CreatePolicyRule("original", 100, `{"effect":"allow"}`, nil, nil, nil)
newDesc := "updated description"
newPriority := 25
if err := db.UpdatePolicyRule(rec.ID, &newDesc, &newPriority, nil); err != nil {
if err := db.UpdatePolicyRule(rec.ID, &newDesc, &newPriority, nil, nil, nil); err != nil {
t.Fatalf("UpdatePolicyRule: %v", err)
}
@@ -127,10 +128,10 @@ func TestUpdatePolicyRule(t *testing.T) {
func TestUpdatePolicyRule_RuleJSON(t *testing.T) {
db := openTestDB(t)
rec, _ := db.CreatePolicyRule("rule", 100, `{"effect":"allow"}`, nil)
rec, _ := db.CreatePolicyRule("rule", 100, `{"effect":"allow"}`, nil, nil, nil)
newJSON := `{"effect":"deny","roles":["auditor"]}`
if err := db.UpdatePolicyRule(rec.ID, nil, nil, &newJSON); err != nil {
if err := db.UpdatePolicyRule(rec.ID, nil, nil, &newJSON, nil, nil); err != nil {
t.Fatalf("UpdatePolicyRule (json only): %v", err)
}
@@ -150,7 +151,7 @@ func TestUpdatePolicyRule_RuleJSON(t *testing.T) {
func TestSetPolicyRuleEnabled(t *testing.T) {
db := openTestDB(t)
rec, _ := db.CreatePolicyRule("toggle rule", 100, `{"effect":"allow"}`, nil)
rec, _ := db.CreatePolicyRule("toggle rule", 100, `{"effect":"allow"}`, nil, nil, nil)
if !rec.Enabled {
t.Fatal("new rule should be enabled")
}
@@ -175,7 +176,7 @@ func TestSetPolicyRuleEnabled(t *testing.T) {
func TestDeletePolicyRule(t *testing.T) {
db := openTestDB(t)
rec, _ := db.CreatePolicyRule("to delete", 100, `{"effect":"allow"}`, nil)
rec, _ := db.CreatePolicyRule("to delete", 100, `{"effect":"allow"}`, nil, nil, nil)
if err := db.DeletePolicyRule(rec.ID); err != nil {
t.Fatalf("DeletePolicyRule: %v", err)
@@ -200,7 +201,7 @@ func TestCreatePolicyRule_WithCreatedBy(t *testing.T) {
db := openTestDB(t)
acct, _ := db.CreateAccount("policy-creator", model.AccountTypeHuman, "hash")
rec, err := db.CreatePolicyRule("by user", 100, `{"effect":"allow"}`, &acct.ID)
rec, err := db.CreatePolicyRule("by user", 100, `{"effect":"allow"}`, &acct.ID, nil, nil)
if err != nil {
t.Fatalf("CreatePolicyRule with createdBy: %v", err)
}
@@ -210,3 +211,111 @@ func TestCreatePolicyRule_WithCreatedBy(t *testing.T) {
t.Errorf("expected CreatedBy=%d, got %v", acct.ID, got.CreatedBy)
}
}
func TestCreatePolicyRule_WithExpiresAt(t *testing.T) {
db := openTestDB(t)
exp := time.Date(2030, 6, 1, 0, 0, 0, 0, time.UTC)
rec, err := db.CreatePolicyRule("expiring rule", 100, `{"effect":"allow"}`, nil, nil, &exp)
if err != nil {
t.Fatalf("CreatePolicyRule with expiresAt: %v", err)
}
got, err := db.GetPolicyRule(rec.ID)
if err != nil {
t.Fatalf("GetPolicyRule: %v", err)
}
if got.ExpiresAt == nil {
t.Fatal("expected ExpiresAt to be set")
}
if !got.ExpiresAt.Equal(exp) {
t.Errorf("expected ExpiresAt=%v, got %v", exp, *got.ExpiresAt)
}
if got.NotBefore != nil {
t.Errorf("expected NotBefore=nil, got %v", *got.NotBefore)
}
}
func TestCreatePolicyRule_WithNotBefore(t *testing.T) {
db := openTestDB(t)
nb := time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC)
rec, err := db.CreatePolicyRule("scheduled rule", 100, `{"effect":"allow"}`, nil, &nb, nil)
if err != nil {
t.Fatalf("CreatePolicyRule with notBefore: %v", err)
}
got, err := db.GetPolicyRule(rec.ID)
if err != nil {
t.Fatalf("GetPolicyRule: %v", err)
}
if got.NotBefore == nil {
t.Fatal("expected NotBefore to be set")
}
if !got.NotBefore.Equal(nb) {
t.Errorf("expected NotBefore=%v, got %v", nb, *got.NotBefore)
}
if got.ExpiresAt != nil {
t.Errorf("expected ExpiresAt=nil, got %v", *got.ExpiresAt)
}
}
func TestCreatePolicyRule_WithBothTimes(t *testing.T) {
db := openTestDB(t)
nb := time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC)
exp := time.Date(2030, 6, 1, 0, 0, 0, 0, time.UTC)
rec, err := db.CreatePolicyRule("windowed rule", 100, `{"effect":"allow"}`, nil, &nb, &exp)
if err != nil {
t.Fatalf("CreatePolicyRule with both times: %v", err)
}
got, err := db.GetPolicyRule(rec.ID)
if err != nil {
t.Fatalf("GetPolicyRule: %v", err)
}
if got.NotBefore == nil || !got.NotBefore.Equal(nb) {
t.Errorf("NotBefore mismatch: got %v", got.NotBefore)
}
if got.ExpiresAt == nil || !got.ExpiresAt.Equal(exp) {
t.Errorf("ExpiresAt mismatch: got %v", got.ExpiresAt)
}
}
func TestUpdatePolicyRule_SetExpiresAt(t *testing.T) {
db := openTestDB(t)
rec, _ := db.CreatePolicyRule("no expiry", 100, `{"effect":"allow"}`, nil, nil, nil)
exp := time.Date(2030, 12, 31, 23, 59, 59, 0, time.UTC)
expPtr := &exp
if err := db.UpdatePolicyRule(rec.ID, nil, nil, nil, nil, &expPtr); err != nil {
t.Fatalf("UpdatePolicyRule (set expires_at): %v", err)
}
got, _ := db.GetPolicyRule(rec.ID)
if got.ExpiresAt == nil {
t.Fatal("expected ExpiresAt to be set after update")
}
if !got.ExpiresAt.Equal(exp) {
t.Errorf("expected ExpiresAt=%v, got %v", exp, *got.ExpiresAt)
}
}
func TestUpdatePolicyRule_ClearExpiresAt(t *testing.T) {
db := openTestDB(t)
exp := time.Date(2030, 6, 1, 0, 0, 0, 0, time.UTC)
rec, _ := db.CreatePolicyRule("will clear", 100, `{"effect":"allow"}`, nil, nil, &exp)
// Clear expires_at by passing non-nil outer, nil inner.
var nilTime *time.Time
if err := db.UpdatePolicyRule(rec.ID, nil, nil, nil, nil, &nilTime); err != nil {
t.Fatalf("UpdatePolicyRule (clear expires_at): %v", err)
}
got, _ := db.GetPolicyRule(rec.ID)
if got.ExpiresAt != nil {
t.Errorf("expected ExpiresAt=nil after clear, got %v", *got.ExpiresAt)
}
}