package db import ( "database/sql" "errors" "fmt" "time" "git.wntrmute.dev/kyle/mcias/internal/model" ) // ListCredentialedAccountIDs returns the set of account IDs that already have // a pg_credentials row. Used to filter the "uncredentialed system accounts" // list on the /pgcreds create form without leaking credential content. func (db *DB) ListCredentialedAccountIDs() (map[int64]struct{}, error) { rows, err := db.sql.Query(`SELECT account_id FROM pg_credentials`) if err != nil { return nil, fmt.Errorf("db: list credentialed account ids: %w", err) } defer func() { _ = rows.Close() }() ids := make(map[int64]struct{}) for rows.Next() { var id int64 if err := rows.Scan(&id); err != nil { return nil, fmt.Errorf("db: scan credentialed account id: %w", err) } ids[id] = struct{}{} } return ids, rows.Err() } // SetPGCredentialOwner records the owning account for a pg_credentials row. // This is called on first write so that pre-migration rows retain a nil owner. // It is idempotent: if the owner is already set it is overwritten. func (db *DB) SetPGCredentialOwner(credentialID, ownerID int64) error { _, err := db.sql.Exec(` UPDATE pg_credentials SET owner_id = ? WHERE id = ? `, ownerID, credentialID) if err != nil { return fmt.Errorf("db: set pg credential owner: %w", err) } return nil } // GetPGCredentialByID retrieves a single pg_credentials row by its primary key. // Returns ErrNotFound if no such credential exists. func (db *DB) GetPGCredentialByID(id int64) (*model.PGCredential, error) { var cred model.PGCredential var createdAtStr, updatedAtStr string var ownerID sql.NullInt64 err := db.sql.QueryRow(` SELECT p.id, p.account_id, p.pg_host, p.pg_port, p.pg_database, p.pg_username, p.pg_password_enc, p.pg_password_nonce, p.created_at, p.updated_at, p.owner_id FROM pg_credentials p WHERE p.id = ? `, id).Scan( &cred.ID, &cred.AccountID, &cred.PGHost, &cred.PGPort, &cred.PGDatabase, &cred.PGUsername, &cred.PGPasswordEnc, &cred.PGPasswordNonce, &createdAtStr, &updatedAtStr, &ownerID, ) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("db: get pg credential by id: %w", err) } cred.CreatedAt, err = parseTime(createdAtStr) if err != nil { return nil, err } cred.UpdatedAt, err = parseTime(updatedAtStr) if err != nil { return nil, err } if ownerID.Valid { v := ownerID.Int64 cred.OwnerID = &v } return &cred, nil } // GrantPGCredAccess grants an account read access to a pg_credentials set. // If the grant already exists the call is a no-op (UNIQUE constraint). // grantedBy may be nil if the grant is made programmatically. func (db *DB) GrantPGCredAccess(credentialID, granteeID int64, grantedBy *int64) error { n := now() _, err := db.sql.Exec(` INSERT INTO pg_credential_access (credential_id, grantee_id, granted_by, granted_at) VALUES (?, ?, ?, ?) ON CONFLICT(credential_id, grantee_id) DO NOTHING `, credentialID, granteeID, grantedBy, n) if err != nil { return fmt.Errorf("db: grant pg cred access: %w", err) } return nil } // RevokePGCredAccess removes a grantee's access to a pg_credentials set. func (db *DB) RevokePGCredAccess(credentialID, granteeID int64) error { _, err := db.sql.Exec(` DELETE FROM pg_credential_access WHERE credential_id = ? AND grantee_id = ? `, credentialID, granteeID) if err != nil { return fmt.Errorf("db: revoke pg cred access: %w", err) } return nil } // ListPGCredAccess returns all access grants for a pg_credentials set, // joining against accounts to populate grantee username and UUID. func (db *DB) ListPGCredAccess(credentialID int64) ([]*model.PGCredAccessGrant, error) { rows, err := db.sql.Query(` SELECT pca.id, pca.credential_id, pca.grantee_id, pca.granted_by, pca.granted_at, a.uuid, a.username FROM pg_credential_access pca JOIN accounts a ON a.id = pca.grantee_id WHERE pca.credential_id = ? ORDER BY pca.granted_at ASC `, credentialID) if err != nil { return nil, fmt.Errorf("db: list pg cred access: %w", err) } defer func() { _ = rows.Close() }() var grants []*model.PGCredAccessGrant for rows.Next() { g, err := scanPGCredAccessGrant(rows) if err != nil { return nil, err } grants = append(grants, g) } return grants, rows.Err() } // CheckPGCredAccess reports whether accountID has an explicit access grant for // credentialID. The credential owner always has access implicitly; callers // must check ownership separately. func (db *DB) CheckPGCredAccess(credentialID, accountID int64) (bool, error) { var count int err := db.sql.QueryRow(` SELECT COUNT(*) FROM pg_credential_access WHERE credential_id = ? AND grantee_id = ? `, credentialID, accountID).Scan(&count) if err != nil { return false, fmt.Errorf("db: check pg cred access: %w", err) } return count > 0, nil } // PGCredWithAccount extends PGCredential with the owning system account's // username, used for the "My PG Credentials" listing view. type PGCredWithAccount struct { model.PGCredential } // ListAccessiblePGCreds returns all pg_credentials rows that accountID may // view: those where accountID is the owner, plus those where an explicit // access grant exists. The ServiceUsername and ServiceAccountUUID fields are // populated from the owning system account for display and navigation. func (db *DB) ListAccessiblePGCreds(accountID int64) ([]*model.PGCredential, error) { rows, err := db.sql.Query(` SELECT p.id, p.account_id, p.pg_host, p.pg_port, p.pg_database, p.pg_username, p.pg_password_enc, p.pg_password_nonce, p.created_at, p.updated_at, p.owner_id, a.username, a.uuid FROM pg_credentials p JOIN accounts a ON a.id = p.account_id WHERE p.owner_id = ? OR EXISTS ( SELECT 1 FROM pg_credential_access pca WHERE pca.credential_id = p.id AND pca.grantee_id = ? ) ORDER BY a.username ASC `, accountID, accountID) if err != nil { return nil, fmt.Errorf("db: list accessible pg creds: %w", err) } defer func() { _ = rows.Close() }() var creds []*model.PGCredential for rows.Next() { cred, err := scanPGCredWithUsername(rows) if err != nil { return nil, err } creds = append(creds, cred) } return creds, rows.Err() } func scanPGCredWithUsername(rows *sql.Rows) (*model.PGCredential, error) { var cred model.PGCredential var createdAtStr, updatedAtStr string var ownerID sql.NullInt64 err := rows.Scan( &cred.ID, &cred.AccountID, &cred.PGHost, &cred.PGPort, &cred.PGDatabase, &cred.PGUsername, &cred.PGPasswordEnc, &cred.PGPasswordNonce, &createdAtStr, &updatedAtStr, &ownerID, &cred.ServiceUsername, &cred.ServiceAccountUUID, ) if err != nil { return nil, fmt.Errorf("db: scan pg cred with username: %w", err) } cred.CreatedAt, err = parseTime(createdAtStr) if err != nil { return nil, err } cred.UpdatedAt, err = parseTime(updatedAtStr) if err != nil { return nil, err } if ownerID.Valid { v := ownerID.Int64 cred.OwnerID = &v } return &cred, nil } func scanPGCredAccessGrant(rows *sql.Rows) (*model.PGCredAccessGrant, error) { var g model.PGCredAccessGrant var grantedAtStr string var grantedBy sql.NullInt64 err := rows.Scan( &g.ID, &g.CredentialID, &g.GranteeID, &grantedBy, &grantedAtStr, &g.GranteeUUID, &g.GranteeName, ) if err != nil { return nil, fmt.Errorf("db: scan pg cred access grant: %w", err) } g.GrantedAt, err = time.Parse("2006-01-02T15:04:05Z", grantedAtStr) if err != nil { return nil, fmt.Errorf("db: parse pg cred access grant time %q: %w", grantedAtStr, err) } if grantedBy.Valid { v := grantedBy.Int64 g.GrantedBy = &v } return &g, nil }