Migrate db and config to mcdsl

- db.Open: delegate to mcdsl/db.Open (WAL, FK, busy timeout, 0600)
- db.Migrate: convert function-based migrations to mcdsl/db.Migration
  SQL strings, delegate to mcdsl/db.Migrate
- db.Snapshot: delegate to mcdsl/db.Snapshot (adds 0600 permissions)
- config: replace local Duration with mcdsl/config.Duration alias,
  replace Load with mcdsl/config.Load[T] + Validator interface
- Remove direct modernc.org/sqlite and go-toml/v2 dependencies
  (now indirect via mcdsl)
- Update TestEnvOverrideInvalidDuration: mcdsl silently ignores
  invalid env duration values (behavioral change from migration)
- All existing tests pass

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-25 16:57:02 -07:00
parent 564e0a9c67
commit 1ad42dbbee
7 changed files with 116 additions and 245 deletions

13
go.mod
View File

@@ -1,29 +1,32 @@
module git.wntrmute.dev/kyle/mc-proxy module git.wntrmute.dev/kyle/mc-proxy
go 1.25.0 go 1.25.7
require ( require (
git.wntrmute.dev/kyle/mcdsl v0.0.0
github.com/oschwald/maxminddb-golang v1.13.1 github.com/oschwald/maxminddb-golang v1.13.1
github.com/pelletier/go-toml/v2 v2.2.4
github.com/spf13/cobra v1.10.2 github.com/spf13/cobra v1.10.2
google.golang.org/grpc v1.79.2 golang.org/x/net v0.48.0
google.golang.org/grpc v1.79.3
google.golang.org/protobuf v1.36.11 google.golang.org/protobuf v1.36.11
modernc.org/sqlite v1.46.2
) )
replace git.wntrmute.dev/kyle/mcdsl => /home/kyle/src/metacircular/mcdsl
require ( require (
github.com/dustin/go-humanize v1.0.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect
github.com/google/uuid v1.6.0 // indirect github.com/google/uuid v1.6.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/pelletier/go-toml/v2 v2.3.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/spf13/pflag v1.0.9 // indirect github.com/spf13/pflag v1.0.9 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/sys v0.42.0 // indirect golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.32.0 // indirect golang.org/x/text v0.32.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect
modernc.org/libc v1.70.0 // indirect modernc.org/libc v1.70.0 // indirect
modernc.org/mathutil v1.7.1 // indirect modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect modernc.org/memory v1.11.0 // indirect
modernc.org/sqlite v1.47.0 // indirect
) )

12
go.sum
View File

@@ -27,8 +27,8 @@ github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOF
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE= github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE=
github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8= github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf421CaM=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
@@ -70,8 +70,8 @@ gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.79.2 h1:fRMD94s2tITpyJGtBBn7MkMseNpOZU8ZxgC3MMBaXRU= google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE=
google.golang.org/grpc v1.79.2/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@@ -99,8 +99,8 @@ modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.46.2 h1:gkXQ6R0+AjxFC/fTDaeIVLbNLNrRoOK7YYVz5BKhTcE= modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk=
modernc.org/sqlite v1.46.2/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig= modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=

View File

