package db import ( "database/sql" "embed" "errors" "fmt" "strings" "github.com/golang-migrate/migrate/v4" sqlitedriver "github.com/golang-migrate/migrate/v4/database/sqlite" "github.com/golang-migrate/migrate/v4/source/iofs" _ "modernc.org/sqlite" // driver registration ) // migrationsFS embeds all migration SQL files from the migrations/ directory. // Each file is named NNN_description.up.sql (and optionally .down.sql). // //go:embed migrations/*.sql var migrationsFS embed.FS // LatestSchemaVersion is the highest migration version defined in the // migrations/ directory. Update this constant whenever a new migration file // is added. const LatestSchemaVersion = 6 // newMigrate constructs a migrate.Migrate instance backed by the embedded SQL // files. It opens a dedicated *sql.DB using the same DSN as the main // database so that calling m.Close() (which closes the underlying connection) // does not affect the caller's main database connection. // // Security: migration SQL is embedded at compile time from the migrations/ // directory and is never loaded from the filesystem at runtime, preventing // injection of arbitrary SQL via a compromised working directory. func newMigrate(database *DB) (*migrate.Migrate, error) { src, err := iofs.New(migrationsFS, "migrations") if err != nil { return nil, fmt.Errorf("db: create migration source: %w", err) } // Open a dedicated connection for the migrator. golang-migrate's sqlite // driver calls db.Close() when the migrator is closed; using a dedicated // connection (same DSN, different *sql.DB) prevents it from closing the // shared connection. For in-memory databases, Open() translates // ":memory:" to a named shared-cache URI so both connections see the same // data. migrateDB, err := sql.Open("sqlite", database.path) if err != nil { return nil, fmt.Errorf("db: open migration connection: %w", err) } migrateDB.SetMaxOpenConns(1) if _, err := migrateDB.Exec("PRAGMA foreign_keys=ON"); err != nil { _ = migrateDB.Close() return nil, fmt.Errorf("db: migration connection pragma: %w", err) } driver, err := sqlitedriver.WithInstance(migrateDB, &sqlitedriver.Config{ MigrationsTable: "schema_migrations", }) if err != nil { _ = migrateDB.Close() return nil, fmt.Errorf("db: create migration driver: %w", err) } m, err := migrate.NewWithInstance("iofs", src, "sqlite", driver) if err != nil { return nil, fmt.Errorf("db: initialise migrator: %w", err) } return m, nil } // Migrate applies any unapplied schema migrations to the database in order. // It is idempotent: running it on an already-current database is safe and // returns nil. // // Existing databases that were migrated by the previous hand-rolled runner // (schema_version table) are handled by the compatibility shim below: the // legacy version is read and used to fast-forward the golang-migrate state // before calling Up, so no migration is applied twice. func Migrate(database *DB) error { // Compatibility shim: if the database was previously migrated by the // hand-rolled runner it has a schema_version table with the current // version. Inform golang-migrate of the existing version so it does // not try to re-apply already-applied migrations. legacyVersion, err := legacySchemaVersion(database) if err != nil { return fmt.Errorf("db: read legacy schema version: %w", err) } m, err := newMigrate(database) if err != nil { return err } defer func() { src, drv := m.Close(); _ = src; _ = drv }() if legacyVersion > 0 { // Only fast-forward from the legacy version when golang-migrate has no // version record of its own yet (ErrNilVersion). If schema_migrations // already has an entry — including a dirty entry from a previously // failed migration — leave it alone and let golang-migrate handle it. // Overriding a non-nil version would discard progress (or a dirty // state that needs idempotent re-application) and cause migrations to // be retried unnecessarily. _, _, versionErr := m.Version() if errors.Is(versionErr, migrate.ErrNilVersion) { if err := m.Force(legacyVersion); err != nil { return fmt.Errorf("db: force legacy schema version %d: %w", legacyVersion, err) } } } if err := m.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) { // A "duplicate column name" error means the failing migration is an // ADD COLUMN that was already applied outside the migration runner // (common during development before a migration file existed). // If this is the last migration and its version matches LatestSchemaVersion, // force it clean so subsequent starts succeed. // // This is intentionally narrow: we only suppress the error when the // dirty version equals the latest known version, preventing accidental // masking of errors in intermediate migrations. if strings.Contains(err.Error(), "duplicate column name") { v, dirty, verErr := m.Version() if verErr == nil && dirty && int(v) == LatestSchemaVersion { //nolint:gosec // G115: safe conversion if forceErr := m.Force(LatestSchemaVersion); forceErr != nil { return fmt.Errorf("db: force after duplicate column: %w", forceErr) } return nil } } return fmt.Errorf("db: apply migrations: %w", err) } return nil } // ForceSchemaVersion marks the database as being at the given version without // running any SQL. This is a break-glass operation: use it to clear a dirty // migration state after verifying (or manually applying) the migration SQL. // // Passing a version that has never been recorded by golang-migrate is safe; // it simply sets the version and clears the dirty flag. The next call to // Migrate will apply any versions higher than the forced one. func ForceSchemaVersion(database *DB, version int) error { m, err := newMigrate(database) if err != nil { return err } defer func() { src, drv := m.Close(); _ = src; _ = drv }() if err := m.Force(version); err != nil { return fmt.Errorf("db: force schema version %d: %w", version, err) } return nil } // SchemaVersion returns the current applied schema version of the database. // Returns 0 if no migrations have been applied yet. func SchemaVersion(database *DB) (int, error) { m, err := newMigrate(database) if err != nil { return 0, err } defer func() { src, drv := m.Close(); _ = src; _ = drv }() v, _, err := m.Version() if errors.Is(err, migrate.ErrNilVersion) { return 0, nil } if err != nil { return 0, fmt.Errorf("db: read schema version: %w", err) } // Security: v is a migration version number (small positive integer); // the uint→int conversion is safe for any realistic schema version count. return int(v), nil //nolint:gosec // G115: migration version is always a small positive integer } // legacySchemaVersion reads the version from the old schema_version table // created by the hand-rolled migration runner. Returns 0 if the table does // not exist (fresh database or already migrated to golang-migrate only). func legacySchemaVersion(database *DB) (int, error) { var version int err := database.sql.QueryRow( `SELECT version FROM schema_version LIMIT 1`, ).Scan(&version) if err != nil { // Table does not exist or is empty — treat as version 0. return 0, nil //nolint:nilerr } return version, nil }