package db import ( "database/sql" "errors" "fmt" "git.wntrmute.dev/kyle/mcias/internal/model" ) // CreateWebAuthnCredential inserts a new WebAuthn credential record. // All encrypted fields (credential_id, public_key) must be encrypted by the caller. func (db *DB) CreateWebAuthnCredential(cred *model.WebAuthnCredential) (int64, error) { n := now() result, err := db.sql.Exec(` INSERT INTO webauthn_credentials (account_id, name, credential_id_enc, credential_id_nonce, public_key_enc, public_key_nonce, aaguid, sign_count, discoverable, transports, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, cred.AccountID, cred.Name, cred.CredentialIDEnc, cred.CredentialIDNonce, cred.PublicKeyEnc, cred.PublicKeyNonce, cred.AAGUID, cred.SignCount, boolToInt(cred.Discoverable), cred.Transports, n, n) if err != nil { return 0, fmt.Errorf("db: create webauthn credential: %w", err) } id, err := result.LastInsertId() if err != nil { return 0, fmt.Errorf("db: webauthn credential last insert id: %w", err) } return id, nil } // GetWebAuthnCredentials returns all WebAuthn credentials for an account. func (db *DB) GetWebAuthnCredentials(accountID int64) ([]*model.WebAuthnCredential, error) { rows, err := db.sql.Query(` SELECT id, account_id, name, credential_id_enc, credential_id_nonce, public_key_enc, public_key_nonce, aaguid, sign_count, discoverable, transports, created_at, updated_at, last_used_at FROM webauthn_credentials WHERE account_id = ? ORDER BY created_at ASC`, accountID) if err != nil { return nil, fmt.Errorf("db: list webauthn credentials: %w", err) } defer rows.Close() //nolint:errcheck // rows.Close error is non-fatal return scanWebAuthnCredentials(rows) } // GetWebAuthnCredentialByID returns a single WebAuthn credential by its DB row ID. // Returns ErrNotFound if the credential does not exist. func (db *DB) GetWebAuthnCredentialByID(id int64) (*model.WebAuthnCredential, error) { row := db.sql.QueryRow(` SELECT id, account_id, name, credential_id_enc, credential_id_nonce, public_key_enc, public_key_nonce, aaguid, sign_count, discoverable, transports, created_at, updated_at, last_used_at FROM webauthn_credentials WHERE id = ?`, id) return scanWebAuthnCredential(row) } // DeleteWebAuthnCredential deletes a WebAuthn credential by ID, verifying ownership. // Returns ErrNotFound if the credential does not exist or does not belong to the account. func (db *DB) DeleteWebAuthnCredential(id, accountID int64) error { result, err := db.sql.Exec( `DELETE FROM webauthn_credentials WHERE id = ? AND account_id = ?`, id, accountID) if err != nil { return fmt.Errorf("db: delete webauthn credential: %w", err) } n, err := result.RowsAffected() if err != nil { return fmt.Errorf("db: webauthn delete rows affected: %w", err) } if n == 0 { return ErrNotFound } return nil } // DeleteWebAuthnCredentialAdmin deletes a WebAuthn credential by ID without ownership check. func (db *DB) DeleteWebAuthnCredentialAdmin(id int64) error { result, err := db.sql.Exec(`DELETE FROM webauthn_credentials WHERE id = ?`, id) if err != nil { return fmt.Errorf("db: admin delete webauthn credential: %w", err) } n, err := result.RowsAffected() if err != nil { return fmt.Errorf("db: webauthn admin delete rows affected: %w", err) } if n == 0 { return ErrNotFound } return nil } // DeleteAllWebAuthnCredentials removes all WebAuthn credentials for an account. func (db *DB) DeleteAllWebAuthnCredentials(accountID int64) (int64, error) { result, err := db.sql.Exec( `DELETE FROM webauthn_credentials WHERE account_id = ?`, accountID) if err != nil { return 0, fmt.Errorf("db: delete all webauthn credentials: %w", err) } return result.RowsAffected() } // UpdateWebAuthnSignCount updates the sign counter for a credential. func (db *DB) UpdateWebAuthnSignCount(id int64, signCount uint32) error { _, err := db.sql.Exec( `UPDATE webauthn_credentials SET sign_count = ?, updated_at = ? WHERE id = ?`, signCount, now(), id) if err != nil { return fmt.Errorf("db: update webauthn sign count: %w", err) } return nil } // UpdateWebAuthnLastUsed sets the last_used_at timestamp for a credential. func (db *DB) UpdateWebAuthnLastUsed(id int64) error { _, err := db.sql.Exec( `UPDATE webauthn_credentials SET last_used_at = ?, updated_at = ? WHERE id = ?`, now(), now(), id) if err != nil { return fmt.Errorf("db: update webauthn last used: %w", err) } return nil } // HasWebAuthnCredentials reports whether the account has any WebAuthn credentials. func (db *DB) HasWebAuthnCredentials(accountID int64) (bool, error) { var count int err := db.sql.QueryRow( `SELECT COUNT(*) FROM webauthn_credentials WHERE account_id = ?`, accountID).Scan(&count) if err != nil { return false, fmt.Errorf("db: count webauthn credentials: %w", err) } return count > 0, nil } // CountWebAuthnCredentials returns the number of WebAuthn credentials for an account. func (db *DB) CountWebAuthnCredentials(accountID int64) (int, error) { var count int err := db.sql.QueryRow( `SELECT COUNT(*) FROM webauthn_credentials WHERE account_id = ?`, accountID).Scan(&count) if err != nil { return 0, fmt.Errorf("db: count webauthn credentials: %w", err) } return count, nil } // boolToInt converts a bool to 0/1 for SQLite storage. func boolToInt(b bool) int { if b { return 1 } return 0 } func scanWebAuthnCredentials(rows *sql.Rows) ([]*model.WebAuthnCredential, error) { var creds []*model.WebAuthnCredential for rows.Next() { cred, err := scanWebAuthnRow(rows) if err != nil { return nil, err } creds = append(creds, cred) } return creds, rows.Err() } // scannable is implemented by both *sql.Row and *sql.Rows. type scannable interface { Scan(dest ...any) error } func scanWebAuthnRow(s scannable) (*model.WebAuthnCredential, error) { var cred model.WebAuthnCredential var createdAt, updatedAt string var lastUsedAt *string var discoverable int err := s.Scan( &cred.ID, &cred.AccountID, &cred.Name, &cred.CredentialIDEnc, &cred.CredentialIDNonce, &cred.PublicKeyEnc, &cred.PublicKeyNonce, &cred.AAGUID, &cred.SignCount, &discoverable, &cred.Transports, &createdAt, &updatedAt, &lastUsedAt) if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } return nil, fmt.Errorf("db: scan webauthn credential: %w", err) } cred.Discoverable = discoverable != 0 cred.CreatedAt, err = parseTime(createdAt) if err != nil { return nil, err } cred.UpdatedAt, err = parseTime(updatedAt) if err != nil { return nil, err } cred.LastUsedAt, err = nullableTime(lastUsedAt) if err != nil { return nil, err } return &cred, nil } func scanWebAuthnCredential(row *sql.Row) (*model.WebAuthnCredential, error) { return scanWebAuthnRow(row) }