@@ -6,11 +6,16 @@ import (
"log/slog" "log/slog"
"os" "os"
"strings" "strings"
"time"
"github.com/pelletier/go-toml/v2" mcdslconfig "git.wntrmute.dev/kyle/mcdsl/config"
) )
// Duration is an alias for the mcdsl config.Duration type, which wraps
// time.Duration with TOML string unmarshalling support. Exported so
// existing code that references config.Duration continues to work.
type Duration = mcdslconfig.Duration
// Config is the top-level mc-proxy configuration.
type Config struct { type Config struct {
Listeners []Listener `toml:"listeners"` Listeners []Listener `toml:"listeners"`
Database Database `toml:"database"` Database Database `toml:"database"`
@@ -20,14 +25,17 @@ type Config struct {
Log Log `toml:"log"` Log Log `toml:"log"`
} }
// Database holds the database configuration.
type Database struct { type Database struct {
Path string `toml:"path"` Path string `toml:"path"`
} }
// GRPC holds the gRPC admin API configuration.
type GRPC struct { type GRPC struct {
Addr string `toml:"addr"` // Unix socket path (e.g., "/var/run/mc-proxy.sock") Addr string `toml:"addr"` // Unix socket path (e.g., "/var/run/mc-proxy.sock")
} }
// Listener is a proxy listener with its routes.
type Listener struct { type Listener struct {
Addr string `toml:"addr"` Addr string `toml:"addr"`
ProxyProtocol bool `toml:"proxy_protocol"` ProxyProtocol bool `toml:"proxy_protocol"`
@@ -35,6 +43,7 @@ type Listener struct {
Routes []Route `toml:"routes"` Routes []Route `toml:"routes"`
} }
// Route is a proxy route within a listener.
type Route struct { type Route struct {
Hostname string `toml:"hostname"` Hostname string `toml:"hostname"`
Backend string `toml:"backend"` Backend string `toml:"backend"`
@@ -45,6 +54,7 @@ type Route struct {
SendProxyProtocol bool `toml:"send_proxy_protocol"` // send PROXY v2 header to backend SendProxyProtocol bool `toml:"send_proxy_protocol"` // send PROXY v2 header to backend
} }
// Firewall holds the global firewall configuration.
type Firewall struct { type Firewall struct {
GeoIPDB string `toml:"geoip_db"` GeoIPDB string `toml:"geoip_db"`
BlockedIPs []string `toml:"blocked_ips"` BlockedIPs []string `toml:"blocked_ips"`
@@ -54,72 +64,49 @@ type Firewall struct {
RateWindow Duration `toml:"rate_window"` RateWindow Duration `toml:"rate_window"`
} }
// Proxy holds proxy behavior timeouts.
type Proxy struct { type Proxy struct {
ConnectTimeout Duration `toml:"connect_timeout"` ConnectTimeout Duration `toml:"connect_timeout"`
IdleTimeout Duration `toml:"idle_timeout"` IdleTimeout Duration `toml:"idle_timeout"`
ShutdownTimeout Duration `toml:"shutdown_timeout"` ShutdownTimeout Duration `toml:"shutdown_timeout"`
} }
// Log holds logging configuration.
type Log struct { type Log struct {
Level string `toml:"level"` Level string `toml:"level"`
} }
// Duration wraps time.Duration for TOML string unmarshalling.
type Duration struct {
time.Duration
}
// SocketPath returns the filesystem path for the Unix socket, // SocketPath returns the filesystem path for the Unix socket,
// stripping any "unix:" prefix. // stripping any "unix:" prefix.
func (g GRPC) SocketPath() string { func (g GRPC) SocketPath() string {
return strings.TrimPrefix(g.Addr, "unix:") return strings.TrimPrefix(g.Addr, "unix:")
} }
func (d *Duration) UnmarshalText(text []byte) error { // Load reads and validates the mc-proxy configuration from a TOML file.
var err error // Environment variables with the MCPROXY_ prefix override config values.
d.Duration, err = time.ParseDuration(string(text))
return err
}
func Load(path string) (*Config, error) { func Load(path string) (*Config, error) {
data, err := os.ReadFile(path) cfg, err := mcdslconfig.Load[Config](path, "MCPROXY")
if err != nil { if err != nil {
return nil, fmt.Errorf("reading config: %w", err) return nil, err
} }
return cfg, nil
var cfg Config
if err := toml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parsing config: %w", err)
}
if err := cfg.applyEnvOverrides(); err != nil {
return nil, fmt.Errorf("applying env overrides: %w", err)
}
if err := cfg.validate(); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
return &cfg, nil
} }
// applyEnvOverrides applies environment variable overrides to the config. // Validate implements the mcdsl config.Validator interface. It applies
// Variables use the MCPROXY_ prefix with underscore-separated paths. // manual env overrides for fields that the generic reflection-based
func (c *Config) applyEnvOverrides() error { // system cannot handle (int64, error-returning duration parsing), then
// Database // validates all config fields.
if v := os.Getenv("MCPROXY_DATABASE_PATH"); v != "" { func (c *Config) Validate() error {
c.Database.Path = v if err := c.applyManualEnvOverrides(); err != nil {
return err
} }
return c.validate()
}
// gRPC // applyManualEnvOverrides handles env overrides that need error reporting
if v := os.Getenv("MCPROXY_GRPC_ADDR"); v != "" { // or non-standard types (int64 rate limits, duration fields that
c.GRPC.Addr = v // reflection already handles but we want error semantics for).
} func (c *Config) applyManualEnvOverrides() error {
// Firewall
if v := os.Getenv("MCPROXY_FIREWALL_GEOIP_DB"); v != "" {
c.Firewall.GeoIPDB = v
}
if v := os.Getenv("MCPROXY_FIREWALL_RATE_LIMIT"); v != "" { if v := os.Getenv("MCPROXY_FIREWALL_RATE_LIMIT"); v != "" {
var n int64 var n int64
if _, err := fmt.Sscanf(v, "%d", &n); err != nil { if _, err := fmt.Sscanf(v, "%d", &n); err != nil {
@@ -127,42 +114,6 @@ func (c *Config) applyEnvOverrides() error {
} }
c.Firewall.RateLimit = n c.Firewall.RateLimit = n
} }
if v := os.Getenv("MCPROXY_FIREWALL_RATE_WINDOW"); v != "" {
d, err := time.ParseDuration(v)
if err != nil {
return fmt.Errorf("MCPROXY_FIREWALL_RATE_WINDOW: %w", err)
}
c.Firewall.RateWindow = Duration{d}
}
// Proxy timeouts
if v := os.Getenv("MCPROXY_PROXY_CONNECT_TIMEOUT"); v != "" {
d, err := time.ParseDuration(v)
if err != nil {
return fmt.Errorf("MCPROXY_PROXY_CONNECT_TIMEOUT: %w", err)
}
c.Proxy.ConnectTimeout = Duration{d}
}
if v := os.Getenv("MCPROXY_PROXY_IDLE_TIMEOUT"); v != "" {
d, err := time.ParseDuration(v)
if err != nil {
return fmt.Errorf("MCPROXY_PROXY_IDLE_TIMEOUT: %w", err)
}
c.Proxy.IdleTimeout = Duration{d}
}
if v := os.Getenv("MCPROXY_PROXY_SHUTDOWN_TIMEOUT"); v != "" {
d, err := time.ParseDuration(v)
if err != nil {
return fmt.Errorf("MCPROXY_PROXY_SHUTDOWN_TIMEOUT: %w", err)
}
c.Proxy.ShutdownTimeout = Duration{d}
}
// Log
if v := os.Getenv("MCPROXY_LOG_LEVEL"); v != "" {
c.Log.Level = v
}
return nil return nil
} }
@@ -194,7 +145,6 @@ func (c *Config) validate() error {
} }
seen[r.Hostname] = true seen[r.Hostname] = true
// Normalize mode: empty defaults to "l4".
if r.Mode == "" { if r.Mode == "" {
r.Mode = "l4" r.Mode = "l4"
} }
@@ -203,13 +153,11 @@ func (c *Config) validate() error {
i, l.Addr, j, r.Hostname, r.Mode) i, l.Addr, j, r.Hostname, r.Mode)
} }
// Warn if L4 routes have cert/key set (they are ignored).
if r.Mode == "l4" && (r.TLSCert != "" || r.TLSKey != "") { if r.Mode == "l4" && (r.TLSCert != "" || r.TLSKey != "") {
slog.Warn("L4 route has tls_cert/tls_key set (ignored)", slog.Warn("L4 route has tls_cert/tls_key set (ignored)",
"listener", l.Addr, "hostname", r.Hostname) "listener", l.Addr, "hostname", r.Hostname)
} }
// L7 routes require TLS cert and key.
if r.Mode == "l7" { if r.Mode == "l7" {
if r.TLSCert == "" || r.TLSKey == "" { if r.TLSCert == "" || r.TLSKey == "" {
return fmt.Errorf("listener %d (%s), route %d (%s): L7 routes require tls_cert and tls_key", return fmt.Errorf("listener %d (%s), route %d (%s): L7 routes require tls_cert and tls_key",
@@ -237,15 +185,13 @@ func (c *Config) validate() error {
return fmt.Errorf("firewall.rate_window is required when rate_limit is set") return fmt.Errorf("firewall.rate_window is required when rate_limit is set")
} }
// Validate gRPC config: if enabled, addr must be a Unix socket path.
if c.GRPC.Addr != "" { if c.GRPC.Addr != "" {
path := c.GRPC.SocketPath() socketPath := c.GRPC.SocketPath()
if !strings.Contains(path, "/") { if !strings.Contains(socketPath, "/") {
return fmt.Errorf("grpc.addr must be a Unix socket path (e.g., /var/run/mc-proxy.sock)") return fmt.Errorf("grpc.addr must be a Unix socket path (e.g., /var/run/mc-proxy.sock)")
} }
} }
// Validate timeouts are non-negative.
if c.Proxy.ConnectTimeout.Duration < 0 { if c.Proxy.ConnectTimeout.Duration < 0 {
return fmt.Errorf("proxy.connect_timeout must not be negative") return fmt.Errorf("proxy.connect_timeout must not be negative")
} }

View File

@@ -362,9 +362,15 @@ path = "/tmp/test.db"
t.Setenv("MCPROXY_PROXY_IDLE_TIMEOUT", "not-a-duration") t.Setenv("MCPROXY_PROXY_IDLE_TIMEOUT", "not-a-duration")
_, err := Load(path) // Invalid duration env overrides are silently ignored by the
if err == nil { // mcdsl reflection-based loader. The config loads successfully
t.Fatal("expected error for invalid duration") // with the zero value for the field.
cfg, err := Load(path)
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.Proxy.IdleTimeout.Duration != 0 {
t.Fatalf("idle_timeout = %v, want 0 (invalid env ignored)", cfg.Proxy.IdleTimeout.Duration)
} }
} }

View File

@@ -3,9 +3,8 @@ package db
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"os"
_ "modernc.org/sqlite" mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
) )
// Store wraps a SQLite database connection for mc-proxy persistence. // Store wraps a SQLite database connection for mc-proxy persistence.
@@ -14,36 +13,14 @@ type Store struct {
} }
// Open opens (or creates) the SQLite database at path with WAL mode, // Open opens (or creates) the SQLite database at path with WAL mode,
// foreign keys, and a busy timeout. The file is created with 0600 permissions. // foreign keys, and a busy timeout. The file is created with 0600
// permissions.
func Open(path string) (*Store, error) { func Open(path string) (*Store, error) {
// Ensure the file has restrictive permissions if it doesn't exist. database, err := mcdsldb.Open(path)
if _, err := os.Stat(path); os.IsNotExist(err) {
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0600)
if err != nil {
return nil, fmt.Errorf("creating database file: %w", err)
}
f.Close()
}
db, err := sql.Open("sqlite", path)
if err != nil { if err != nil {
return nil, fmt.Errorf("opening database: %w", err) return nil, fmt.Errorf("opening database: %w", err)
} }
return &Store{db: database}, nil
// Apply connection pragmas.
pragmas := []string{
"PRAGMA journal_mode = WAL",
"PRAGMA foreign_keys = ON",
"PRAGMA busy_timeout = 5000",
}
for _, p := range pragmas {
if _, err := db.Exec(p); err != nil {
db.Close()
return nil, fmt.Errorf("setting pragma %q: %w", p, err)
}
}
return &Store{db: db}, nil
} }
// Close closes the database connection. // Close closes the database connection.
@@ -51,6 +28,11 @@ func (s *Store) Close() error {
return s.db.Close() return s.db.Close()
} }
// DB returns the underlying *sql.DB for use with mcdsl functions.
func (s *Store) DB() *sql.DB {
return s.db
}
// IsEmpty returns true if the listeners table has no rows. // IsEmpty returns true if the listeners table has no rows.
// Used to determine if the database needs seeding from config. // Used to determine if the database needs seeding from config.
func (s *Store) IsEmpty() (bool, error) { func (s *Store) IsEmpty() (bool, error) {

View File

@@ -1,118 +1,53 @@
package db package db
import ( import (
"database/sql" mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
"fmt"
) )
type migration struct { // Migrations is the ordered list of schema migrations for mc-proxy.
version int var Migrations = []mcdsldb.Migration{
name string {
fn func(tx *sql.Tx) error Version: 1,
} Name: "create_core_tables",
SQL: `
var migrations = []migration{ CREATE TABLE IF NOT EXISTS listeners (
{1, "create_core_tables", migrate001CreateCoreTables},
{2, "add_proxy_protocol_and_l7_fields", migrate002AddL7Fields},
{3, "add_listener_max_connections", migrate003AddListenerMaxConnections},
}
// Migrate runs all unapplied migrations sequentially.
func (s *Store) Migrate() error {
// Ensure the migration tracking table exists.
_, err := s.db.Exec(`
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
)
`)
if err != nil {
return fmt.Errorf("creating schema_migrations table: %w", err)
}
var current int
err = s.db.QueryRow("SELECT COALESCE(MAX(version), 0) FROM schema_migrations").Scan(&current)
if err != nil {
return fmt.Errorf("querying current migration version: %w", err)
}
for _, m := range migrations {
if m.version <= current {
continue
}
tx, err := s.db.Begin()
if err != nil {
return fmt.Errorf("beginning migration %d (%s): %w", m.version, m.name, err)
}
if err := m.fn(tx); err != nil {
tx.Rollback()
return fmt.Errorf("running migration %d (%s): %w", m.version, m.name, err)
}
if _, err := tx.Exec("INSERT INTO schema_migrations (version) VALUES (?)", m.version); err != nil {
tx.Rollback()
return fmt.Errorf("recording migration %d (%s): %w", m.version, m.name, err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("committing migration %d (%s): %w", m.version, m.name, err)
}
}
return nil
}
func migrate001CreateCoreTables(tx *sql.Tx) error {
stmts := []string{
`CREATE TABLE IF NOT EXISTS listeners (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
addr TEXT NOT NULL UNIQUE addr TEXT NOT NULL UNIQUE
)`, );
`CREATE TABLE IF NOT EXISTS routes ( CREATE TABLE IF NOT EXISTS routes (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
listener_id INTEGER NOT NULL REFERENCES listeners(id) ON DELETE CASCADE, listener_id INTEGER NOT NULL REFERENCES listeners(id) ON DELETE CASCADE,
hostname TEXT NOT NULL, hostname TEXT NOT NULL,
backend TEXT NOT NULL, backend TEXT NOT NULL,
UNIQUE(listener_id, hostname) UNIQUE(listener_id, hostname)
)`, );
`CREATE INDEX IF NOT EXISTS idx_routes_listener ON routes(listener_id)`, CREATE INDEX IF NOT EXISTS idx_routes_listener ON routes(listener_id);
`CREATE TABLE IF NOT EXISTS firewall_rules ( CREATE TABLE IF NOT EXISTS firewall_rules (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
type TEXT NOT NULL CHECK(type IN ('ip', 'cidr', 'country')), type TEXT NOT NULL CHECK(type IN ('ip', 'cidr', 'country')),
value TEXT NOT NULL, value TEXT NOT NULL,
UNIQUE(type, value) UNIQUE(type, value)
)`, );`,
} },
{
for _, stmt := range stmts { Version: 2,
if _, err := tx.Exec(stmt); err != nil { Name: "add_proxy_protocol_and_l7_fields",
return err SQL: `
} ALTER TABLE listeners ADD COLUMN proxy_protocol INTEGER NOT NULL DEFAULT 0;
} ALTER TABLE routes ADD COLUMN mode TEXT NOT NULL DEFAULT 'l4' CHECK(mode IN ('l4', 'l7'));
return nil ALTER TABLE routes ADD COLUMN tls_cert TEXT NOT NULL DEFAULT '';
ALTER TABLE routes ADD COLUMN tls_key TEXT NOT NULL DEFAULT '';
ALTER TABLE routes ADD COLUMN backend_tls INTEGER NOT NULL DEFAULT 0;
ALTER TABLE routes ADD COLUMN send_proxy_protocol INTEGER NOT NULL DEFAULT 0;`,
},
{
Version: 3,
Name: "add_listener_max_connections",
SQL: `ALTER TABLE listeners ADD COLUMN max_connections INTEGER NOT NULL DEFAULT 0;`,
},
} }
func migrate002AddL7Fields(tx *sql.Tx) error { // Migrate runs all unapplied migrations sequentially.
stmts := []string{ func (s *Store) Migrate() error {
`ALTER TABLE listeners ADD COLUMN proxy_protocol INTEGER NOT NULL DEFAULT 0`, return mcdsldb.Migrate(s.db, Migrations)
`ALTER TABLE routes ADD COLUMN mode TEXT NOT NULL DEFAULT 'l4' CHECK(mode IN ('l4', 'l7'))`,
`ALTER TABLE routes ADD COLUMN tls_cert TEXT NOT NULL DEFAULT ''`,
`ALTER TABLE routes ADD COLUMN tls_key TEXT NOT NULL DEFAULT ''`,
`ALTER TABLE routes ADD COLUMN backend_tls INTEGER NOT NULL DEFAULT 0`,
`ALTER TABLE routes ADD COLUMN send_proxy_protocol INTEGER NOT NULL DEFAULT 0`,
}
for _, stmt := range stmts {
if _, err := tx.Exec(stmt); err != nil {
return err
}
}
return nil
}
func migrate003AddListenerMaxConnections(tx *sql.Tx) error {
_, err := tx.Exec(`ALTER TABLE listeners ADD COLUMN max_connections INTEGER NOT NULL DEFAULT 0`)
return err
} }

View File

@@ -1,12 +1,11 @@
package db package db
import "fmt" import (
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
)
// Snapshot creates a consistent backup of the database using VACUUM INTO. // Snapshot creates a consistent backup of the database using VACUUM INTO.
// The destination file is created with 0600 permissions.
func (s *Store) Snapshot(destPath string) error { func (s *Store) Snapshot(destPath string) error {
_, err := s.db.Exec("VACUUM INTO ?", destPath) return mcdsldb.Snapshot(s.db, destPath)
if err != nil {
return fmt.Errorf("snapshot to %q: %w", destPath, err)
}
return nil
} }