package db import ( "database/sql" "encoding/json" "errors" "fmt" "strings" "git.wntrmute.dev/mc/mcias/internal/model" ) const ssoClientCols = `id, client_id, redirect_uri, tags_json, enabled, created_by, created_at, updated_at` // CreateSSOClient inserts a new SSO client. The client_id must be unique // and the redirect_uri must start with "https://". func (db *DB) CreateSSOClient(clientID, redirectURI string, tags []string, createdBy *int64) (*model.SSOClient, error) { if clientID == "" { return nil, fmt.Errorf("db: client_id is required") } if !strings.HasPrefix(redirectURI, "https://") { return nil, fmt.Errorf("db: redirect_uri must start with https://") } if tags == nil { tags = []string{} } tagsJSON, err := json.Marshal(tags) if err != nil { return nil, fmt.Errorf("db: marshal tags: %w", err) } n := now() result, err := db.sql.Exec(` INSERT INTO sso_clients (client_id, redirect_uri, tags_json, enabled, created_by, created_at, updated_at) VALUES (?, ?, ?, 1, ?, ?, ?) `, clientID, redirectURI, string(tagsJSON), createdBy, n, n) if err != nil { return nil, fmt.Errorf("db: create SSO client: %w", err) } id, err := result.LastInsertId() if err != nil { return nil, fmt.Errorf("db: create SSO client last insert id: %w", err) } createdAt, err := parseTime(n) if err != nil { return nil, err } return &model.SSOClient{ ID: id, ClientID: clientID, RedirectURI: redirectURI, Tags: tags, Enabled: true, CreatedBy: createdBy, CreatedAt: createdAt, UpdatedAt: createdAt, }, nil } // GetSSOClient retrieves an SSO client by client_id. // Returns ErrNotFound if no such client exists. func (db *DB) GetSSOClient(clientID string) (*model.SSOClient, error) { return scanSSOClient(db.sql.QueryRow(` SELECT `+ssoClientCols+` FROM sso_clients WHERE client_id = ? `, clientID)) } // ListSSOClients returns all SSO clients ordered by client_id. func (db *DB) ListSSOClients() ([]*model.SSOClient, error) { rows, err := db.sql.Query(` SELECT ` + ssoClientCols + ` FROM sso_clients ORDER BY client_id ASC `) if err != nil { return nil, fmt.Errorf("db: list SSO clients: %w", err) } defer func() { _ = rows.Close() }() var clients []*model.SSOClient for rows.Next() { c, err := scanSSOClientRow(rows) if err != nil { return nil, err } clients = append(clients, c) } return clients, rows.Err() } // UpdateSSOClient updates the mutable fields of an SSO client. // Only non-nil fields are changed. func (db *DB) UpdateSSOClient(clientID string, redirectURI *string, tags *[]string, enabled *bool) error { n := now() setClauses := "updated_at = ?" args := []interface{}{n} if redirectURI != nil { if !strings.HasPrefix(*redirectURI, "https://") { return fmt.Errorf("db: redirect_uri must start with https://") } setClauses += ", redirect_uri = ?" args = append(args, *redirectURI) } if tags != nil { tagsJSON, err := json.Marshal(*tags) if err != nil { return fmt.Errorf("db: marshal tags: %w", err) } setClauses += ", tags_json = ?" args = append(args, string(tagsJSON)) } if enabled != nil { enabledInt := 0 if *enabled { enabledInt = 1 } setClauses += ", enabled = ?" args = append(args, enabledInt) } args = append(args, clientID) res, err := db.sql.Exec(`UPDATE sso_clients SET `+setClauses+` WHERE client_id = ?`, args...) if err != nil { return fmt.Errorf("db: update SSO client %s: %w", clientID, err) } n2, _ := res.RowsAffected() if n2 == 0 { return ErrNotFound } return nil } // DeleteSSOClient removes an SSO client by client_id. func (db *DB) DeleteSSOClient(clientID string) error { res, err := db.sql.Exec(`DELETE FROM sso_clients WHERE client_id = ?`, clientID) if err != nil { return fmt.Errorf("db: delete SSO client %s: %w", clientID, err) } n, _ := res.RowsAffected() if n == 0 { return ErrNotFound } return nil } // scanSSOClient scans a single SSO client from a *sql.Row. func scanSSOClient(row *sql.Row) (*model.SSOClient, error) { var c model.SSOClient var enabledInt int var tagsJSON, createdAtStr, updatedAtStr string var createdBy *int64 err := row.Scan(&c.ID, &c.ClientID, &c.RedirectURI, &tagsJSON, &enabledInt, &createdBy, &createdAtStr, &updatedAtStr) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("db: scan SSO client: %w", err) } return finishSSOClientScan(&c, enabledInt, createdBy, tagsJSON, createdAtStr, updatedAtStr) } // scanSSOClientRow scans a single SSO client from *sql.Rows. func scanSSOClientRow(rows *sql.Rows) (*model.SSOClient, error) { var c model.SSOClient var enabledInt int var tagsJSON, createdAtStr, updatedAtStr string var createdBy *int64 err := rows.Scan(&c.ID, &c.ClientID, &c.RedirectURI, &tagsJSON, &enabledInt, &createdBy, &createdAtStr, &updatedAtStr) if err != nil { return nil, fmt.Errorf("db: scan SSO client row: %w", err) } return finishSSOClientScan(&c, enabledInt, createdBy, tagsJSON, createdAtStr, updatedAtStr) } func finishSSOClientScan(c *model.SSOClient, enabledInt int, createdBy *int64, tagsJSON, createdAtStr, updatedAtStr string) (*model.SSOClient, error) { c.Enabled = enabledInt == 1 c.CreatedBy = createdBy var err error if c.CreatedAt, err = parseTime(createdAtStr); err != nil { return nil, err } if c.UpdatedAt, err = parseTime(updatedAtStr); err != nil { return nil, err } if err := json.Unmarshal([]byte(tagsJSON), &c.Tags); err != nil { return nil, fmt.Errorf("db: unmarshal SSO client tags: %w", err) } if c.Tags == nil { c.Tags = []string{} } return c, nil }