package db import ( "database/sql" "embed" "errors" "fmt" "github.com/golang-migrate/migrate/v4" sqlitedriver "github.com/golang-migrate/migrate/v4/database/sqlite" "github.com/golang-migrate/migrate/v4/source/iofs" _ "modernc.org/sqlite" // driver registration ) // migrationsFS embeds all migration SQL files from the migrations/ directory. // Each file is named NNN_description.up.sql (and optionally .down.sql). // //go:embed migrations/*.sql var migrationsFS embed.FS // LatestSchemaVersion is the highest migration version defined in the // migrations/ directory. Update this constant whenever a new migration file // is added. const LatestSchemaVersion = 5 // newMigrate constructs a migrate.Migrate instance backed by the embedded SQL // files. It opens a dedicated *sql.DB using the same DSN as the main // database so that calling m.Close() (which closes the underlying connection) // does not affect the caller's main database connection. // // Security: migration SQL is embedded at compile time from the migrations/ // directory and is never loaded from the filesystem at runtime, preventing // injection of arbitrary SQL via a compromised working directory. func newMigrate(database *DB) (*migrate.Migrate, error) { src, err := iofs.New(migrationsFS, "migrations") if err != nil { return nil, fmt.Errorf("db: create migration source: %w", err) } // Open a dedicated connection for the migrator. golang-migrate's sqlite // driver calls db.Close() when the migrator is closed; using a dedicated // connection (same DSN, different *sql.DB) prevents it from closing the // shared connection. For in-memory databases, Open() translates // ":memory:" to a named shared-cache URI so both connections see the same // data. migrateDB, err := sql.Open("sqlite", database.path) if err != nil { return nil, fmt.Errorf("db: open migration connection: %w", err) } migrateDB.SetMaxOpenConns(1) if _, err := migrateDB.Exec("PRAGMA foreign_keys=ON"); err != nil { _ = migrateDB.Close() return nil, fmt.Errorf("db: migration connection pragma: %w", err) } driver, err := sqlitedriver.WithInstance(migrateDB, &sqlitedriver.Config{ MigrationsTable: "schema_migrations", }) if err != nil { _ = migrateDB.Close() return nil, fmt.Errorf("db: create migration driver: %w", err) } m, err := migrate.NewWithInstance("iofs", src, "sqlite", driver) if err != nil { return nil, fmt.Errorf("db: initialise migrator: %w", err) } return m, nil } // Migrate applies any unapplied schema migrations to the database in order. // It is idempotent: running it on an already-current database is safe and // returns nil. // // Existing databases that were migrated by the previous hand-rolled runner // (schema_version table) are handled by the compatibility shim below: the // legacy version is read and used to fast-forward the golang-migrate state // before calling Up, so no migration is applied twice. func Migrate(database *DB) error { // Compatibility shim: if the database was previously migrated by the // hand-rolled runner it has a schema_version table with the current // version. Inform golang-migrate of the existing version so it does // not try to re-apply already-applied migrations. legacyVersion, err := legacySchemaVersion(database) if err != nil { return fmt.Errorf("db: read legacy schema version: %w", err) } m, err := newMigrate(database) if err != nil { return err } defer func() { src, drv := m.Close(); _ = src; _ = drv }() if legacyVersion > 0 { // Force the migrator to treat the database as already at // legacyVersion so Up only applies newer migrations. if err := m.Force(legacyVersion); err != nil { return fmt.Errorf("db: force legacy schema version %d: %w", legacyVersion, err) } } if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { return fmt.Errorf("db: apply migrations: %w", err) } return nil } // SchemaVersion returns the current applied schema version of the database. // Returns 0 if no migrations have been applied yet. func SchemaVersion(database *DB) (int, error) { m, err := newMigrate(database) if err != nil { return 0, err } defer func() { src, drv := m.Close(); _ = src; _ = drv }() v, _, err := m.Version() if errors.Is(err, migrate.ErrNilVersion) { return 0, nil } if err != nil { return 0, fmt.Errorf("db: read schema version: %w", err) } // Security: v is a migration version number (small positive integer); // the uint→int conversion is safe for any realistic schema version count. return int(v), nil //nolint:gosec // G115: migration version is always a small positive integer } // legacySchemaVersion reads the version from the old schema_version table // created by the hand-rolled migration runner. Returns 0 if the table does // not exist (fresh database or already migrated to golang-migrate only). func legacySchemaVersion(database *DB) (int, error) { var version int err := database.sql.QueryRow( `SELECT version FROM schema_version LIMIT 1`, ).Scan(&version) if err != nil { // Table does not exist or is empty — treat as version 0. return 0, nil //nolint:nilerr } return version, nil }