package db import ( "database/sql" "errors" "fmt" "time" "git.wntrmute.dev/kyle/mcias/internal/model" ) // policyRuleCols is the column list for all policy rule SELECT queries. const policyRuleCols = `id, priority, description, rule_json, enabled, created_by, created_at, updated_at, not_before, expires_at` // CreatePolicyRule inserts a new policy rule record. The returned record // includes the database-assigned ID and timestamps. // notBefore and expiresAt are optional; nil means no constraint. func (db *DB) CreatePolicyRule(description string, priority int, ruleJSON string, createdBy *int64, notBefore, expiresAt *time.Time) (*model.PolicyRuleRecord, error) { n := now() result, err := db.sql.Exec(` INSERT INTO policy_rules (priority, description, rule_json, enabled, created_by, created_at, updated_at, not_before, expires_at) VALUES (?, ?, ?, 1, ?, ?, ?, ?, ?) `, priority, description, ruleJSON, createdBy, n, n, formatNullableTime(notBefore), formatNullableTime(expiresAt)) if err != nil { return nil, fmt.Errorf("db: create policy rule: %w", err) } id, err := result.LastInsertId() if err != nil { return nil, fmt.Errorf("db: create policy rule last insert id: %w", err) } createdAt, err := parseTime(n) if err != nil { return nil, err } return &model.PolicyRuleRecord{ ID: id, Priority: priority, Description: description, RuleJSON: ruleJSON, Enabled: true, CreatedBy: createdBy, CreatedAt: createdAt, UpdatedAt: createdAt, NotBefore: notBefore, ExpiresAt: expiresAt, }, nil } // GetPolicyRule retrieves a single policy rule by its database ID. // Returns ErrNotFound if no such rule exists. func (db *DB) GetPolicyRule(id int64) (*model.PolicyRuleRecord, error) { return db.scanPolicyRule(db.sql.QueryRow(` SELECT `+policyRuleCols+` FROM policy_rules WHERE id = ? `, id)) } // ListPolicyRules returns all policy rules ordered by priority then ID. // When enabledOnly is true, only rules with enabled=1 are returned. func (db *DB) ListPolicyRules(enabledOnly bool) ([]*model.PolicyRuleRecord, error) { query := ` SELECT ` + policyRuleCols + ` FROM policy_rules` if enabledOnly { query += ` WHERE enabled = 1` } query += ` ORDER BY priority ASC, id ASC` rows, err := db.sql.Query(query) if err != nil { return nil, fmt.Errorf("db: list policy rules: %w", err) } defer func() { _ = rows.Close() }() var rules []*model.PolicyRuleRecord for rows.Next() { r, err := db.scanPolicyRuleRow(rows) if err != nil { return nil, err } rules = append(rules, r) } return rules, rows.Err() } // UpdatePolicyRule updates the mutable fields of a policy rule. // Only non-nil fields are changed; nil fields are left untouched. // For notBefore and expiresAt, use a non-nil pointer-to-pointer: // - nil (outer) → don't change // - non-nil → nil → set column to NULL // - non-nil → non-nil → set column to the time value func (db *DB) UpdatePolicyRule(id int64, description *string, priority *int, ruleJSON *string, notBefore, expiresAt **time.Time) error { n := now() // Build SET clause dynamically to only update provided fields. // Security: field names are not user-supplied strings — they are selected // from a fixed set of known column names only. setClauses := "updated_at = ?" args := []interface{}{n} if description != nil { setClauses += ", description = ?" args = append(args, *description) } if priority != nil { setClauses += ", priority = ?" args = append(args, *priority) } if ruleJSON != nil { setClauses += ", rule_json = ?" args = append(args, *ruleJSON) } if notBefore != nil { setClauses += ", not_before = ?" args = append(args, formatNullableTime(*notBefore)) } if expiresAt != nil { setClauses += ", expires_at = ?" args = append(args, formatNullableTime(*expiresAt)) } args = append(args, id) _, err := db.sql.Exec(`UPDATE policy_rules SET `+setClauses+` WHERE id = ?`, args...) if err != nil { return fmt.Errorf("db: update policy rule %d: %w", id, err) } return nil } // SetPolicyRuleEnabled enables or disables a policy rule by ID. func (db *DB) SetPolicyRuleEnabled(id int64, enabled bool) error { enabledInt := 0 if enabled { enabledInt = 1 } _, err := db.sql.Exec(` UPDATE policy_rules SET enabled = ?, updated_at = ? WHERE id = ? `, enabledInt, now(), id) if err != nil { return fmt.Errorf("db: set policy rule %d enabled=%v: %w", id, enabled, err) } return nil } // DeletePolicyRule removes a policy rule by ID. func (db *DB) DeletePolicyRule(id int64) error { _, err := db.sql.Exec(`DELETE FROM policy_rules WHERE id = ?`, id) if err != nil { return fmt.Errorf("db: delete policy rule %d: %w", id, err) } return nil } // scanPolicyRule scans a single policy rule from a *sql.Row. func (db *DB) scanPolicyRule(row *sql.Row) (*model.PolicyRuleRecord, error) { var r model.PolicyRuleRecord var enabledInt int var createdAtStr, updatedAtStr string var createdBy *int64 var notBeforeStr, expiresAtStr *string err := row.Scan( &r.ID, &r.Priority, &r.Description, &r.RuleJSON, &enabledInt, &createdBy, &createdAtStr, &updatedAtStr, ¬BeforeStr, &expiresAtStr, ) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("db: scan policy rule: %w", err) } return finishPolicyRuleScan(&r, enabledInt, createdBy, createdAtStr, updatedAtStr, notBeforeStr, expiresAtStr) } // scanPolicyRuleRow scans a single policy rule from *sql.Rows. func (db *DB) scanPolicyRuleRow(rows *sql.Rows) (*model.PolicyRuleRecord, error) { var r model.PolicyRuleRecord var enabledInt int var createdAtStr, updatedAtStr string var createdBy *int64 var notBeforeStr, expiresAtStr *string err := rows.Scan( &r.ID, &r.Priority, &r.Description, &r.RuleJSON, &enabledInt, &createdBy, &createdAtStr, &updatedAtStr, ¬BeforeStr, &expiresAtStr, ) if err != nil { return nil, fmt.Errorf("db: scan policy rule row: %w", err) } return finishPolicyRuleScan(&r, enabledInt, createdBy, createdAtStr, updatedAtStr, notBeforeStr, expiresAtStr) } func finishPolicyRuleScan(r *model.PolicyRuleRecord, enabledInt int, createdBy *int64, createdAtStr, updatedAtStr string, notBeforeStr, expiresAtStr *string) (*model.PolicyRuleRecord, error) { r.Enabled = enabledInt == 1 r.CreatedBy = createdBy var err error r.CreatedAt, err = parseTime(createdAtStr) if err != nil { return nil, err } r.UpdatedAt, err = parseTime(updatedAtStr) if err != nil { return nil, err } r.NotBefore, err = nullableTime(notBeforeStr) if err != nil { return nil, err } r.ExpiresAt, err = nullableTime(expiresAtStr) if err != nil { return nil, err } return r, nil } // formatNullableTime converts a *time.Time to a *string suitable for SQLite. // Returns nil if the input is nil (stores NULL). func formatNullableTime(t *time.Time) *string { if t == nil { return nil } s := t.UTC().Format(time.RFC3339) return &s }