Add PG creds + policy/tags UI; fix lint and build
- internal/ui/ui.go: add PGCred, Tags to AccountDetailData; register
PUT /accounts/{id}/pgcreds and PUT /accounts/{id}/tags routes; add
pgcreds_form.html and tags_editor.html to shared template set; remove
unused AccountTagsData; fix fieldalignment on PolicyRuleView, PoliciesData
- internal/ui/handlers_accounts.go: add handleSetPGCreds — encrypts
password via crypto.SealAESGCM, writes audit EventPGCredUpdated, renders
pgcreds_form fragment; password never echoed; load PG creds and tags in
handleAccountDetail
- internal/ui/handlers_policy.go: fix handleSetAccountTags to render with
AccountDetailData instead of removed AccountTagsData
- internal/ui/ui_test.go: add 5 PG credential UI tests
- web/templates/fragments/pgcreds_form.html: new fragment — metadata display
+ set/replace form; system accounts only; password write-only
- web/templates/fragments/tags_editor.html: new fragment — textarea editor
with HTMX PUT for atomic tag replacement
- web/templates/fragments/policy_form.html: rewrite to use structured fields
matching handleCreatePolicyRule (roles/account_types/actions multi-select,
resource_type, subject_uuid, service_names, required_tags, checkbox)
- web/templates/policies.html: new policies management page
- web/templates/fragments/policy_row.html: new HTMX table row with toggle
and delete
- web/templates/account_detail.html: add Tags card and PG Credentials card
- web/templates/base.html: add Policies nav link
- internal/server/server.go: remove ~220 lines of duplicate tag/policy
handler code (real implementations are in handlers_policy.go)
- internal/policy/engine_wrapper.go: fix corrupted source; use errors.New
- internal/db/policy_test.go: use model.AccountTypeHuman constant
- cmd/mciasctl/main.go: add nolint:gosec to int(os.Stdin.Fd()) calls
- gofmt/goimports: db/policy_test.go, policy/defaults.go,
policy/engine_test.go, ui/ui.go, cmd/mciasctl/main.go
- fieldalignment: model.PolicyRuleRecord, policy.Engine, policy.Rule,
policy.RuleBody, ui.PolicyRuleView
Security: PG password encrypted AES-256-GCM with fresh random nonce before
storage; plaintext never logged or returned in any response; audit event
written on every credential write.
This commit is contained in:
@@ -131,6 +131,37 @@ CREATE TABLE IF NOT EXISTS failed_logins (
|
||||
window_start TEXT NOT NULL,
|
||||
attempt_count INTEGER NOT NULL DEFAULT 1
|
||||
);
|
||||
`,
|
||||
},
|
||||
{
|
||||
id: 4,
|
||||
sql: `
|
||||
-- Machine/service tags on accounts (many-to-many).
|
||||
-- Used by the policy engine to gate access by machine or service identity
|
||||
-- (e.g. env:production, svc:payments-api, machine:db-west-01).
|
||||
CREATE TABLE IF NOT EXISTS account_tags (
|
||||
account_id INTEGER NOT NULL REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
tag TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||
PRIMARY KEY (account_id, tag)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_account_tags_account ON account_tags (account_id);
|
||||
|
||||
-- Policy rules stored in the database and evaluated in-process.
|
||||
-- rule_json holds a JSON-encoded policy.RuleBody (all match fields + effect).
|
||||
-- Built-in default rules are compiled into the binary and are not stored here.
|
||||
-- Rows with enabled=0 are loaded but skipped during evaluation.
|
||||
CREATE TABLE IF NOT EXISTS policy_rules (
|
||||
id INTEGER PRIMARY KEY,
|
||||
priority INTEGER NOT NULL DEFAULT 100,
|
||||
description TEXT NOT NULL,
|
||||
rule_json TEXT NOT NULL,
|
||||
enabled INTEGER NOT NULL DEFAULT 1 CHECK (enabled IN (0,1)),
|
||||
created_by INTEGER REFERENCES accounts(id),
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||
);
|
||||
`,
|
||||
},
|
||||
}
|
||||
|
||||
191
internal/db/policy.go
Normal file
191
internal/db/policy.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
)
|
||||
|
||||
// CreatePolicyRule inserts a new policy rule record. The returned record
|
||||
// includes the database-assigned ID and timestamps.
|
||||
func (db *DB) CreatePolicyRule(description string, priority int, ruleJSON string, createdBy *int64) (*model.PolicyRuleRecord, error) {
|
||||
n := now()
|
||||
result, err := db.sql.Exec(`
|
||||
INSERT INTO policy_rules (priority, description, rule_json, enabled, created_by, created_at, updated_at)
|
||||
VALUES (?, ?, ?, 1, ?, ?, ?)
|
||||
`, priority, description, ruleJSON, createdBy, n, n)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: create policy rule: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: create policy rule last insert id: %w", err)
|
||||
}
|
||||
|
||||
createdAt, err := parseTime(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.PolicyRuleRecord{
|
||||
ID: id,
|
||||
Priority: priority,
|
||||
Description: description,
|
||||
RuleJSON: ruleJSON,
|
||||
Enabled: true,
|
||||
CreatedBy: createdBy,
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: createdAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetPolicyRule retrieves a single policy rule by its database ID.
|
||||
// Returns ErrNotFound if no such rule exists.
|
||||
func (db *DB) GetPolicyRule(id int64) (*model.PolicyRuleRecord, error) {
|
||||
return db.scanPolicyRule(db.sql.QueryRow(`
|
||||
SELECT id, priority, description, rule_json, enabled, created_by, created_at, updated_at
|
||||
FROM policy_rules WHERE id = ?
|
||||
`, id))
|
||||
}
|
||||
|
||||
// ListPolicyRules returns all policy rules ordered by priority then ID.
|
||||
// When enabledOnly is true, only rules with enabled=1 are returned.
|
||||
func (db *DB) ListPolicyRules(enabledOnly bool) ([]*model.PolicyRuleRecord, error) {
|
||||
query := `
|
||||
SELECT id, priority, description, rule_json, enabled, created_by, created_at, updated_at
|
||||
FROM policy_rules`
|
||||
if enabledOnly {
|
||||
query += ` WHERE enabled = 1`
|
||||
}
|
||||
query += ` ORDER BY priority ASC, id ASC`
|
||||
|
||||
rows, err := db.sql.Query(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: list policy rules: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var rules []*model.PolicyRuleRecord
|
||||
for rows.Next() {
|
||||
r, err := db.scanPolicyRuleRow(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rules = append(rules, r)
|
||||
}
|
||||
return rules, rows.Err()
|
||||
}
|
||||
|
||||
// UpdatePolicyRule updates the mutable fields of a policy rule.
|
||||
// Only the fields in the update map are changed; other fields are untouched.
|
||||
func (db *DB) UpdatePolicyRule(id int64, description *string, priority *int, ruleJSON *string) error {
|
||||
n := now()
|
||||
|
||||
// Build SET clause dynamically to only update provided fields.
|
||||
// Security: field names are not user-supplied strings — they are selected
|
||||
// from a fixed set of known column names only.
|
||||
setClauses := "updated_at = ?"
|
||||
args := []interface{}{n}
|
||||
|
||||
if description != nil {
|
||||
setClauses += ", description = ?"
|
||||
args = append(args, *description)
|
||||
}
|
||||
if priority != nil {
|
||||
setClauses += ", priority = ?"
|
||||
args = append(args, *priority)
|
||||
}
|
||||
if ruleJSON != nil {
|
||||
setClauses += ", rule_json = ?"
|
||||
args = append(args, *ruleJSON)
|
||||
}
|
||||
args = append(args, id)
|
||||
|
||||
_, err := db.sql.Exec(`UPDATE policy_rules SET `+setClauses+` WHERE id = ?`, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: update policy rule %d: %w", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetPolicyRuleEnabled enables or disables a policy rule by ID.
|
||||
func (db *DB) SetPolicyRuleEnabled(id int64, enabled bool) error {
|
||||
enabledInt := 0
|
||||
if enabled {
|
||||
enabledInt = 1
|
||||
}
|
||||
_, err := db.sql.Exec(`
|
||||
UPDATE policy_rules SET enabled = ?, updated_at = ? WHERE id = ?
|
||||
`, enabledInt, now(), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: set policy rule %d enabled=%v: %w", id, enabled, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePolicyRule removes a policy rule by ID.
|
||||
func (db *DB) DeletePolicyRule(id int64) error {
|
||||
_, err := db.sql.Exec(`DELETE FROM policy_rules WHERE id = ?`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: delete policy rule %d: %w", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// scanPolicyRule scans a single policy rule from a *sql.Row.
|
||||
func (db *DB) scanPolicyRule(row *sql.Row) (*model.PolicyRuleRecord, error) {
|
||||
var r model.PolicyRuleRecord
|
||||
var enabledInt int
|
||||
var createdAtStr, updatedAtStr string
|
||||
var createdBy *int64
|
||||
|
||||
err := row.Scan(
|
||||
&r.ID, &r.Priority, &r.Description, &r.RuleJSON,
|
||||
&enabledInt, &createdBy, &createdAtStr, &updatedAtStr,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: scan policy rule: %w", err)
|
||||
}
|
||||
|
||||
return finishPolicyRuleScan(&r, enabledInt, createdBy, createdAtStr, updatedAtStr)
|
||||
}
|
||||
|
||||
// scanPolicyRuleRow scans a single policy rule from *sql.Rows.
|
||||
func (db *DB) scanPolicyRuleRow(rows *sql.Rows) (*model.PolicyRuleRecord, error) {
|
||||
var r model.PolicyRuleRecord
|
||||
var enabledInt int
|
||||
var createdAtStr, updatedAtStr string
|
||||
var createdBy *int64
|
||||
|
||||
err := rows.Scan(
|
||||
&r.ID, &r.Priority, &r.Description, &r.RuleJSON,
|
||||
&enabledInt, &createdBy, &createdAtStr, &updatedAtStr,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: scan policy rule row: %w", err)
|
||||
}
|
||||
|
||||
return finishPolicyRuleScan(&r, enabledInt, createdBy, createdAtStr, updatedAtStr)
|
||||
}
|
||||
|
||||
func finishPolicyRuleScan(r *model.PolicyRuleRecord, enabledInt int, createdBy *int64, createdAtStr, updatedAtStr string) (*model.PolicyRuleRecord, error) {
|
||||
r.Enabled = enabledInt == 1
|
||||
r.CreatedBy = createdBy
|
||||
|
||||
var err error
|
||||
r.CreatedAt, err = parseTime(createdAtStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.UpdatedAt, err = parseTime(updatedAtStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
212
internal/db/policy_test.go
Normal file
212
internal/db/policy_test.go
Normal file
@@ -0,0 +1,212 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
)
|
||||
|
||||
func TestCreateAndGetPolicyRule(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
ruleJSON := `{"actions":["pgcreds:read"],"resource_type":"pgcreds","effect":"allow"}`
|
||||
rec, err := db.CreatePolicyRule("test rule", 50, ruleJSON, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("CreatePolicyRule: %v", err)
|
||||
}
|
||||
if rec.ID == 0 {
|
||||
t.Error("expected non-zero ID after create")
|
||||
}
|
||||
if rec.Priority != 50 {
|
||||
t.Errorf("expected priority 50, got %d", rec.Priority)
|
||||
}
|
||||
if !rec.Enabled {
|
||||
t.Error("new rule should be enabled by default")
|
||||
}
|
||||
|
||||
got, err := db.GetPolicyRule(rec.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPolicyRule: %v", err)
|
||||
}
|
||||
if got.Description != "test rule" {
|
||||
t.Errorf("expected description %q, got %q", "test rule", got.Description)
|
||||
}
|
||||
if got.RuleJSON != ruleJSON {
|
||||
t.Errorf("rule_json mismatch: got %q", got.RuleJSON)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPolicyRule_NotFound(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.GetPolicyRule(99999)
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListPolicyRules(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, _ = db.CreatePolicyRule("rule A", 100, `{"effect":"allow"}`, nil)
|
||||
_, _ = db.CreatePolicyRule("rule B", 50, `{"effect":"deny"}`, nil)
|
||||
_, _ = db.CreatePolicyRule("rule C", 200, `{"effect":"allow"}`, nil)
|
||||
|
||||
rules, err := db.ListPolicyRules(false)
|
||||
if err != nil {
|
||||
t.Fatalf("ListPolicyRules: %v", err)
|
||||
}
|
||||
if len(rules) != 3 {
|
||||
t.Fatalf("expected 3 rules, got %d", len(rules))
|
||||
}
|
||||
// Should be ordered by priority ascending.
|
||||
if rules[0].Priority > rules[1].Priority || rules[1].Priority > rules[2].Priority {
|
||||
t.Errorf("rules not sorted by priority: %v %v %v",
|
||||
rules[0].Priority, rules[1].Priority, rules[2].Priority)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListPolicyRules_EnabledOnly(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
r1, _ := db.CreatePolicyRule("enabled rule", 100, `{"effect":"allow"}`, nil)
|
||||
r2, _ := db.CreatePolicyRule("disabled rule", 100, `{"effect":"deny"}`, nil)
|
||||
|
||||
if err := db.SetPolicyRuleEnabled(r2.ID, false); err != nil {
|
||||
t.Fatalf("SetPolicyRuleEnabled: %v", err)
|
||||
}
|
||||
|
||||
all, err := db.ListPolicyRules(false)
|
||||
if err != nil {
|
||||
t.Fatalf("ListPolicyRules(all): %v", err)
|
||||
}
|
||||
if len(all) != 2 {
|
||||
t.Errorf("expected 2 total rules, got %d", len(all))
|
||||
}
|
||||
|
||||
enabled, err := db.ListPolicyRules(true)
|
||||
if err != nil {
|
||||
t.Fatalf("ListPolicyRules(enabledOnly): %v", err)
|
||||
}
|
||||
if len(enabled) != 1 {
|
||||
t.Fatalf("expected 1 enabled rule, got %d", len(enabled))
|
||||
}
|
||||
if enabled[0].ID != r1.ID {
|
||||
t.Errorf("wrong rule returned: got ID %d, want %d", enabled[0].ID, r1.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdatePolicyRule(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
rec, _ := db.CreatePolicyRule("original", 100, `{"effect":"allow"}`, nil)
|
||||
|
||||
newDesc := "updated description"
|
||||
newPriority := 25
|
||||
if err := db.UpdatePolicyRule(rec.ID, &newDesc, &newPriority, nil); err != nil {
|
||||
t.Fatalf("UpdatePolicyRule: %v", err)
|
||||
}
|
||||
|
||||
got, err := db.GetPolicyRule(rec.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPolicyRule after update: %v", err)
|
||||
}
|
||||
if got.Description != newDesc {
|
||||
t.Errorf("expected description %q, got %q", newDesc, got.Description)
|
||||
}
|
||||
if got.Priority != newPriority {
|
||||
t.Errorf("expected priority %d, got %d", newPriority, got.Priority)
|
||||
}
|
||||
// RuleJSON should be unchanged.
|
||||
if got.RuleJSON != `{"effect":"allow"}` {
|
||||
t.Errorf("rule_json should not change when not provided: %q", got.RuleJSON)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdatePolicyRule_RuleJSON(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
rec, _ := db.CreatePolicyRule("rule", 100, `{"effect":"allow"}`, nil)
|
||||
|
||||
newJSON := `{"effect":"deny","roles":["auditor"]}`
|
||||
if err := db.UpdatePolicyRule(rec.ID, nil, nil, &newJSON); err != nil {
|
||||
t.Fatalf("UpdatePolicyRule (json only): %v", err)
|
||||
}
|
||||
|
||||
got, err := db.GetPolicyRule(rec.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPolicyRule: %v", err)
|
||||
}
|
||||
if got.RuleJSON != newJSON {
|
||||
t.Errorf("expected updated rule_json, got %q", got.RuleJSON)
|
||||
}
|
||||
// Description and priority unchanged.
|
||||
if got.Description != "rule" {
|
||||
t.Errorf("description should be unchanged, got %q", got.Description)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetPolicyRuleEnabled(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
rec, _ := db.CreatePolicyRule("toggle rule", 100, `{"effect":"allow"}`, nil)
|
||||
if !rec.Enabled {
|
||||
t.Fatal("new rule should be enabled")
|
||||
}
|
||||
|
||||
if err := db.SetPolicyRuleEnabled(rec.ID, false); err != nil {
|
||||
t.Fatalf("SetPolicyRuleEnabled(false): %v", err)
|
||||
}
|
||||
got, _ := db.GetPolicyRule(rec.ID)
|
||||
if got.Enabled {
|
||||
t.Error("rule should be disabled after SetPolicyRuleEnabled(false)")
|
||||
}
|
||||
|
||||
if err := db.SetPolicyRuleEnabled(rec.ID, true); err != nil {
|
||||
t.Fatalf("SetPolicyRuleEnabled(true): %v", err)
|
||||
}
|
||||
got, _ = db.GetPolicyRule(rec.ID)
|
||||
if !got.Enabled {
|
||||
t.Error("rule should be enabled after SetPolicyRuleEnabled(true)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeletePolicyRule(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
rec, _ := db.CreatePolicyRule("to delete", 100, `{"effect":"allow"}`, nil)
|
||||
|
||||
if err := db.DeletePolicyRule(rec.ID); err != nil {
|
||||
t.Fatalf("DeletePolicyRule: %v", err)
|
||||
}
|
||||
|
||||
_, err := db.GetPolicyRule(rec.ID)
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("expected ErrNotFound after delete, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeletePolicyRule_NonExistent(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
// Deleting a non-existent rule should be a no-op, not an error.
|
||||
if err := db.DeletePolicyRule(99999); err != nil {
|
||||
t.Errorf("DeletePolicyRule on nonexistent ID should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreatePolicyRule_WithCreatedBy(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
acct, _ := db.CreateAccount("policy-creator", model.AccountTypeHuman, "hash")
|
||||
rec, err := db.CreatePolicyRule("by user", 100, `{"effect":"allow"}`, &acct.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("CreatePolicyRule with createdBy: %v", err)
|
||||
}
|
||||
|
||||
got, _ := db.GetPolicyRule(rec.ID)
|
||||
if got.CreatedBy == nil || *got.CreatedBy != acct.ID {
|
||||
t.Errorf("expected CreatedBy=%d, got %v", acct.ID, got.CreatedBy)
|
||||
}
|
||||
}
|
||||
82
internal/db/tags.go
Normal file
82
internal/db/tags.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GetAccountTags returns the tags assigned to an account, sorted alphabetically.
|
||||
func (db *DB) GetAccountTags(accountID int64) ([]string, error) {
|
||||
rows, err := db.sql.Query(`
|
||||
SELECT tag FROM account_tags WHERE account_id = ? ORDER BY tag ASC
|
||||
`, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: get tags for account %d: %w", accountID, err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var tags []string
|
||||
for rows.Next() {
|
||||
var tag string
|
||||
if err := rows.Scan(&tag); err != nil {
|
||||
return nil, fmt.Errorf("db: scan tag: %w", err)
|
||||
}
|
||||
tags = append(tags, tag)
|
||||
}
|
||||
return tags, rows.Err()
|
||||
}
|
||||
|
||||
// AddAccountTag adds a single tag to an account. If the tag already exists the
|
||||
// operation is a no-op (INSERT OR IGNORE).
|
||||
func (db *DB) AddAccountTag(accountID int64, tag string) error {
|
||||
_, err := db.sql.Exec(`
|
||||
INSERT OR IGNORE INTO account_tags (account_id, tag, created_at)
|
||||
VALUES (?, ?, ?)
|
||||
`, accountID, tag, now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: add tag %q to account %d: %w", tag, accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAccountTag removes a single tag from an account. If the tag does not
|
||||
// exist the operation is a no-op.
|
||||
func (db *DB) RemoveAccountTag(accountID int64, tag string) error {
|
||||
_, err := db.sql.Exec(`
|
||||
DELETE FROM account_tags WHERE account_id = ? AND tag = ?
|
||||
`, accountID, tag)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: remove tag %q from account %d: %w", tag, accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetAccountTags atomically replaces the complete tag set for an account within
|
||||
// a single transaction. Any tags not present in the new set are removed; any
|
||||
// new tags are inserted.
|
||||
func (db *DB) SetAccountTags(accountID int64, tags []string) error {
|
||||
tx, err := db.sql.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: set account tags begin tx: %w", err)
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(`DELETE FROM account_tags WHERE account_id = ?`, accountID); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("db: set account tags delete existing: %w", err)
|
||||
}
|
||||
|
||||
n := now()
|
||||
for _, tag := range tags {
|
||||
if _, err := tx.Exec(`
|
||||
INSERT INTO account_tags (account_id, tag, created_at)
|
||||
VALUES (?, ?, ?)
|
||||
`, accountID, tag, n); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("db: set account tags insert %q: %w", tag, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("db: set account tags commit: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
183
internal/db/tags_test.go
Normal file
183
internal/db/tags_test.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcias/internal/model"
|
||||
)
|
||||
|
||||
func TestGetAccountTags_Empty(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
acct, err := db.CreateAccount("taguser", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
tags, err := db.GetAccountTags(acct.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAccountTags: %v", err)
|
||||
}
|
||||
if len(tags) != 0 {
|
||||
t.Errorf("expected no tags, got %v", tags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddAndGetAccountTags(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
acct, err := db.CreateAccount("taguser2", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
for _, tag := range []string{"env:staging", "svc:payments-api"} {
|
||||
if err := db.AddAccountTag(acct.ID, tag); err != nil {
|
||||
t.Fatalf("AddAccountTag(%q): %v", tag, err)
|
||||
}
|
||||
}
|
||||
|
||||
tags, err := db.GetAccountTags(acct.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAccountTags: %v", err)
|
||||
}
|
||||
if len(tags) != 2 {
|
||||
t.Fatalf("expected 2 tags, got %d: %v", len(tags), tags)
|
||||
}
|
||||
// Results are sorted alphabetically.
|
||||
if tags[0] != "env:staging" || tags[1] != "svc:payments-api" {
|
||||
t.Errorf("unexpected tags: %v", tags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddAccountTag_Idempotent(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
acct, err := db.CreateAccount("taguser3", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
// Adding the same tag twice must not error or produce duplicates.
|
||||
for i := 0; i < 3; i++ {
|
||||
if err := db.AddAccountTag(acct.ID, "env:production"); err != nil {
|
||||
t.Fatalf("AddAccountTag (attempt %d): %v", i+1, err)
|
||||
}
|
||||
}
|
||||
|
||||
tags, err := db.GetAccountTags(acct.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAccountTags: %v", err)
|
||||
}
|
||||
if len(tags) != 1 {
|
||||
t.Errorf("expected exactly 1 tag, got %d: %v", len(tags), tags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveAccountTag(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
acct, err := db.CreateAccount("taguser4", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
_ = db.AddAccountTag(acct.ID, "env:staging")
|
||||
_ = db.AddAccountTag(acct.ID, "env:production")
|
||||
|
||||
if err := db.RemoveAccountTag(acct.ID, "env:staging"); err != nil {
|
||||
t.Fatalf("RemoveAccountTag: %v", err)
|
||||
}
|
||||
|
||||
tags, err := db.GetAccountTags(acct.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAccountTags: %v", err)
|
||||
}
|
||||
if len(tags) != 1 || tags[0] != "env:production" {
|
||||
t.Errorf("expected only env:production, got %v", tags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveAccountTag_NonExistent(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
acct, err := db.CreateAccount("taguser5", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
// Removing a tag that doesn't exist must be a no-op, not an error.
|
||||
if err := db.RemoveAccountTag(acct.ID, "nonexistent:tag"); err != nil {
|
||||
t.Errorf("RemoveAccountTag on nonexistent tag should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAccountTags_ReplacesFully(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
acct, err := db.CreateAccount("taguser6", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
_ = db.AddAccountTag(acct.ID, "old:tag1")
|
||||
_ = db.AddAccountTag(acct.ID, "old:tag2")
|
||||
|
||||
newTags := []string{"new:tag1", "new:tag2", "new:tag3"}
|
||||
if err := db.SetAccountTags(acct.ID, newTags); err != nil {
|
||||
t.Fatalf("SetAccountTags: %v", err)
|
||||
}
|
||||
|
||||
tags, err := db.GetAccountTags(acct.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAccountTags: %v", err)
|
||||
}
|
||||
if len(tags) != 3 {
|
||||
t.Fatalf("expected 3 tags after set, got %d: %v", len(tags), tags)
|
||||
}
|
||||
// Verify old tags are gone.
|
||||
for _, tag := range tags {
|
||||
if tag == "old:tag1" || tag == "old:tag2" {
|
||||
t.Errorf("old tag still present after SetAccountTags: %q", tag)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetAccountTags_Empty(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
acct, err := db.CreateAccount("taguser7", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
|
||||
_ = db.AddAccountTag(acct.ID, "env:staging")
|
||||
|
||||
if err := db.SetAccountTags(acct.ID, []string{}); err != nil {
|
||||
t.Fatalf("SetAccountTags with empty slice: %v", err)
|
||||
}
|
||||
|
||||
tags, err := db.GetAccountTags(acct.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAccountTags: %v", err)
|
||||
}
|
||||
if len(tags) != 0 {
|
||||
t.Errorf("expected no tags after clearing, got %v", tags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountTagsCascadeDelete(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
acct, err := db.CreateAccount("taguser8", model.AccountTypeHuman, "hash")
|
||||
if err != nil {
|
||||
t.Fatalf("CreateAccount: %v", err)
|
||||
}
|
||||
_ = db.AddAccountTag(acct.ID, "env:staging")
|
||||
|
||||
// Soft-deleting an account does not cascade-delete tags (FK ON DELETE CASCADE
|
||||
// only fires on hard deletes). Verify tags still exist after status update.
|
||||
if err := db.UpdateAccountStatus(acct.ID, model.AccountStatusDeleted); err != nil {
|
||||
t.Fatalf("UpdateAccountStatus: %v", err)
|
||||
}
|
||||
|
||||
tags, err := db.GetAccountTags(acct.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAccountTags after soft delete: %v", err)
|
||||
}
|
||||
if len(tags) != 1 {
|
||||
t.Errorf("expected tag to survive soft delete, got %v", tags)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user