Move SSO clients from config to database
- Add sso_clients table (migration 000010) with client_id, redirect_uri, tags (JSON), enabled flag, and audit timestamps - Add SSOClient model struct and audit events - Implement DB CRUD with 10 unit tests - Add REST API: GET/POST/PATCH/DELETE /v1/sso/clients (policy-gated) - Add gRPC SSOClientService with 5 RPCs (admin-only) - Add mciasctl sso list/create/get/update/delete commands - Add web UI admin page at /sso-clients with HTMX create/toggle/delete - Migrate handleSSOAuthorize and handleSSOTokenExchange to use DB - Remove SSOConfig, SSOClient struct, lookup methods from config - Simplify: client_id = service_name for policy evaluation Security: - SSO client CRUD is admin-only (policy-gated REST, requireAdmin gRPC) - redirect_uri must use https:// (validated at DB layer) - Disabled clients are rejected at both authorize and token exchange - All mutations write audit events Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -22,7 +22,7 @@ var migrationsFS embed.FS
|
||||
// LatestSchemaVersion is the highest migration version defined in the
|
||||
// migrations/ directory. Update this constant whenever a new migration file
|
||||
// is added.
|
||||
const LatestSchemaVersion = 9
|
||||
const LatestSchemaVersion = 10
|
||||
|
||||
// newMigrate constructs a migrate.Migrate instance backed by the embedded SQL
|
||||
// files. It opens a dedicated *sql.DB using the same DSN as the main
|
||||
|
||||
10
internal/db/migrations/000010_sso_clients.up.sql
Normal file
10
internal/db/migrations/000010_sso_clients.up.sql
Normal file
@@ -0,0 +1,10 @@
|
||||
CREATE TABLE sso_clients (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
client_id TEXT NOT NULL UNIQUE,
|
||||
redirect_uri TEXT NOT NULL,
|
||||
tags_json TEXT NOT NULL DEFAULT '[]',
|
||||
enabled INTEGER NOT NULL DEFAULT 1 CHECK (enabled IN (0,1)),
|
||||
created_by INTEGER REFERENCES accounts(id),
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now'))
|
||||
);
|
||||
206
internal/db/sso_clients.go
Normal file
206
internal/db/sso_clients.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.wntrmute.dev/mc/mcias/internal/model"
|
||||
)
|
||||
|
||||
const ssoClientCols = `id, client_id, redirect_uri, tags_json, enabled, created_by, created_at, updated_at`
|
||||
|
||||
// CreateSSOClient inserts a new SSO client. The client_id must be unique
|
||||
// and the redirect_uri must start with "https://".
|
||||
func (db *DB) CreateSSOClient(clientID, redirectURI string, tags []string, createdBy *int64) (*model.SSOClient, error) {
|
||||
if clientID == "" {
|
||||
return nil, fmt.Errorf("db: client_id is required")
|
||||
}
|
||||
if !strings.HasPrefix(redirectURI, "https://") {
|
||||
return nil, fmt.Errorf("db: redirect_uri must start with https://")
|
||||
}
|
||||
if tags == nil {
|
||||
tags = []string{}
|
||||
}
|
||||
|
||||
tagsJSON, err := json.Marshal(tags)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: marshal tags: %w", err)
|
||||
}
|
||||
|
||||
n := now()
|
||||
result, err := db.sql.Exec(`
|
||||
INSERT INTO sso_clients (client_id, redirect_uri, tags_json, enabled, created_by, created_at, updated_at)
|
||||
VALUES (?, ?, ?, 1, ?, ?, ?)
|
||||
`, clientID, redirectURI, string(tagsJSON), createdBy, n, n)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: create SSO client: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: create SSO client last insert id: %w", err)
|
||||
}
|
||||
|
||||
createdAt, err := parseTime(n)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &model.SSOClient{
|
||||
ID: id,
|
||||
ClientID: clientID,
|
||||
RedirectURI: redirectURI,
|
||||
Tags: tags,
|
||||
Enabled: true,
|
||||
CreatedBy: createdBy,
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: createdAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetSSOClient retrieves an SSO client by client_id.
|
||||
// Returns ErrNotFound if no such client exists.
|
||||
func (db *DB) GetSSOClient(clientID string) (*model.SSOClient, error) {
|
||||
return scanSSOClient(db.sql.QueryRow(`
|
||||
SELECT `+ssoClientCols+`
|
||||
FROM sso_clients WHERE client_id = ?
|
||||
`, clientID))
|
||||
}
|
||||
|
||||
// ListSSOClients returns all SSO clients ordered by client_id.
|
||||
func (db *DB) ListSSOClients() ([]*model.SSOClient, error) {
|
||||
rows, err := db.sql.Query(`
|
||||
SELECT ` + ssoClientCols + `
|
||||
FROM sso_clients ORDER BY client_id ASC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: list SSO clients: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var clients []*model.SSOClient
|
||||
for rows.Next() {
|
||||
c, err := scanSSOClientRow(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clients = append(clients, c)
|
||||
}
|
||||
return clients, rows.Err()
|
||||
}
|
||||
|
||||
// UpdateSSOClient updates the mutable fields of an SSO client.
|
||||
// Only non-nil fields are changed.
|
||||
func (db *DB) UpdateSSOClient(clientID string, redirectURI *string, tags *[]string, enabled *bool) error {
|
||||
n := now()
|
||||
setClauses := "updated_at = ?"
|
||||
args := []interface{}{n}
|
||||
|
||||
if redirectURI != nil {
|
||||
if !strings.HasPrefix(*redirectURI, "https://") {
|
||||
return fmt.Errorf("db: redirect_uri must start with https://")
|
||||
}
|
||||
setClauses += ", redirect_uri = ?"
|
||||
args = append(args, *redirectURI)
|
||||
}
|
||||
if tags != nil {
|
||||
tagsJSON, err := json.Marshal(*tags)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: marshal tags: %w", err)
|
||||
}
|
||||
setClauses += ", tags_json = ?"
|
||||
args = append(args, string(tagsJSON))
|
||||
}
|
||||
if enabled != nil {
|
||||
enabledInt := 0
|
||||
if *enabled {
|
||||
enabledInt = 1
|
||||
}
|
||||
setClauses += ", enabled = ?"
|
||||
args = append(args, enabledInt)
|
||||
}
|
||||
args = append(args, clientID)
|
||||
|
||||
res, err := db.sql.Exec(`UPDATE sso_clients SET `+setClauses+` WHERE client_id = ?`, args...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: update SSO client %s: %w", clientID, err)
|
||||
}
|
||||
n2, _ := res.RowsAffected()
|
||||
if n2 == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSSOClient removes an SSO client by client_id.
|
||||
func (db *DB) DeleteSSOClient(clientID string) error {
|
||||
res, err := db.sql.Exec(`DELETE FROM sso_clients WHERE client_id = ?`, clientID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("db: delete SSO client %s: %w", clientID, err)
|
||||
}
|
||||
n, _ := res.RowsAffected()
|
||||
if n == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// scanSSOClient scans a single SSO client from a *sql.Row.
|
||||
func scanSSOClient(row *sql.Row) (*model.SSOClient, error) {
|
||||
var c model.SSOClient
|
||||
var enabledInt int
|
||||
var tagsJSON, createdAtStr, updatedAtStr string
|
||||
var createdBy *int64
|
||||
|
||||
err := row.Scan(&c.ID, &c.ClientID, &c.RedirectURI, &tagsJSON,
|
||||
&enabledInt, &createdBy, &createdAtStr, &updatedAtStr)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: scan SSO client: %w", err)
|
||||
}
|
||||
|
||||
return finishSSOClientScan(&c, enabledInt, createdBy, tagsJSON, createdAtStr, updatedAtStr)
|
||||
}
|
||||
|
||||
// scanSSOClientRow scans a single SSO client from *sql.Rows.
|
||||
func scanSSOClientRow(rows *sql.Rows) (*model.SSOClient, error) {
|
||||
var c model.SSOClient
|
||||
var enabledInt int
|
||||
var tagsJSON, createdAtStr, updatedAtStr string
|
||||
var createdBy *int64
|
||||
|
||||
err := rows.Scan(&c.ID, &c.ClientID, &c.RedirectURI, &tagsJSON,
|
||||
&enabledInt, &createdBy, &createdAtStr, &updatedAtStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: scan SSO client row: %w", err)
|
||||
}
|
||||
|
||||
return finishSSOClientScan(&c, enabledInt, createdBy, tagsJSON, createdAtStr, updatedAtStr)
|
||||
}
|
||||
|
||||
func finishSSOClientScan(c *model.SSOClient, enabledInt int, createdBy *int64, tagsJSON, createdAtStr, updatedAtStr string) (*model.SSOClient, error) {
|
||||
c.Enabled = enabledInt == 1
|
||||
c.CreatedBy = createdBy
|
||||
|
||||
var err error
|
||||
if c.CreatedAt, err = parseTime(createdAtStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.UpdatedAt, err = parseTime(updatedAtStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil {
|
||||
return nil, fmt.Errorf("db: unmarshal SSO client tags: %w", err)
|
||||
}
|
||||
if c.Tags == nil {
|
||||
c.Tags = []string{}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
192
internal/db/sso_clients_test.go
Normal file
192
internal/db/sso_clients_test.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCreateAndGetSSOClient(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
c, err := db.CreateSSOClient("mcr", "https://mcr.example.com/sso/callback", []string{"env:prod"}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("CreateSSOClient: %v", err)
|
||||
}
|
||||
if c.ID == 0 {
|
||||
t.Error("expected non-zero ID")
|
||||
}
|
||||
if c.ClientID != "mcr" {
|
||||
t.Errorf("client_id = %q, want %q", c.ClientID, "mcr")
|
||||
}
|
||||
if !c.Enabled {
|
||||
t.Error("new client should be enabled by default")
|
||||
}
|
||||
if len(c.Tags) != 1 || c.Tags[0] != "env:prod" {
|
||||
t.Errorf("tags = %v, want [env:prod]", c.Tags)
|
||||
}
|
||||
|
||||
got, err := db.GetSSOClient("mcr")
|
||||
if err != nil {
|
||||
t.Fatalf("GetSSOClient: %v", err)
|
||||
}
|
||||
if got.RedirectURI != "https://mcr.example.com/sso/callback" {
|
||||
t.Errorf("redirect_uri = %q", got.RedirectURI)
|
||||
}
|
||||
if len(got.Tags) != 1 || got.Tags[0] != "env:prod" {
|
||||
t.Errorf("tags = %v after round-trip", got.Tags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSSOClient_DuplicateClientID(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.CreateSSOClient("mcr", "https://mcr.example.com/cb", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("first create: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.CreateSSOClient("mcr", "https://other.example.com/cb", nil, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for duplicate client_id")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSSOClient_Validation(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.CreateSSOClient("", "https://example.com/cb", nil, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for empty client_id")
|
||||
}
|
||||
|
||||
_, err = db.CreateSSOClient("mcr", "http://example.com/cb", nil, nil)
|
||||
if err == nil {
|
||||
t.Error("expected error for non-https redirect_uri")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSSOClient_NotFound(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.GetSSOClient("nonexistent")
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListSSOClients(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
clients, err := db.ListSSOClients()
|
||||
if err != nil {
|
||||
t.Fatalf("ListSSOClients (empty): %v", err)
|
||||
}
|
||||
if len(clients) != 0 {
|
||||
t.Errorf("expected 0 clients, got %d", len(clients))
|
||||
}
|
||||
|
||||
_, _ = db.CreateSSOClient("mcat", "https://mcat.example.com/cb", nil, nil)
|
||||
_, _ = db.CreateSSOClient("mcr", "https://mcr.example.com/cb", nil, nil)
|
||||
|
||||
clients, err = db.ListSSOClients()
|
||||
if err != nil {
|
||||
t.Fatalf("ListSSOClients: %v", err)
|
||||
}
|
||||
if len(clients) != 2 {
|
||||
t.Fatalf("expected 2 clients, got %d", len(clients))
|
||||
}
|
||||
// Ordered by client_id ASC.
|
||||
if clients[0].ClientID != "mcat" {
|
||||
t.Errorf("first client = %q, want %q", clients[0].ClientID, "mcat")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSSOClient(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.CreateSSOClient("mcr", "https://mcr.example.com/cb", []string{"a"}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
newURI := "https://mcr.example.com/sso/callback"
|
||||
newTags := []string{"b", "c"}
|
||||
disabled := false
|
||||
if err := db.UpdateSSOClient("mcr", &newURI, &newTags, &disabled); err != nil {
|
||||
t.Fatalf("UpdateSSOClient: %v", err)
|
||||
}
|
||||
|
||||
got, err := db.GetSSOClient("mcr")
|
||||
if err != nil {
|
||||
t.Fatalf("get after update: %v", err)
|
||||
}
|
||||
if got.RedirectURI != newURI {
|
||||
t.Errorf("redirect_uri = %q, want %q", got.RedirectURI, newURI)
|
||||
}
|
||||
if len(got.Tags) != 2 || got.Tags[0] != "b" {
|
||||
t.Errorf("tags = %v, want [b c]", got.Tags)
|
||||
}
|
||||
if got.Enabled {
|
||||
t.Error("expected enabled=false after update")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateSSOClient_NotFound(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
uri := "https://x.example.com/cb"
|
||||
err := db.UpdateSSOClient("nonexistent", &uri, nil, nil)
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteSSOClient(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.CreateSSOClient("mcr", "https://mcr.example.com/cb", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
if err := db.DeleteSSOClient("mcr"); err != nil {
|
||||
t.Fatalf("DeleteSSOClient: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.GetSSOClient("mcr")
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("expected ErrNotFound after delete, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteSSOClient_NotFound(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
err := db.DeleteSSOClient("nonexistent")
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Errorf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateSSOClient_NilTags(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
c, err := db.CreateSSOClient("mcr", "https://mcr.example.com/cb", nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
if c.Tags == nil {
|
||||
t.Error("Tags should be empty slice, not nil")
|
||||
}
|
||||
if len(c.Tags) != 0 {
|
||||
t.Errorf("expected 0 tags, got %d", len(c.Tags))
|
||||
}
|
||||
|
||||
got, err := db.GetSSOClient("mcr")
|
||||
if err != nil {
|
||||
t.Fatalf("get: %v", err)
|
||||
}
|
||||
if got.Tags == nil || len(got.Tags) != 0 {
|
||||
t.Errorf("Tags round-trip: got %v", got.Tags)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user