diff --git a/cmd/mc-proxy/root.go b/cmd/mc-proxy/root.go index 649174e..05d9369 100644 --- a/cmd/mc-proxy/root.go +++ b/cmd/mc-proxy/root.go @@ -12,6 +12,7 @@ func rootCmd() *cobra.Command { } cmd.AddCommand(serverCmd()) + cmd.AddCommand(snapshotCmd()) return cmd } diff --git a/cmd/mc-proxy/server.go b/cmd/mc-proxy/server.go index 015bb39..f96c624 100644 --- a/cmd/mc-proxy/server.go +++ b/cmd/mc-proxy/server.go @@ -2,14 +2,18 @@ package main import ( "context" + "fmt" "log/slog" "os" "os/signal" + "strings" "syscall" "github.com/spf13/cobra" "git.wntrmute.dev/kyle/mc-proxy/internal/config" + "git.wntrmute.dev/kyle/mc-proxy/internal/db" + "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/kyle/mc-proxy/internal/grpcserver" "git.wntrmute.dev/kyle/mc-proxy/internal/server" ) @@ -30,17 +34,53 @@ func serverCmd() *cobra.Command { Level: parseLogLevel(cfg.Log.Level), })) - srv, err := server.New(cfg, logger, version) + // Open and migrate the database. + store, err := db.Open(cfg.Database.Path) + if err != nil { + return fmt.Errorf("opening database: %w", err) + } + defer store.Close() + + if err := store.Migrate(); err != nil { + return fmt.Errorf("running migrations: %w", err) + } + + // Seed from config on first run, or load from DB. + empty, err := store.IsEmpty() + if err != nil { + return fmt.Errorf("checking database: %w", err) + } + + if empty { + if len(cfg.Listeners) == 0 { + return fmt.Errorf("database is empty and no listeners defined in config for seeding") + } + logger.Info("seeding database from config") + if err := store.Seed(cfg.Listeners, cfg.Firewall); err != nil { + return fmt.Errorf("seeding database: %w", err) + } + } + + // Load listeners and routes from DB. + listenerData, err := loadListenersFromDB(store) if err != nil { return err } + // Load firewall rules from DB. + fw, err := loadFirewallFromDB(store, cfg.Firewall.GeoIPDB) + if err != nil { + return err + } + + srv := server.New(cfg, fw, listenerData, logger, version) + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer cancel() // Start gRPC admin API if configured. if cfg.GRPC.Addr != "" { - grpcSrv, ln, err := grpcserver.New(cfg.GRPC, srv, logger) + grpcSrv, ln, err := grpcserver.New(cfg.GRPC, srv, store, logger) if err != nil { return err } @@ -75,6 +115,56 @@ func serverCmd() *cobra.Command { return cmd } +func loadListenersFromDB(store *db.Store) ([]server.ListenerData, error) { + dbListeners, err := store.ListListeners() + if err != nil { + return nil, fmt.Errorf("loading listeners: %w", err) + } + + var result []server.ListenerData + for _, l := range dbListeners { + dbRoutes, err := store.ListRoutes(l.ID) + if err != nil { + return nil, fmt.Errorf("loading routes for listener %q: %w", l.Addr, err) + } + routes := make(map[string]string, len(dbRoutes)) + for _, r := range dbRoutes { + routes[strings.ToLower(r.Hostname)] = r.Backend + } + result = append(result, server.ListenerData{ + ID: l.ID, + Addr: l.Addr, + Routes: routes, + }) + } + return result, nil +} + +func loadFirewallFromDB(store *db.Store, geoIPPath string) (*firewall.Firewall, error) { + rules, err := store.ListFirewallRules() + if err != nil { + return nil, fmt.Errorf("loading firewall rules: %w", err) + } + + var ips, cidrs, countries []string + for _, r := range rules { + switch r.Type { + case "ip": + ips = append(ips, r.Value) + case "cidr": + cidrs = append(cidrs, r.Value) + case "country": + countries = append(countries, r.Value) + } + } + + fw, err := firewall.New(geoIPPath, ips, cidrs, countries) + if err != nil { + return nil, fmt.Errorf("initializing firewall: %w", err) + } + return fw, nil +} + func parseLogLevel(s string) slog.Level { switch s { case "debug": diff --git a/cmd/mc-proxy/snapshot.go b/cmd/mc-proxy/snapshot.go new file mode 100644 index 0000000..4c36570 --- /dev/null +++ b/cmd/mc-proxy/snapshot.go @@ -0,0 +1,54 @@ +package main + +import ( + "fmt" + "path/filepath" + "time" + + "github.com/spf13/cobra" + + "git.wntrmute.dev/kyle/mc-proxy/internal/config" + "git.wntrmute.dev/kyle/mc-proxy/internal/db" +) + +func snapshotCmd() *cobra.Command { + var ( + configPath string + outputPath string + ) + + cmd := &cobra.Command{ + Use: "snapshot", + Short: "Create a database backup", + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := config.Load(configPath) + if err != nil { + return err + } + + store, err := db.Open(cfg.Database.Path) + if err != nil { + return fmt.Errorf("opening database: %w", err) + } + defer store.Close() + + if outputPath == "" { + dir := filepath.Dir(cfg.Database.Path) + ts := time.Now().UTC().Format("20060102T150405Z") + outputPath = filepath.Join(dir, "backups", fmt.Sprintf("mc-proxy-%s.db", ts)) + } + + if err := store.Snapshot(outputPath); err != nil { + return err + } + + fmt.Printf("snapshot written to %s\n", outputPath) + return nil + }, + } + + cmd.Flags().StringVarP(&configPath, "config", "c", "mc-proxy.toml", "path to configuration file") + cmd.Flags().StringVarP(&outputPath, "output", "o", "", "output path (default: backups/mc-proxy-.db)") + + return cmd +} diff --git a/go.mod b/go.mod index f133b1b..a6b55ce 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module git.wntrmute.dev/kyle/mc-proxy -go 1.24.0 +go 1.25.0 require ( github.com/oschwald/maxminddb-golang v1.13.1 @@ -8,13 +8,22 @@ require ( github.com/spf13/cobra v1.10.2 google.golang.org/grpc v1.79.2 google.golang.org/protobuf v1.36.11 + modernc.org/sqlite v1.46.2 ) require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/spf13/pflag v1.0.9 // indirect golang.org/x/net v0.48.0 // indirect - golang.org/x/sys v0.39.0 // indirect + golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.32.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect + modernc.org/libc v1.70.0 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect ) diff --git a/go.sum b/go.sum index 575f3f8..ee0ca60 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,8 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= @@ -11,16 +13,26 @@ github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE= github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= @@ -41,12 +53,19 @@ go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5w go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= @@ -58,3 +77,31 @@ google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.32.0 h1:hjG66bI/kqIPX1b2yT6fr/jt+QedtP2fqojG2VrFuVw= +modernc.org/ccgo/v4 v4.32.0/go.mod h1:6F08EBCx5uQc38kMGl+0Nm0oWczoo1c7cgpzEry7Uc0= +modernc.org/fileutil v1.4.0 h1:j6ZzNTftVS054gi281TyLjHPp6CPHr2KCxEXjEbD6SM= +modernc.org/fileutil v1.4.0/go.mod h1:EqdKFDxiByqxLk8ozOxObDSfcVOv/54xDs/DUHdvCUU= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.2 h1:ZtDCnhonXSZexk/AYsegNRV1lJGgaNZJuKjJSWKyEqo= +modernc.org/gc/v3 v3.1.2/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw= +modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.46.2 h1:gkXQ6R0+AjxFC/fTDaeIVLbNLNrRoOK7YYVz5BKhTcE= +modernc.org/sqlite v1.46.2/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/internal/config/config.go b/internal/config/config.go index 5ceb98d..fc9d1a1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,12 +10,17 @@ import ( type Config struct { Listeners []Listener `toml:"listeners"` + Database Database `toml:"database"` GRPC GRPC `toml:"grpc"` Firewall Firewall `toml:"firewall"` Proxy Proxy `toml:"proxy"` Log Log `toml:"log"` } +type Database struct { + Path string `toml:"path"` +} + type GRPC struct { Addr string `toml:"addr"` TLSCert string `toml:"tls_cert"` @@ -80,18 +85,15 @@ func Load(path string) (*Config, error) { } func (c *Config) validate() error { - if len(c.Listeners) == 0 { - return fmt.Errorf("at least one listener is required") + if c.Database.Path == "" { + return fmt.Errorf("database.path is required") } + // Validate listeners if provided (used for seeding on first run). for i, l := range c.Listeners { if l.Addr == "" { return fmt.Errorf("listener %d: addr is required", i) } - if len(l.Routes) == 0 { - return fmt.Errorf("listener %d (%s): at least one route is required", i, l.Addr) - } - seen := make(map[string]bool) for j, r := range l.Routes { if r.Hostname == "" { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index eee7dbb..909ce4a 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -11,6 +11,9 @@ func TestLoadValid(t *testing.T) { path := filepath.Join(dir, "test.toml") data := ` +[database] +path = "/tmp/test.db" + [[listeners]] addr = ":443" @@ -49,11 +52,37 @@ level = "info" } } -func TestLoadNoListeners(t *testing.T) { +func TestLoadNoDatabasePath(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "test.toml") data := ` +[[listeners]] +addr = ":443" + + [[listeners.routes]] + hostname = "example.com" + backend = "127.0.0.1:8443" +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for missing database path") + } +} + +func TestLoadNoListenersValid(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + // No listeners is valid — DB may already have them. + data := ` +[database] +path = "/tmp/test.db" + [log] level = "info" ` @@ -62,26 +91,8 @@ level = "info" } _, err := Load(path) - if err == nil { - t.Fatal("expected error for missing listeners") - } -} - -func TestLoadNoRoutes(t *testing.T) { - dir := t.TempDir() - path := filepath.Join(dir, "test.toml") - - data := ` -[[listeners]] -addr = ":443" -` - if err := os.WriteFile(path, []byte(data), 0600); err != nil { - t.Fatalf("write config: %v", err) - } - - _, err := Load(path) - if err == nil { - t.Fatal("expected error for missing routes") + if err != nil { + t.Fatalf("unexpected error: %v", err) } } @@ -90,6 +101,9 @@ func TestLoadDuplicateHostnames(t *testing.T) { path := filepath.Join(dir, "test.toml") data := ` +[database] +path = "/tmp/test.db" + [[listeners]] addr = ":443" @@ -116,6 +130,9 @@ func TestLoadGeoIPRequiredWithCountries(t *testing.T) { path := filepath.Join(dir, "test.toml") data := ` +[database] +path = "/tmp/test.db" + [[listeners]] addr = ":443" @@ -141,6 +158,9 @@ func TestLoadMultipleListeners(t *testing.T) { path := filepath.Join(dir, "test.toml") data := ` +[database] +path = "/tmp/test.db" + [[listeners]] addr = ":443" diff --git a/internal/db/db.go b/internal/db/db.go new file mode 100644 index 0000000..a73ae20 --- /dev/null +++ b/internal/db/db.go @@ -0,0 +1,63 @@ +package db + +import ( + "database/sql" + "fmt" + "os" + + _ "modernc.org/sqlite" +) + +// Store wraps a SQLite database connection for mc-proxy persistence. +type Store struct { + db *sql.DB +} + +// Open opens (or creates) the SQLite database at path with WAL mode, +// foreign keys, and a busy timeout. The file is created with 0600 permissions. +func Open(path string) (*Store, error) { + // Ensure the file has restrictive permissions if it doesn't exist. + if _, err := os.Stat(path); os.IsNotExist(err) { + f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0600) + if err != nil { + return nil, fmt.Errorf("creating database file: %w", err) + } + f.Close() + } + + db, err := sql.Open("sqlite", path) + if err != nil { + return nil, fmt.Errorf("opening database: %w", err) + } + + // Apply connection pragmas. + pragmas := []string{ + "PRAGMA journal_mode = WAL", + "PRAGMA foreign_keys = ON", + "PRAGMA busy_timeout = 5000", + } + for _, p := range pragmas { + if _, err := db.Exec(p); err != nil { + db.Close() + return nil, fmt.Errorf("setting pragma %q: %w", p, err) + } + } + + return &Store{db: db}, nil +} + +// Close closes the database connection. +func (s *Store) Close() error { + return s.db.Close() +} + +// IsEmpty returns true if the listeners table has no rows. +// Used to determine if the database needs seeding from config. +func (s *Store) IsEmpty() (bool, error) { + var count int + err := s.db.QueryRow("SELECT COUNT(*) FROM listeners").Scan(&count) + if err != nil { + return false, err + } + return count == 0, nil +} diff --git a/internal/db/db_test.go b/internal/db/db_test.go new file mode 100644 index 0000000..3cb9080 --- /dev/null +++ b/internal/db/db_test.go @@ -0,0 +1,331 @@ +package db + +import ( + "path/filepath" + "testing" + + "git.wntrmute.dev/kyle/mc-proxy/internal/config" +) + +func openTestDB(t *testing.T) *Store { + t.Helper() + dir := t.TempDir() + store, err := Open(filepath.Join(dir, "test.db")) + if err != nil { + t.Fatalf("open: %v", err) + } + if err := store.Migrate(); err != nil { + t.Fatalf("migrate: %v", err) + } + t.Cleanup(func() { store.Close() }) + return store +} + +func TestMigrate(t *testing.T) { + store := openTestDB(t) + + // Running migrate again should be idempotent. + if err := store.Migrate(); err != nil { + t.Fatalf("second migrate: %v", err) + } +} + +func TestIsEmpty(t *testing.T) { + store := openTestDB(t) + + empty, err := store.IsEmpty() + if err != nil { + t.Fatalf("is empty: %v", err) + } + if !empty { + t.Fatal("expected empty database") + } + + if _, err := store.CreateListener(":443"); err != nil { + t.Fatalf("create listener: %v", err) + } + + empty, err = store.IsEmpty() + if err != nil { + t.Fatalf("is empty: %v", err) + } + if empty { + t.Fatal("expected non-empty database") + } +} + +func TestListenerCRUD(t *testing.T) { + store := openTestDB(t) + + id, err := store.CreateListener(":443") + if err != nil { + t.Fatalf("create: %v", err) + } + if id == 0 { + t.Fatal("expected non-zero ID") + } + + listeners, err := store.ListListeners() + if err != nil { + t.Fatalf("list: %v", err) + } + if len(listeners) != 1 { + t.Fatalf("got %d listeners, want 1", len(listeners)) + } + if listeners[0].Addr != ":443" { + t.Fatalf("got addr %q, want %q", listeners[0].Addr, ":443") + } + + l, err := store.GetListenerByAddr(":443") + if err != nil { + t.Fatalf("get by addr: %v", err) + } + if l.ID != id { + t.Fatalf("got ID %d, want %d", l.ID, id) + } + + if err := store.DeleteListener(id); err != nil { + t.Fatalf("delete: %v", err) + } + + listeners, err = store.ListListeners() + if err != nil { + t.Fatalf("list after delete: %v", err) + } + if len(listeners) != 0 { + t.Fatalf("got %d listeners after delete, want 0", len(listeners)) + } +} + +func TestListenerDuplicateAddr(t *testing.T) { + store := openTestDB(t) + + if _, err := store.CreateListener(":443"); err != nil { + t.Fatalf("first create: %v", err) + } + if _, err := store.CreateListener(":443"); err == nil { + t.Fatal("expected error for duplicate addr") + } +} + +func TestRouteCRUD(t *testing.T) { + store := openTestDB(t) + + listenerID, err := store.CreateListener(":443") + if err != nil { + t.Fatalf("create listener: %v", err) + } + + routeID, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:8443") + if err != nil { + t.Fatalf("create route: %v", err) + } + if routeID == 0 { + t.Fatal("expected non-zero route ID") + } + + routes, err := store.ListRoutes(listenerID) + if err != nil { + t.Fatalf("list routes: %v", err) + } + if len(routes) != 1 { + t.Fatalf("got %d routes, want 1", len(routes)) + } + if routes[0].Hostname != "example.com" { + t.Fatalf("got hostname %q, want %q", routes[0].Hostname, "example.com") + } + + if err := store.DeleteRoute(listenerID, "example.com"); err != nil { + t.Fatalf("delete route: %v", err) + } + + routes, err = store.ListRoutes(listenerID) + if err != nil { + t.Fatalf("list after delete: %v", err) + } + if len(routes) != 0 { + t.Fatalf("got %d routes after delete, want 0", len(routes)) + } +} + +func TestRouteDuplicateHostname(t *testing.T) { + store := openTestDB(t) + + listenerID, _ := store.CreateListener(":443") + if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:8443"); err != nil { + t.Fatalf("first create: %v", err) + } + if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:9443"); err == nil { + t.Fatal("expected error for duplicate hostname on same listener") + } +} + +func TestRouteCascadeDelete(t *testing.T) { + store := openTestDB(t) + + listenerID, _ := store.CreateListener(":443") + store.CreateRoute(listenerID, "a.example.com", "127.0.0.1:8443") + store.CreateRoute(listenerID, "b.example.com", "127.0.0.1:9443") + + if err := store.DeleteListener(listenerID); err != nil { + t.Fatalf("delete listener: %v", err) + } + + routes, err := store.ListRoutes(listenerID) + if err != nil { + t.Fatalf("list routes: %v", err) + } + if len(routes) != 0 { + t.Fatalf("got %d routes after cascade delete, want 0", len(routes)) + } +} + +func TestFirewallRuleCRUD(t *testing.T) { + store := openTestDB(t) + + id, err := store.CreateFirewallRule("ip", "192.0.2.1") + if err != nil { + t.Fatalf("create: %v", err) + } + if id == 0 { + t.Fatal("expected non-zero ID") + } + + if _, err := store.CreateFirewallRule("cidr", "198.51.100.0/24"); err != nil { + t.Fatalf("create cidr: %v", err) + } + if _, err := store.CreateFirewallRule("country", "CN"); err != nil { + t.Fatalf("create country: %v", err) + } + + rules, err := store.ListFirewallRules() + if err != nil { + t.Fatalf("list: %v", err) + } + if len(rules) != 3 { + t.Fatalf("got %d rules, want 3", len(rules)) + } + + if err := store.DeleteFirewallRule("ip", "192.0.2.1"); err != nil { + t.Fatalf("delete: %v", err) + } + + rules, err = store.ListFirewallRules() + if err != nil { + t.Fatalf("list after delete: %v", err) + } + if len(rules) != 2 { + t.Fatalf("got %d rules after delete, want 2", len(rules)) + } +} + +func TestFirewallRuleDuplicate(t *testing.T) { + store := openTestDB(t) + + if _, err := store.CreateFirewallRule("ip", "192.0.2.1"); err != nil { + t.Fatalf("first create: %v", err) + } + if _, err := store.CreateFirewallRule("ip", "192.0.2.1"); err == nil { + t.Fatal("expected error for duplicate rule") + } +} + +func TestSeed(t *testing.T) { + store := openTestDB(t) + + listeners := []config.Listener{ + { + Addr: ":443", + Routes: []config.Route{ + {Hostname: "a.example.com", Backend: "127.0.0.1:8443"}, + {Hostname: "b.example.com", Backend: "127.0.0.1:9443"}, + }, + }, + { + Addr: ":8443", + Routes: []config.Route{ + {Hostname: "c.example.com", Backend: "127.0.0.1:18443"}, + }, + }, + } + + fw := config.Firewall{ + BlockedIPs: []string{"192.0.2.1"}, + BlockedCIDRs: []string{"198.51.100.0/24"}, + BlockedCountries: []string{"cn", "KP"}, + } + + if err := store.Seed(listeners, fw); err != nil { + t.Fatalf("seed: %v", err) + } + + dbListeners, err := store.ListListeners() + if err != nil { + t.Fatalf("list listeners: %v", err) + } + if len(dbListeners) != 2 { + t.Fatalf("got %d listeners, want 2", len(dbListeners)) + } + + routes, err := store.ListRoutes(dbListeners[0].ID) + if err != nil { + t.Fatalf("list routes: %v", err) + } + if len(routes) != 2 { + t.Fatalf("got %d routes for listener 0, want 2", len(routes)) + } + + rules, err := store.ListFirewallRules() + if err != nil { + t.Fatalf("list firewall rules: %v", err) + } + if len(rules) != 4 { + t.Fatalf("got %d firewall rules, want 4", len(rules)) + } +} + +func TestSnapshot(t *testing.T) { + store := openTestDB(t) + + store.CreateListener(":443") + + dest := filepath.Join(t.TempDir(), "backup.db") + if err := store.Snapshot(dest); err != nil { + t.Fatalf("snapshot: %v", err) + } + + // Open the snapshot and verify. + backup, err := Open(dest) + if err != nil { + t.Fatalf("open backup: %v", err) + } + defer backup.Close() + + if err := backup.Migrate(); err != nil { + t.Fatalf("migrate backup: %v", err) + } + + listeners, err := backup.ListListeners() + if err != nil { + t.Fatalf("list from backup: %v", err) + } + if len(listeners) != 1 { + t.Fatalf("backup has %d listeners, want 1", len(listeners)) + } +} + +func TestDeleteNonexistent(t *testing.T) { + store := openTestDB(t) + + if err := store.DeleteListener(999); err == nil { + t.Fatal("expected error deleting nonexistent listener") + } + + if err := store.DeleteRoute(999, "example.com"); err == nil { + t.Fatal("expected error deleting nonexistent route") + } + + if err := store.DeleteFirewallRule("ip", "1.2.3.4"); err == nil { + t.Fatal("expected error deleting nonexistent firewall rule") + } +} diff --git a/internal/db/firewall.go b/internal/db/firewall.go new file mode 100644 index 0000000..73156a4 --- /dev/null +++ b/internal/db/firewall.go @@ -0,0 +1,57 @@ +package db + +import "fmt" + +// FirewallRule is a database firewall rule record. +type FirewallRule struct { + ID int64 + Type string // "ip", "cidr", "country" + Value string +} + +// ListFirewallRules returns all firewall rules. +func (s *Store) ListFirewallRules() ([]FirewallRule, error) { + rows, err := s.db.Query("SELECT id, type, value FROM firewall_rules ORDER BY type, value") + if err != nil { + return nil, fmt.Errorf("querying firewall rules: %w", err) + } + defer rows.Close() + + var rules []FirewallRule + for rows.Next() { + var r FirewallRule + if err := rows.Scan(&r.ID, &r.Type, &r.Value); err != nil { + return nil, fmt.Errorf("scanning firewall rule: %w", err) + } + rules = append(rules, r) + } + return rules, rows.Err() +} + +// CreateFirewallRule inserts a firewall rule and returns its ID. +func (s *Store) CreateFirewallRule(ruleType, value string) (int64, error) { + result, err := s.db.Exec( + "INSERT INTO firewall_rules (type, value) VALUES (?, ?)", + ruleType, value, + ) + if err != nil { + return 0, fmt.Errorf("inserting firewall rule: %w", err) + } + return result.LastInsertId() +} + +// DeleteFirewallRule deletes a firewall rule by type and value. +func (s *Store) DeleteFirewallRule(ruleType, value string) error { + result, err := s.db.Exec( + "DELETE FROM firewall_rules WHERE type = ? AND value = ?", + ruleType, value, + ) + if err != nil { + return fmt.Errorf("deleting firewall rule: %w", err) + } + n, _ := result.RowsAffected() + if n == 0 { + return fmt.Errorf("firewall rule (%s, %s) not found", ruleType, value) + } + return nil +} diff --git a/internal/db/listeners.go b/internal/db/listeners.go new file mode 100644 index 0000000..0daeae7 --- /dev/null +++ b/internal/db/listeners.go @@ -0,0 +1,61 @@ +package db + +import "fmt" + +// Listener is a database listener record. +type Listener struct { + ID int64 + Addr string +} + +// ListListeners returns all listeners. +func (s *Store) ListListeners() ([]Listener, error) { + rows, err := s.db.Query("SELECT id, addr FROM listeners ORDER BY id") + if err != nil { + return nil, fmt.Errorf("querying listeners: %w", err) + } + defer rows.Close() + + var listeners []Listener + for rows.Next() { + var l Listener + if err := rows.Scan(&l.ID, &l.Addr); err != nil { + return nil, fmt.Errorf("scanning listener: %w", err) + } + listeners = append(listeners, l) + } + return listeners, rows.Err() +} + +// CreateListener inserts a listener and returns its ID. +func (s *Store) CreateListener(addr string) (int64, error) { + result, err := s.db.Exec("INSERT INTO listeners (addr) VALUES (?)", addr) + if err != nil { + return 0, fmt.Errorf("inserting listener: %w", err) + } + return result.LastInsertId() +} + +// DeleteListener deletes a listener by ID. Routes are cascade-deleted. +func (s *Store) DeleteListener(id int64) error { + result, err := s.db.Exec("DELETE FROM listeners WHERE id = ?", id) + if err != nil { + return fmt.Errorf("deleting listener: %w", err) + } + n, _ := result.RowsAffected() + if n == 0 { + return fmt.Errorf("listener %d not found", id) + } + return nil +} + +// GetListenerByAddr returns a listener by its address. +func (s *Store) GetListenerByAddr(addr string) (Listener, error) { + var l Listener + err := s.db.QueryRow("SELECT id, addr FROM listeners WHERE addr = ?", addr). + Scan(&l.ID, &l.Addr) + if err != nil { + return Listener{}, fmt.Errorf("querying listener by addr %q: %w", addr, err) + } + return l, nil +} diff --git a/internal/db/migrations.go b/internal/db/migrations.go new file mode 100644 index 0000000..caee4e4 --- /dev/null +++ b/internal/db/migrations.go @@ -0,0 +1,93 @@ +package db + +import ( + "database/sql" + "fmt" +) + +type migration struct { + version int + name string + fn func(tx *sql.Tx) error +} + +var migrations = []migration{ + {1, "create_core_tables", migrate001CreateCoreTables}, +} + +// Migrate runs all unapplied migrations sequentially. +func (s *Store) Migrate() error { + // Ensure the migration tracking table exists. + _, err := s.db.Exec(` + CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')) + ) + `) + if err != nil { + return fmt.Errorf("creating schema_migrations table: %w", err) + } + + var current int + err = s.db.QueryRow("SELECT COALESCE(MAX(version), 0) FROM schema_migrations").Scan(¤t) + if err != nil { + return fmt.Errorf("querying current migration version: %w", err) + } + + for _, m := range migrations { + if m.version <= current { + continue + } + + tx, err := s.db.Begin() + if err != nil { + return fmt.Errorf("beginning migration %d (%s): %w", m.version, m.name, err) + } + + if err := m.fn(tx); err != nil { + tx.Rollback() + return fmt.Errorf("running migration %d (%s): %w", m.version, m.name, err) + } + + if _, err := tx.Exec("INSERT INTO schema_migrations (version) VALUES (?)", m.version); err != nil { + tx.Rollback() + return fmt.Errorf("recording migration %d (%s): %w", m.version, m.name, err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("committing migration %d (%s): %w", m.version, m.name, err) + } + } + + return nil +} + +func migrate001CreateCoreTables(tx *sql.Tx) error { + stmts := []string{ + `CREATE TABLE IF NOT EXISTS listeners ( + id INTEGER PRIMARY KEY, + addr TEXT NOT NULL UNIQUE + )`, + `CREATE TABLE IF NOT EXISTS routes ( + id INTEGER PRIMARY KEY, + listener_id INTEGER NOT NULL REFERENCES listeners(id) ON DELETE CASCADE, + hostname TEXT NOT NULL, + backend TEXT NOT NULL, + UNIQUE(listener_id, hostname) + )`, + `CREATE INDEX IF NOT EXISTS idx_routes_listener ON routes(listener_id)`, + `CREATE TABLE IF NOT EXISTS firewall_rules ( + id INTEGER PRIMARY KEY, + type TEXT NOT NULL CHECK(type IN ('ip', 'cidr', 'country')), + value TEXT NOT NULL, + UNIQUE(type, value) + )`, + } + + for _, stmt := range stmts { + if _, err := tx.Exec(stmt); err != nil { + return err + } + } + return nil +} diff --git a/internal/db/routes.go b/internal/db/routes.go new file mode 100644 index 0000000..304d527 --- /dev/null +++ b/internal/db/routes.go @@ -0,0 +1,61 @@ +package db + +import "fmt" + +// Route is a database route record. +type Route struct { + ID int64 + ListenerID int64 + Hostname string + Backend string +} + +// ListRoutes returns all routes for a listener. +func (s *Store) ListRoutes(listenerID int64) ([]Route, error) { + rows, err := s.db.Query( + "SELECT id, listener_id, hostname, backend FROM routes WHERE listener_id = ? ORDER BY hostname", + listenerID, + ) + if err != nil { + return nil, fmt.Errorf("querying routes: %w", err) + } + defer rows.Close() + + var routes []Route + for rows.Next() { + var r Route + if err := rows.Scan(&r.ID, &r.ListenerID, &r.Hostname, &r.Backend); err != nil { + return nil, fmt.Errorf("scanning route: %w", err) + } + routes = append(routes, r) + } + return routes, rows.Err() +} + +// CreateRoute inserts a route and returns its ID. +func (s *Store) CreateRoute(listenerID int64, hostname, backend string) (int64, error) { + result, err := s.db.Exec( + "INSERT INTO routes (listener_id, hostname, backend) VALUES (?, ?, ?)", + listenerID, hostname, backend, + ) + if err != nil { + return 0, fmt.Errorf("inserting route: %w", err) + } + return result.LastInsertId() +} + +// DeleteRoute deletes a route by listener ID and hostname. +func (s *Store) DeleteRoute(listenerID int64, hostname string) error { + result, err := s.db.Exec( + "DELETE FROM routes WHERE listener_id = ? AND hostname = ?", + listenerID, hostname, + ) + if err != nil { + return fmt.Errorf("deleting route: %w", err) + } + n, _ := result.RowsAffected() + if n == 0 { + return fmt.Errorf("route %q not found on listener %d", hostname, listenerID) + } + return nil +} diff --git a/internal/db/seed.go b/internal/db/seed.go new file mode 100644 index 0000000..848fb58 --- /dev/null +++ b/internal/db/seed.go @@ -0,0 +1,56 @@ +package db + +import ( + "fmt" + "strings" + + "git.wntrmute.dev/kyle/mc-proxy/internal/config" +) + +// Seed populates the database from TOML config data. Only called when the +// database is empty (first run). +func (s *Store) Seed(listeners []config.Listener, fw config.Firewall) error { + tx, err := s.db.Begin() + if err != nil { + return fmt.Errorf("beginning seed transaction: %w", err) + } + defer tx.Rollback() + + for _, l := range listeners { + result, err := tx.Exec("INSERT INTO listeners (addr) VALUES (?)", l.Addr) + if err != nil { + return fmt.Errorf("seeding listener %q: %w", l.Addr, err) + } + listenerID, _ := result.LastInsertId() + + for _, r := range l.Routes { + _, err := tx.Exec( + "INSERT INTO routes (listener_id, hostname, backend) VALUES (?, ?, ?)", + listenerID, strings.ToLower(r.Hostname), r.Backend, + ) + if err != nil { + return fmt.Errorf("seeding route %q on listener %q: %w", r.Hostname, l.Addr, err) + } + } + } + + for _, ip := range fw.BlockedIPs { + if _, err := tx.Exec("INSERT INTO firewall_rules (type, value) VALUES ('ip', ?)", ip); err != nil { + return fmt.Errorf("seeding blocked IP %q: %w", ip, err) + } + } + + for _, cidr := range fw.BlockedCIDRs { + if _, err := tx.Exec("INSERT INTO firewall_rules (type, value) VALUES ('cidr', ?)", cidr); err != nil { + return fmt.Errorf("seeding blocked CIDR %q: %w", cidr, err) + } + } + + for _, code := range fw.BlockedCountries { + if _, err := tx.Exec("INSERT INTO firewall_rules (type, value) VALUES ('country', ?)", strings.ToUpper(code)); err != nil { + return fmt.Errorf("seeding blocked country %q: %w", code, err) + } + } + + return tx.Commit() +} diff --git a/internal/db/snapshot.go b/internal/db/snapshot.go new file mode 100644 index 0000000..1c51c9a --- /dev/null +++ b/internal/db/snapshot.go @@ -0,0 +1,12 @@ +package db + +import "fmt" + +// Snapshot creates a consistent backup of the database using VACUUM INTO. +func (s *Store) Snapshot(destPath string) error { + _, err := s.db.Exec("VACUUM INTO ?", destPath) + if err != nil { + return fmt.Errorf("snapshot to %q: %w", destPath, err) + } + return nil +} diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index 5c19687..13669a4 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -8,7 +8,6 @@ import ( "github.com/oschwald/maxminddb-golang" - "git.wntrmute.dev/kyle/mc-proxy/internal/config" ) type geoIPRecord struct { @@ -27,15 +26,15 @@ type Firewall struct { mu sync.RWMutex // protects all mutable state } -// New creates a Firewall from the given configuration. -func New(cfg config.Firewall) (*Firewall, error) { +// New creates a Firewall from raw rule lists and an optional GeoIP database path. +func New(geoIPPath string, ips, cidrs, countries []string) (*Firewall, error) { f := &Firewall{ blockedIPs: make(map[netip.Addr]struct{}), blockedCountries: make(map[string]struct{}), - geoDBPath: cfg.GeoIPDB, + geoDBPath: geoIPPath, } - for _, ip := range cfg.BlockedIPs { + for _, ip := range ips { addr, err := netip.ParseAddr(ip) if err != nil { return nil, fmt.Errorf("invalid blocked IP %q: %w", ip, err) @@ -43,7 +42,7 @@ func New(cfg config.Firewall) (*Firewall, error) { f.blockedIPs[addr] = struct{}{} } - for _, cidr := range cfg.BlockedCIDRs { + for _, cidr := range cidrs { prefix, err := netip.ParsePrefix(cidr) if err != nil { return nil, fmt.Errorf("invalid blocked CIDR %q: %w", cidr, err) @@ -51,12 +50,12 @@ func New(cfg config.Firewall) (*Firewall, error) { f.blockedCIDRs = append(f.blockedCIDRs, prefix) } - for _, code := range cfg.BlockedCountries { + for _, code := range countries { f.blockedCountries[strings.ToUpper(code)] = struct{}{} } - if len(f.blockedCountries) > 0 { - if err := f.loadGeoDB(cfg.GeoIPDB); err != nil { + if len(f.blockedCountries) > 0 && geoIPPath != "" { + if err := f.loadGeoDB(geoIPPath); err != nil { return nil, fmt.Errorf("loading GeoIP database: %w", err) } } diff --git a/internal/firewall/firewall_test.go b/internal/firewall/firewall_test.go index f442c28..6b9655e 100644 --- a/internal/firewall/firewall_test.go +++ b/internal/firewall/firewall_test.go @@ -3,12 +3,10 @@ package firewall import ( "net/netip" "testing" - - "git.wntrmute.dev/kyle/mc-proxy/internal/config" ) func TestEmptyFirewall(t *testing.T) { - fw, err := New(config.Firewall{}) + fw, err := New("", nil, nil, nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -24,9 +22,7 @@ func TestEmptyFirewall(t *testing.T) { } func TestIPBlocking(t *testing.T) { - fw, err := New(config.Firewall{ - BlockedIPs: []string{"192.0.2.1", "2001:db8::dead"}, - }) + fw, err := New("", []string{"192.0.2.1", "2001:db8::dead"}, nil, nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -51,9 +47,7 @@ func TestIPBlocking(t *testing.T) { } func TestCIDRBlocking(t *testing.T) { - fw, err := New(config.Firewall{ - BlockedCIDRs: []string{"198.51.100.0/24", "2001:db8::/32"}, - }) + fw, err := New("", nil, []string{"198.51.100.0/24", "2001:db8::/32"}, nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -79,15 +73,12 @@ func TestCIDRBlocking(t *testing.T) { } func TestIPv4MappedIPv6(t *testing.T) { - fw, err := New(config.Firewall{ - BlockedIPs: []string{"192.0.2.1"}, - }) + fw, err := New("", []string{"192.0.2.1"}, nil, nil) if err != nil { t.Fatalf("unexpected error: %v", err) } defer fw.Close() - // IPv4-mapped IPv6 representation of 192.0.2.1. addr := netip.MustParseAddr("::ffff:192.0.2.1") if !fw.Blocked(addr) { t.Fatal("expected IPv4-mapped IPv6 address to be blocked") @@ -95,28 +86,21 @@ func TestIPv4MappedIPv6(t *testing.T) { } func TestInvalidIP(t *testing.T) { - _, err := New(config.Firewall{ - BlockedIPs: []string{"not-an-ip"}, - }) + _, err := New("", []string{"not-an-ip"}, nil, nil) if err == nil { t.Fatal("expected error for invalid IP") } } func TestInvalidCIDR(t *testing.T) { - _, err := New(config.Firewall{ - BlockedCIDRs: []string{"not-a-cidr"}, - }) + _, err := New("", nil, []string{"not-a-cidr"}, nil) if err == nil { t.Fatal("expected error for invalid CIDR") } } func TestCombinedRules(t *testing.T) { - fw, err := New(config.Firewall{ - BlockedIPs: []string{"10.0.0.1"}, - BlockedCIDRs: []string{"192.168.0.0/16"}, - }) + fw, err := New("", []string{"10.0.0.1"}, []string{"192.168.0.0/16"}, nil) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -126,10 +110,10 @@ func TestCombinedRules(t *testing.T) { addr string blocked bool }{ - {"10.0.0.1", true}, // IP match - {"10.0.0.2", false}, // no match - {"192.168.1.1", true}, // CIDR match - {"172.16.0.1", false}, // no match + {"10.0.0.1", true}, + {"10.0.0.2", false}, + {"192.168.1.1", true}, + {"172.16.0.1", false}, } for _, tt := range tests { @@ -139,3 +123,30 @@ func TestCombinedRules(t *testing.T) { } } } + +func TestRuntimeMutation(t *testing.T) { + fw, err := New("", nil, nil, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer fw.Close() + + addr := netip.MustParseAddr("10.0.0.1") + if fw.Blocked(addr) { + t.Fatal("should not be blocked initially") + } + + if err := fw.AddIP("10.0.0.1"); err != nil { + t.Fatalf("add IP: %v", err) + } + if !fw.Blocked(addr) { + t.Fatal("should be blocked after AddIP") + } + + if err := fw.RemoveIP("10.0.0.1"); err != nil { + t.Fatalf("remove IP: %v", err) + } + if fw.Blocked(addr) { + t.Fatal("should not be blocked after RemoveIP") + } +} diff --git a/internal/grpcserver/grpcserver.go b/internal/grpcserver/grpcserver.go index e40e568..8c5350d 100644 --- a/internal/grpcserver/grpcserver.go +++ b/internal/grpcserver/grpcserver.go @@ -8,6 +8,7 @@ import ( "log/slog" "net" "os" + "strings" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -17,6 +18,7 @@ import ( pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc-proxy/v1" "git.wntrmute.dev/kyle/mc-proxy/internal/config" + "git.wntrmute.dev/kyle/mc-proxy/internal/db" "git.wntrmute.dev/kyle/mc-proxy/internal/server" ) @@ -24,11 +26,12 @@ import ( type AdminServer struct { pb.UnimplementedProxyAdminServer srv *server.Server + store *db.Store logger *slog.Logger } // New creates a gRPC server with TLS and optional mTLS. -func New(cfg config.GRPC, srv *server.Server, logger *slog.Logger) (*grpc.Server, net.Listener, error) { +func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logger) (*grpc.Server, net.Listener, error) { cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey) if err != nil { return nil, nil, fmt.Errorf("loading TLS keypair: %w", err) @@ -39,7 +42,6 @@ func New(cfg config.GRPC, srv *server.Server, logger *slog.Logger) (*grpc.Server MinVersion: tls.VersionTLS13, } - // mTLS: require and verify client certificates. if cfg.ClientCA != "" { caCert, err := os.ReadFile(cfg.ClientCA) if err != nil { @@ -58,6 +60,7 @@ func New(cfg config.GRPC, srv *server.Server, logger *slog.Logger) (*grpc.Server admin := &AdminServer{ srv: srv, + store: store, logger: logger, } pb.RegisterProxyAdminServer(grpcServer, admin) @@ -90,7 +93,7 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) ( return resp, nil } -// AddRoute adds a route to a listener's route table. +// AddRoute writes to the database first, then updates in-memory state. func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.AddRouteResponse, error) { if req.Route == nil { return nil, status.Error(codes.InvalidArgument, "route is required") @@ -104,15 +107,23 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb. return nil, err } - if err := ls.AddRoute(req.Route.Hostname, req.Route.Backend); err != nil { + hostname := strings.ToLower(req.Route.Hostname) + + // Write-through: DB first, then memory. + if _, err := a.store.CreateRoute(ls.ID, hostname, req.Route.Backend); err != nil { return nil, status.Errorf(codes.AlreadyExists, "%v", err) } - a.logger.Info("route added", "listener", ls.Addr, "hostname", req.Route.Hostname, "backend", req.Route.Backend) + if err := ls.AddRoute(hostname, req.Route.Backend); err != nil { + // DB succeeded but memory failed (should not happen since DB enforces uniqueness). + a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err) + } + + a.logger.Info("route added", "listener", ls.Addr, "hostname", hostname, "backend", req.Route.Backend) return &pb.AddRouteResponse{}, nil } -// RemoveRoute removes a route from a listener's route table. +// RemoveRoute writes to the database first, then updates in-memory state. func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest) (*pb.RemoveRouteResponse, error) { if req.Hostname == "" { return nil, status.Error(codes.InvalidArgument, "hostname is required") @@ -123,11 +134,17 @@ func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest) return nil, err } - if err := ls.RemoveRoute(req.Hostname); err != nil { + hostname := strings.ToLower(req.Hostname) + + if err := a.store.DeleteRoute(ls.ID, hostname); err != nil { return nil, status.Errorf(codes.NotFound, "%v", err) } - a.logger.Info("route removed", "listener", ls.Addr, "hostname", req.Hostname) + if err := ls.RemoveRoute(hostname); err != nil { + a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err) + } + + a.logger.Info("route removed", "listener", ls.Addr, "hostname", hostname) return &pb.RemoveRouteResponse{}, nil } @@ -158,61 +175,74 @@ func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRules return &pb.GetFirewallRulesResponse{Rules: rules}, nil } -// AddFirewallRule adds a firewall rule. +// AddFirewallRule writes to the database first, then updates in-memory state. func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRuleRequest) (*pb.AddFirewallRuleResponse, error) { if req.Rule == nil { return nil, status.Error(codes.InvalidArgument, "rule is required") } - fw := a.srv.Firewall() - switch req.Rule.Type { - case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP: - if err := fw.AddIP(req.Rule.Value); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "%v", err) - } - case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR: - if err := fw.AddCIDR(req.Rule.Value); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "%v", err) - } - case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY: - if req.Rule.Value == "" { - return nil, status.Error(codes.InvalidArgument, "country code is required") - } - fw.AddCountry(req.Rule.Value) - default: - return nil, status.Error(codes.InvalidArgument, "unknown rule type") + ruleType, err := protoRuleTypeToString(req.Rule.Type) + if err != nil { + return nil, err } - a.logger.Info("firewall rule added", "type", req.Rule.Type, "value", req.Rule.Value) + if req.Rule.Value == "" { + return nil, status.Error(codes.InvalidArgument, "value is required") + } + + // Write-through: DB first, then memory. + if _, err := a.store.CreateFirewallRule(ruleType, req.Rule.Value); err != nil { + return nil, status.Errorf(codes.AlreadyExists, "%v", err) + } + + fw := a.srv.Firewall() + switch ruleType { + case "ip": + if err := fw.AddIP(req.Rule.Value); err != nil { + a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err) + } + case "cidr": + if err := fw.AddCIDR(req.Rule.Value); err != nil { + a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err) + } + case "country": + fw.AddCountry(req.Rule.Value) + } + + a.logger.Info("firewall rule added", "type", ruleType, "value", req.Rule.Value) return &pb.AddFirewallRuleResponse{}, nil } -// RemoveFirewallRule removes a firewall rule. +// RemoveFirewallRule writes to the database first, then updates in-memory state. func (a *AdminServer) RemoveFirewallRule(_ context.Context, req *pb.RemoveFirewallRuleRequest) (*pb.RemoveFirewallRuleResponse, error) { if req.Rule == nil { return nil, status.Error(codes.InvalidArgument, "rule is required") } - fw := a.srv.Firewall() - switch req.Rule.Type { - case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP: - if err := fw.RemoveIP(req.Rule.Value); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "%v", err) - } - case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR: - if err := fw.RemoveCIDR(req.Rule.Value); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "%v", err) - } - case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY: - if req.Rule.Value == "" { - return nil, status.Error(codes.InvalidArgument, "country code is required") - } - fw.RemoveCountry(req.Rule.Value) - default: - return nil, status.Error(codes.InvalidArgument, "unknown rule type") + ruleType, err := protoRuleTypeToString(req.Rule.Type) + if err != nil { + return nil, err } - a.logger.Info("firewall rule removed", "type", req.Rule.Type, "value", req.Rule.Value) + if err := a.store.DeleteFirewallRule(ruleType, req.Rule.Value); err != nil { + return nil, status.Errorf(codes.NotFound, "%v", err) + } + + fw := a.srv.Firewall() + switch ruleType { + case "ip": + if err := fw.RemoveIP(req.Rule.Value); err != nil { + a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err) + } + case "cidr": + if err := fw.RemoveCIDR(req.Rule.Value); err != nil { + a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err) + } + case "country": + fw.RemoveCountry(req.Rule.Value) + } + + a.logger.Info("firewall rule removed", "type", ruleType, "value", req.Rule.Value) return &pb.RemoveFirewallRuleResponse{}, nil } @@ -244,3 +274,16 @@ func (a *AdminServer) findListener(addr string) (*server.ListenerState, error) { } return nil, status.Errorf(codes.NotFound, "listener %q not found", addr) } + +func protoRuleTypeToString(t pb.FirewallRuleType) (string, error) { + switch t { + case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP: + return "ip", nil + case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR: + return "cidr", nil + case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY: + return "country", nil + default: + return "", status.Error(codes.InvalidArgument, "unknown rule type") + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 6ec71a8..712e148 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -19,6 +19,7 @@ import ( // ListenerState holds the mutable state for a single proxy listener. type ListenerState struct { + ID int64 // database primary key Addr string routes map[string]string // lowercase hostname → backend addr mu sync.RWMutex @@ -75,6 +76,13 @@ func (ls *ListenerState) lookupRoute(hostname string) (string, bool) { return backend, ok } +// ListenerData holds the data needed to construct a ListenerState. +type ListenerData struct { + ID int64 + Addr string + Routes map[string]string // lowercase hostname → backend +} + // Server is the mc-proxy server. It manages listeners, firewall evaluation, // SNI-based routing, and bidirectional proxying. type Server struct { @@ -82,27 +90,19 @@ type Server struct { fw *firewall.Firewall listeners []*ListenerState logger *slog.Logger - wg sync.WaitGroup // tracks active connections + wg sync.WaitGroup startedAt time.Time version string } -// New creates a Server from the given configuration. -func New(cfg *config.Config, logger *slog.Logger, version string) (*Server, error) { - fw, err := firewall.New(cfg.Firewall) - if err != nil { - return nil, fmt.Errorf("initializing firewall: %w", err) - } - +// New creates a Server from pre-loaded data. +func New(cfg *config.Config, fw *firewall.Firewall, listenerData []ListenerData, logger *slog.Logger, version string) *Server { var listeners []*ListenerState - for _, lcfg := range cfg.Listeners { - routes := make(map[string]string, len(lcfg.Routes)) - for _, r := range lcfg.Routes { - routes[strings.ToLower(r.Hostname)] = r.Backend - } + for _, ld := range listenerData { listeners = append(listeners, &ListenerState{ - Addr: lcfg.Addr, - routes: routes, + ID: ld.ID, + Addr: ld.Addr, + routes: ld.Routes, }) } @@ -112,7 +112,7 @@ func New(cfg *config.Config, logger *slog.Logger, version string) (*Server, erro listeners: listeners, logger: logger, version: version, - }, nil + } } // Firewall returns the server's firewall for use by the gRPC admin API. @@ -162,23 +162,19 @@ func (s *Server) Run(ctx context.Context) error { netListeners = append(netListeners, ln) } - // Start accept loops. for i, ln := range netListeners { ln := ln ls := s.listeners[i] go s.serve(ctx, ln, ls) } - // Block until shutdown signal. <-ctx.Done() s.logger.Info("shutting down") - // Stop accepting new connections. for _, ln := range netListeners { ln.Close() } - // Wait for in-flight connections with a timeout. done := make(chan struct{}) go func() { s.wg.Wait() diff --git a/mc-proxy.toml.example b/mc-proxy.toml.example index a05fa7d..b3be66e 100644 --- a/mc-proxy.toml.example +++ b/mc-proxy.toml.example @@ -1,6 +1,13 @@ # mc-proxy configuration +# Database. Required. Listeners, routes, and firewall rules are persisted here. +# On first run, the database is seeded from the config below. +# On subsequent runs, the database is the source of truth. +[database] +path = "/srv/mc-proxy/mc-proxy.db" + # Listeners. Each listener binds a TCP port and has its own route table. +# These are used to seed the database on first run only. [[listeners]] addr = ":443"