Files
mcias-junie/cmd/mcias/migrate.go

175 lines
4.5 KiB
Go

package main
import (
"database/sql"
"errors"
"fmt"
"log"
"os"
"path/filepath"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
_ "github.com/golang-migrate/migrate/v4/source/file"
_ "github.com/mattn/go-sqlite3"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
var (
migrationsDir string
steps int
)
var migrateCmd = &cobra.Command{
Use: "migrate",
Short: "Manage database migrations",
Long: `Commands for managing database migrations in the MCIAS system.`,
}
var migrateUpCmd = &cobra.Command{
Use: "up [steps]",
Short: "Apply migrations",
Long: `Apply all or a specific number of migrations.
If steps is not provided, all pending migrations will be applied.`,
Run: func(cmd *cobra.Command, args []string) {
runMigration("up", steps)
},
}
var migrateDownCmd = &cobra.Command{
Use: "down [steps]",
Short: "Revert migrations",
Long: `Revert all or a specific number of migrations.
If steps is not provided, all applied migrations will be reverted.`,
Run: func(cmd *cobra.Command, args []string) {
runMigration("down", steps)
},
}
var migrateVersionCmd = &cobra.Command{
Use: "version",
Short: "Show current migration version",
Long: `Display the current migration version of the database.`,
Run: func(cmd *cobra.Command, args []string) {
showMigrationVersion()
},
}
func init() {
rootCmd.AddCommand(migrateCmd)
migrateCmd.AddCommand(migrateUpCmd)
migrateCmd.AddCommand(migrateDownCmd)
migrateCmd.AddCommand(migrateVersionCmd)
migrateCmd.PersistentFlags().StringVarP(&migrationsDir, "migrations", "m", "database/migrations", "Directory containing migration files")
migrateCmd.PersistentFlags().IntVarP(&steps, "steps", "s", 0, "Number of migrations to apply or revert (0 means all)")
}
func runMigration(direction string, steps int) {
dbPath := viper.GetString("db")
logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags)
absPath, err := filepath.Abs(migrationsDir)
if err != nil {
logger.Fatalf("Failed to get absolute path for migrations directory: %v", err)
}
if _, err := os.Stat(absPath); os.IsNotExist(err) {
logger.Fatalf("Migrations directory does not exist: %s", absPath)
}
db, err := openDatabase(dbPath)
if err != nil {
logger.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil {
logger.Fatalf("Failed to create migration driver: %v", err)
}
m, err := migrate.NewWithDatabaseInstance(
fmt.Sprintf("file://%s", absPath),
"sqlite3", driver)
if err != nil {
logger.Fatalf("Failed to create migration instance: %v", err)
}
if direction == "up" {
if steps > 0 {
err = m.Steps(steps)
} else {
err = m.Up()
}
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
logger.Fatalf("Failed to apply migrations: %v", err)
}
logger.Println("Migrations applied successfully")
} else if direction == "down" {
if steps > 0 {
err = m.Steps(-steps)
} else {
err = m.Down()
}
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
logger.Fatalf("Failed to revert migrations: %v", err)
}
logger.Println("Migrations reverted successfully")
}
}
func showMigrationVersion() {
dbPath := viper.GetString("db")
logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags)
db, err := openDatabase(dbPath)
if err != nil {
logger.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil {
logger.Fatalf("Failed to create migration driver: %v", err)
}
absPath, err := filepath.Abs(migrationsDir)
if err != nil {
logger.Fatalf("Failed to get absolute path for migrations directory: %v", err)
}
m, err := migrate.NewWithDatabaseInstance(
fmt.Sprintf("file://%s", absPath),
"sqlite3", driver)
if err != nil {
logger.Fatalf("Failed to create migration instance: %v", err)
}
version, dirty, err := m.Version()
if err != nil {
if errors.Is(err, migrate.ErrNilVersion) {
logger.Println("No migrations have been applied yet")
return
}
logger.Fatalf("Failed to get migration version: %v", err)
}
logger.Printf("Current migration version: %d (dirty: %t)", version, dirty)
}
func openDatabase(dbPath string) (*sql.DB, error) {
dbDir := filepath.Dir(dbPath)
if err := os.MkdirAll(dbDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create database directory: %w", err)
}
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
return db, nil
}