// Package db provides SQLite database management, migration support, // and transaction helpers for the exo system. package db import ( "context" "database/sql" "embed" "fmt" "log" "sort" "strings" "time" _ "github.com/mattn/go-sqlite3" ) //go:embed migrations/*.sql var migrationsFS embed.FS const iso8601 = "2006-01-02 15:04:05" // ToDBTime formats a time.Time as an ISO 8601 UTC string for storage. func ToDBTime(t time.Time) string { return t.UTC().Format(iso8601) } // FromDBTime parses an ISO 8601 UTC string back to a time.Time. // If loc is non-nil, the result is converted to that location. func FromDBTime(datetime string, loc *time.Location) (time.Time, error) { t, err := time.Parse(iso8601, datetime) if err != nil { return t, fmt.Errorf("db: failed to parse time %q: %w", datetime, err) } if loc != nil { t = t.In(loc) } return t, nil } // Open opens a SQLite database at the given path with standard pragmas. func Open(path string) (*sql.DB, error) { db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL&_foreign_keys=ON&_busy_timeout=5000") if err != nil { return nil, fmt.Errorf("db: failed to open database %q: %w", path, err) } // Verify the connection works. if err := db.Ping(); err != nil { _ = db.Close() return nil, fmt.Errorf("db: failed to ping database %q: %w", path, err) } return db, nil } // StartTX begins a new database transaction. func StartTX(ctx context.Context, db *sql.DB) (*sql.Tx, error) { return db.BeginTx(ctx, nil) } // EndTX commits or rolls back a transaction based on the error value. // If err is non-nil, the transaction is rolled back. Otherwise it is committed. func EndTX(tx *sql.Tx, err error) error { if err != nil { rbErr := tx.Rollback() if rbErr != nil { return fmt.Errorf("db: rollback failed (%w) after error: %w", rbErr, err) } return err } return tx.Commit() } // Migrate runs all pending migrations against the database. // Migrations are embedded SQL files in the migrations/ directory, // named with a numeric prefix (e.g., 001_initial.sql). func Migrate(database *sql.DB) error { // Ensure schema_version table exists for tracking. _, err := database.Exec(`CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER NOT NULL, applied TEXT NOT NULL )`) if err != nil { return fmt.Errorf("db: failed to ensure schema_version table: %w", err) } currentVersion, err := getCurrentVersion(database) if err != nil { return fmt.Errorf("db: failed to get current schema version: %w", err) } entries, err := migrationsFS.ReadDir("migrations") if err != nil { return fmt.Errorf("db: failed to read migrations directory: %w", err) } // Sort migration files by name to ensure order. sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() }) for _, entry := range entries { if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") { continue } var version int if _, err := fmt.Sscanf(entry.Name(), "%d_", &version); err != nil { return fmt.Errorf("db: failed to parse migration version from %q: %w", entry.Name(), err) } if version <= currentVersion { continue } sqlBytes, err := migrationsFS.ReadFile("migrations/" + entry.Name()) if err != nil { return fmt.Errorf("db: failed to read migration %q: %w", entry.Name(), err) } log.Printf("db: applying migration %d (%s)", version, entry.Name()) tx, err := database.Begin() if err != nil { return fmt.Errorf("db: failed to begin migration transaction: %w", err) } if _, err := tx.Exec(string(sqlBytes)); err != nil { _ = tx.Rollback() return fmt.Errorf("db: migration %d failed: %w", version, err) } if err := tx.Commit(); err != nil { return fmt.Errorf("db: failed to commit migration %d: %w", version, err) } log.Printf("db: migration %d applied successfully", version) } return nil } func getCurrentVersion(database *sql.DB) (int, error) { var version int row := database.QueryRow(`SELECT COALESCE(MAX(version), 0) FROM schema_version`) if err := row.Scan(&version); err != nil { // Table might not have any rows yet — that's version 0. return 0, nil } return version, nil } // DBObject is the interface for types that can be stored in and retrieved from // the database within a transaction. type DBObject interface { Get(ctx context.Context, tx *sql.Tx) error Store(ctx context.Context, tx *sql.Tx) error }