Files
mcias/internal/db/policy.go
Kyle Isom 22158824bd Checkpoint: password reset, rule expiry, migrations
- Self-service and admin password-change endpoints
  (PUT /v1/auth/password, PUT /v1/accounts/{id}/password)
- Policy rule time-scoped expiry (not_before / expires_at)
  with migration 000006 and engine filtering
- golang-migrate integration; embedded SQL migrations
- PolicyRecord fieldalignment lint fix

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-12 14:38:38 -07:00

233 lines
6.8 KiB
Go

package db
import (
"database/sql"
"errors"
"fmt"
"time"
"git.wntrmute.dev/kyle/mcias/internal/model"
)
// policyRuleCols is the column list for all policy rule SELECT queries.
const policyRuleCols = `id, priority, description, rule_json, enabled, created_by, created_at, updated_at, not_before, expires_at`
// CreatePolicyRule inserts a new policy rule record. The returned record
// includes the database-assigned ID and timestamps.
// notBefore and expiresAt are optional; nil means no constraint.
func (db *DB) CreatePolicyRule(description string, priority int, ruleJSON string, createdBy *int64, notBefore, expiresAt *time.Time) (*model.PolicyRuleRecord, error) {
n := now()
result, err := db.sql.Exec(`
INSERT INTO policy_rules (priority, description, rule_json, enabled, created_by, created_at, updated_at, not_before, expires_at)
VALUES (?, ?, ?, 1, ?, ?, ?, ?, ?)
`, priority, description, ruleJSON, createdBy, n, n, formatNullableTime(notBefore), formatNullableTime(expiresAt))
if err != nil {
return nil, fmt.Errorf("db: create policy rule: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("db: create policy rule last insert id: %w", err)
}
createdAt, err := parseTime(n)
if err != nil {
return nil, err
}
return &model.PolicyRuleRecord{
ID: id,
Priority: priority,
Description: description,
RuleJSON: ruleJSON,
Enabled: true,
CreatedBy: createdBy,
CreatedAt: createdAt,
UpdatedAt: createdAt,
NotBefore: notBefore,
ExpiresAt: expiresAt,
}, nil
}
// GetPolicyRule retrieves a single policy rule by its database ID.
// Returns ErrNotFound if no such rule exists.
func (db *DB) GetPolicyRule(id int64) (*model.PolicyRuleRecord, error) {
return db.scanPolicyRule(db.sql.QueryRow(`
SELECT `+policyRuleCols+`
FROM policy_rules WHERE id = ?
`, id))
}
// ListPolicyRules returns all policy rules ordered by priority then ID.
// When enabledOnly is true, only rules with enabled=1 are returned.
func (db *DB) ListPolicyRules(enabledOnly bool) ([]*model.PolicyRuleRecord, error) {
query := `
SELECT ` + policyRuleCols + `
FROM policy_rules`
if enabledOnly {
query += ` WHERE enabled = 1`
}
query += ` ORDER BY priority ASC, id ASC`
rows, err := db.sql.Query(query)
if err != nil {
return nil, fmt.Errorf("db: list policy rules: %w", err)
}
defer func() { _ = rows.Close() }()
var rules []*model.PolicyRuleRecord
for rows.Next() {
r, err := db.scanPolicyRuleRow(rows)
if err != nil {
return nil, err
}
rules = append(rules, r)
}
return rules, rows.Err()
}
// UpdatePolicyRule updates the mutable fields of a policy rule.
// Only non-nil fields are changed; nil fields are left untouched.
// For notBefore and expiresAt, use a non-nil pointer-to-pointer:
// - nil (outer) → don't change
// - non-nil → nil → set column to NULL
// - non-nil → non-nil → set column to the time value
func (db *DB) UpdatePolicyRule(id int64, description *string, priority *int, ruleJSON *string, notBefore, expiresAt **time.Time) error {
n := now()
// Build SET clause dynamically to only update provided fields.
// Security: field names are not user-supplied strings — they are selected
// from a fixed set of known column names only.
setClauses := "updated_at = ?"
args := []interface{}{n}
if description != nil {
setClauses += ", description = ?"
args = append(args, *description)
}
if priority != nil {
setClauses += ", priority = ?"
args = append(args, *priority)
}
if ruleJSON != nil {
setClauses += ", rule_json = ?"
args = append(args, *ruleJSON)
}
if notBefore != nil {
setClauses += ", not_before = ?"
args = append(args, formatNullableTime(*notBefore))
}
if expiresAt != nil {
setClauses += ", expires_at = ?"
args = append(args, formatNullableTime(*expiresAt))
}
args = append(args, id)
_, err := db.sql.Exec(`UPDATE policy_rules SET `+setClauses+` WHERE id = ?`, args...)
if err != nil {
return fmt.Errorf("db: update policy rule %d: %w", id, err)
}
return nil
}
// SetPolicyRuleEnabled enables or disables a policy rule by ID.
func (db *DB) SetPolicyRuleEnabled(id int64, enabled bool) error {
enabledInt := 0
if enabled {
enabledInt = 1
}
_, err := db.sql.Exec(`
UPDATE policy_rules SET enabled = ?, updated_at = ? WHERE id = ?
`, enabledInt, now(), id)
if err != nil {
return fmt.Errorf("db: set policy rule %d enabled=%v: %w", id, enabled, err)
}
return nil
}
// DeletePolicyRule removes a policy rule by ID.
func (db *DB) DeletePolicyRule(id int64) error {
_, err := db.sql.Exec(`DELETE FROM policy_rules WHERE id = ?`, id)
if err != nil {
return fmt.Errorf("db: delete policy rule %d: %w", id, err)
}
return nil
}
// scanPolicyRule scans a single policy rule from a *sql.Row.
func (db *DB) scanPolicyRule(row *sql.Row) (*model.PolicyRuleRecord, error) {
var r model.PolicyRuleRecord
var enabledInt int
var createdAtStr, updatedAtStr string
var createdBy *int64
var notBeforeStr, expiresAtStr *string
err := row.Scan(
&r.ID, &r.Priority, &r.Description, &r.RuleJSON,
&enabledInt, &createdBy, &createdAtStr, &updatedAtStr,
&notBeforeStr, &expiresAtStr,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("db: scan policy rule: %w", err)
}
return finishPolicyRuleScan(&r, enabledInt, createdBy, createdAtStr, updatedAtStr, notBeforeStr, expiresAtStr)
}
// scanPolicyRuleRow scans a single policy rule from *sql.Rows.
func (db *DB) scanPolicyRuleRow(rows *sql.Rows) (*model.PolicyRuleRecord, error) {
var r model.PolicyRuleRecord
var enabledInt int
var createdAtStr, updatedAtStr string
var createdBy *int64
var notBeforeStr, expiresAtStr *string
err := rows.Scan(
&r.ID, &r.Priority, &r.Description, &r.RuleJSON,
&enabledInt, &createdBy, &createdAtStr, &updatedAtStr,
&notBeforeStr, &expiresAtStr,
)
if err != nil {
return nil, fmt.Errorf("db: scan policy rule row: %w", err)
}
return finishPolicyRuleScan(&r, enabledInt, createdBy, createdAtStr, updatedAtStr, notBeforeStr, expiresAtStr)
}
func finishPolicyRuleScan(r *model.PolicyRuleRecord, enabledInt int, createdBy *int64, createdAtStr, updatedAtStr string, notBeforeStr, expiresAtStr *string) (*model.PolicyRuleRecord, error) {
r.Enabled = enabledInt == 1
r.CreatedBy = createdBy
var err error
r.CreatedAt, err = parseTime(createdAtStr)
if err != nil {
return nil, err
}
r.UpdatedAt, err = parseTime(updatedAtStr)
if err != nil {
return nil, err
}
r.NotBefore, err = nullableTime(notBeforeStr)
if err != nil {
return nil, err
}
r.ExpiresAt, err = nullableTime(expiresAtStr)
if err != nil {
return nil, err
}
return r, nil
}
// formatNullableTime converts a *time.Time to a *string suitable for SQLite.
// Returns nil if the input is nil (stores NULL).
func formatNullableTime(t *time.Time) *string {
if t == nil {
return nil
}
s := t.UTC().Format(time.RFC3339)
return &s
}