package main import ( "database/sql" "errors" "fmt" "log" "os" "path/filepath" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/sqlite3" _ "github.com/golang-migrate/migrate/v4/source/file" _ "github.com/mattn/go-sqlite3" "github.com/spf13/cobra" "github.com/spf13/viper" ) var ( migrationsDir string steps int ) var migrateCmd = &cobra.Command{ Use: "migrate", Short: "Manage database migrations", Long: `Commands for managing database migrations in the MCIAS system.`, } var migrateUpCmd = &cobra.Command{ Use: "up [steps]", Short: "Apply migrations", Long: `Apply all or a specific number of migrations. If steps is not provided, all pending migrations will be applied.`, Run: func(cmd *cobra.Command, args []string) { runMigration("up", steps) }, } var migrateDownCmd = &cobra.Command{ Use: "down [steps]", Short: "Revert migrations", Long: `Revert all or a specific number of migrations. If steps is not provided, all applied migrations will be reverted.`, Run: func(cmd *cobra.Command, args []string) { runMigration("down", steps) }, } var migrateVersionCmd = &cobra.Command{ Use: "version", Short: "Show current migration version", Long: `Display the current migration version of the database.`, Run: func(cmd *cobra.Command, args []string) { showMigrationVersion() }, } func init() { rootCmd.AddCommand(migrateCmd) migrateCmd.AddCommand(migrateUpCmd) migrateCmd.AddCommand(migrateDownCmd) migrateCmd.AddCommand(migrateVersionCmd) migrateCmd.PersistentFlags().StringVarP(&migrationsDir, "migrations", "m", "database/migrations", "Directory containing migration files") migrateCmd.PersistentFlags().IntVarP(&steps, "steps", "s", 0, "Number of migrations to apply or revert (0 means all)") } func runMigration(direction string, steps int) { dbPath := viper.GetString("db") logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags) absPath, err := filepath.Abs(migrationsDir) if err != nil { logger.Fatalf("Failed to get absolute path for migrations directory: %v", err) } if _, err := os.Stat(absPath); os.IsNotExist(err) { logger.Fatalf("Migrations directory does not exist: %s", absPath) } db, err := openDatabase(dbPath) if err != nil { logger.Fatalf("Failed to open database: %v", err) } defer db.Close() driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) if err != nil { logger.Fatalf("Failed to create migration driver: %v", err) } m, err := migrate.NewWithDatabaseInstance( fmt.Sprintf("file://%s", absPath), "sqlite3", driver) if err != nil { logger.Fatalf("Failed to create migration instance: %v", err) } if direction == "up" { if steps > 0 { err = m.Steps(steps) } else { err = m.Up() } if err != nil && !errors.Is(err, migrate.ErrNoChange) { logger.Fatalf("Failed to apply migrations: %v", err) } logger.Println("Migrations applied successfully") } else if direction == "down" { if steps > 0 { err = m.Steps(-steps) } else { err = m.Down() } if err != nil && !errors.Is(err, migrate.ErrNoChange) { logger.Fatalf("Failed to revert migrations: %v", err) } logger.Println("Migrations reverted successfully") } } func showMigrationVersion() { dbPath := viper.GetString("db") logger := log.New(os.Stdout, "MCIAS Migration: ", log.LstdFlags) db, err := openDatabase(dbPath) if err != nil { logger.Fatalf("Failed to open database: %v", err) } defer db.Close() driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) if err != nil { logger.Fatalf("Failed to create migration driver: %v", err) } absPath, err := filepath.Abs(migrationsDir) if err != nil { logger.Fatalf("Failed to get absolute path for migrations directory: %v", err) } m, err := migrate.NewWithDatabaseInstance( fmt.Sprintf("file://%s", absPath), "sqlite3", driver) if err != nil { logger.Fatalf("Failed to create migration instance: %v", err) } version, dirty, err := m.Version() if err != nil { if errors.Is(err, migrate.ErrNilVersion) { logger.Println("No migrations have been applied yet") return } logger.Fatalf("Failed to get migration version: %v", err) } logger.Printf("Current migration version: %d (dirty: %t)", version, dirty) } func openDatabase(dbPath string) (*sql.DB, error) { dbDir := filepath.Dir(dbPath) if err := os.MkdirAll(dbDir, 0755); err != nil { return nil, fmt.Errorf("failed to create database directory: %w", err) } db, err := sql.Open("sqlite3", dbPath) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } return db, nil }