Files
mc-proxy/cmd/mc-proxy/server.go
Kyle Isom feeadc582b Migrate module path from kyle/ to mc/ org
All import paths updated to git.wntrmute.dev/mc/. Bumps mcdsl to v1.2.0.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 02:05:59 -07:00

215 lines
5.6 KiB
Go

package main
import (
"context"
"fmt"
"log/slog"
"os"
"os/signal"
"strings"
"syscall"
"github.com/spf13/cobra"
"git.wntrmute.dev/mc/mc-proxy/internal/config"
"git.wntrmute.dev/mc/mc-proxy/internal/db"
"git.wntrmute.dev/mc/mc-proxy/internal/firewall"
"git.wntrmute.dev/mc/mc-proxy/internal/grpcserver"
"git.wntrmute.dev/mc/mc-proxy/internal/metrics"
"git.wntrmute.dev/mc/mc-proxy/internal/server"
)
func serverCmd() *cobra.Command {
var configPath string
cmd := &cobra.Command{
Use: "server",
Short: "Start the proxy server",
RunE: func(cmd *cobra.Command, args []string) error {
cfg, err := config.Load(configPath)
if err != nil {
return err
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: parseLogLevel(cfg.Log.Level),
}))
// Open and migrate the database.
store, err := db.Open(cfg.Database.Path)
if err != nil {
return fmt.Errorf("opening database: %w", err)
}
defer store.Close()
if err := store.Migrate(); err != nil {
return fmt.Errorf("running migrations: %w", err)
}
// Seed from config on first run, or load from DB.
empty, err := store.IsEmpty()
if err != nil {
return fmt.Errorf("checking database: %w", err)
}
if empty {
if len(cfg.Listeners) == 0 {
return fmt.Errorf("database is empty and no listeners defined in config for seeding")
}
logger.Info("seeding database from config")
if err := store.Seed(cfg.Listeners, cfg.Firewall); err != nil {
return fmt.Errorf("seeding database: %w", err)
}
}
// Load listeners and routes from DB.
listenerData, err := loadListenersFromDB(store)
if err != nil {
return err
}
// Load firewall rules from DB.
fw, err := loadFirewallFromDB(store, cfg.Firewall)
if err != nil {
return err
}
srv := server.New(cfg, fw, listenerData, logger, version)
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
// Start gRPC admin API if configured.
if cfg.GRPC.Addr != "" {
grpcSrv, ln, err := grpcserver.New(cfg.GRPC, srv, store, logger)
if err != nil {
return err
}
logger.Info("gRPC admin API listening", "addr", ln.Addr())
go func() {
if err := grpcSrv.Serve(ln); err != nil {
logger.Error("gRPC server error", "error", err)
}
}()
defer func() {
grpcSrv.GracefulStop()
os.Remove(cfg.GRPC.SocketPath())
}()
}
// SIGHUP reloads the GeoIP database.
sighup := make(chan os.Signal, 1)
signal.Notify(sighup, syscall.SIGHUP)
go func() {
for range sighup {
logger.Info("received SIGHUP, reloading GeoIP database")
if err := srv.ReloadGeoIP(); err != nil {
logger.Error("failed to reload GeoIP database", "error", err)
}
}
}()
// Start Prometheus metrics server if configured.
if cfg.Metrics.Addr != "" {
logger.Info("metrics server listening", "addr", cfg.Metrics.Addr, "path", cfg.Metrics.Path)
go func() {
if err := metrics.ListenAndServe(ctx, cfg.Metrics.Addr, cfg.Metrics.Path); err != nil {
logger.Error("metrics server error", "error", err)
}
}()
}
logger.Info("mc-proxy starting", "version", version)
return srv.Run(ctx)
},
}
cmd.Flags().StringVarP(&configPath, "config", "c", "mc-proxy.toml", "path to configuration file")
return cmd
}
func loadListenersFromDB(store *db.Store) ([]server.ListenerData, error) {
dbListeners, err := store.ListListeners()
if err != nil {
return nil, fmt.Errorf("loading listeners: %w", err)
}
var result []server.ListenerData
for _, l := range dbListeners {
dbRoutes, err := store.ListRoutes(l.ID)
if err != nil {
return nil, fmt.Errorf("loading routes for listener %q: %w", l.Addr, err)
}
routes := make(map[string]server.RouteInfo, len(dbRoutes))
for _, r := range dbRoutes {
// Load L7 policies for this route.
var policies []server.L7PolicyRule
if r.Mode == "l7" {
dbPolicies, err := store.ListL7Policies(r.ID)
if err != nil {
return nil, fmt.Errorf("loading L7 policies for route %q: %w", r.Hostname, err)
}
for _, p := range dbPolicies {
policies = append(policies, server.L7PolicyRule{Type: p.Type, Value: p.Value})
}
}
routes[strings.ToLower(r.Hostname)] = server.RouteInfo{
Backend: r.Backend,
Mode: r.Mode,
TLSCert: r.TLSCert,
TLSKey: r.TLSKey,
BackendTLS: r.BackendTLS,
SendProxyProtocol: r.SendProxyProtocol,
L7Policies: policies,
}
}
result = append(result, server.ListenerData{
ID: l.ID,
Addr: l.Addr,
ProxyProtocol: l.ProxyProtocol,
MaxConnections: l.MaxConnections,
Routes: routes,
})
}
return result, nil
}
func loadFirewallFromDB(store *db.Store, fwCfg config.Firewall) (*firewall.Firewall, error) {
rules, err := store.ListFirewallRules()
if err != nil {
return nil, fmt.Errorf("loading firewall rules: %w", err)
}
var ips, cidrs, countries []string
for _, r := range rules {
switch r.Type {
case "ip":
ips = append(ips, r.Value)
case "cidr":
cidrs = append(cidrs, r.Value)
case "country":
countries = append(countries, r.Value)
}
}
fw, err := firewall.New(fwCfg.GeoIPDB, ips, cidrs, countries, fwCfg.RateLimit, fwCfg.RateWindow.Duration)
if err != nil {
return nil, fmt.Errorf("initializing firewall: %w", err)
}
return fw, nil
}
func parseLogLevel(s string) slog.Level {
switch s {
case "debug":
return slog.LevelDebug
case "warn":
return slog.LevelWarn
case "error":
return slog.LevelError
default:
return slog.LevelInfo
}
}