package db import ( "database/sql" "fmt" ) type migration struct { version int name string fn func(tx *sql.Tx) error } var migrations = []migration{ {1, "create_core_tables", migrate001CreateCoreTables}, {2, "add_proxy_protocol_and_l7_fields", migrate002AddL7Fields}, } // Migrate runs all unapplied migrations sequentially. func (s *Store) Migrate() error { // Ensure the migration tracking table exists. _, err := s.db.Exec(` CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, applied TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')) ) `) if err != nil { return fmt.Errorf("creating schema_migrations table: %w", err) } var current int err = s.db.QueryRow("SELECT COALESCE(MAX(version), 0) FROM schema_migrations").Scan(¤t) if err != nil { return fmt.Errorf("querying current migration version: %w", err) } for _, m := range migrations { if m.version <= current { continue } tx, err := s.db.Begin() if err != nil { return fmt.Errorf("beginning migration %d (%s): %w", m.version, m.name, err) } if err := m.fn(tx); err != nil { tx.Rollback() return fmt.Errorf("running migration %d (%s): %w", m.version, m.name, err) } if _, err := tx.Exec("INSERT INTO schema_migrations (version) VALUES (?)", m.version); err != nil { tx.Rollback() return fmt.Errorf("recording migration %d (%s): %w", m.version, m.name, err) } if err := tx.Commit(); err != nil { return fmt.Errorf("committing migration %d (%s): %w", m.version, m.name, err) } } return nil } func migrate001CreateCoreTables(tx *sql.Tx) error { stmts := []string{ `CREATE TABLE IF NOT EXISTS listeners ( id INTEGER PRIMARY KEY, addr TEXT NOT NULL UNIQUE )`, `CREATE TABLE IF NOT EXISTS routes ( id INTEGER PRIMARY KEY, listener_id INTEGER NOT NULL REFERENCES listeners(id) ON DELETE CASCADE, hostname TEXT NOT NULL, backend TEXT NOT NULL, UNIQUE(listener_id, hostname) )`, `CREATE INDEX IF NOT EXISTS idx_routes_listener ON routes(listener_id)`, `CREATE TABLE IF NOT EXISTS firewall_rules ( id INTEGER PRIMARY KEY, type TEXT NOT NULL CHECK(type IN ('ip', 'cidr', 'country')), value TEXT NOT NULL, UNIQUE(type, value) )`, } for _, stmt := range stmts { if _, err := tx.Exec(stmt); err != nil { return err } } return nil } func migrate002AddL7Fields(tx *sql.Tx) error { stmts := []string{ `ALTER TABLE listeners ADD COLUMN proxy_protocol INTEGER NOT NULL DEFAULT 0`, `ALTER TABLE routes ADD COLUMN mode TEXT NOT NULL DEFAULT 'l4'`, `ALTER TABLE routes ADD COLUMN tls_cert TEXT NOT NULL DEFAULT ''`, `ALTER TABLE routes ADD COLUMN tls_key TEXT NOT NULL DEFAULT ''`, `ALTER TABLE routes ADD COLUMN backend_tls INTEGER NOT NULL DEFAULT 0`, `ALTER TABLE routes ADD COLUMN send_proxy_protocol INTEGER NOT NULL DEFAULT 0`, } for _, stmt := range stmts { if _, err := tx.Exec(stmt); err != nil { return err } } return nil }