package server import ( "context" "fmt" "log/slog" "net" "net/netip" "strings" "sync" "sync/atomic" "time" "git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/kyle/mc-proxy/internal/proxy" "git.wntrmute.dev/kyle/mc-proxy/internal/sni" ) // ListenerState holds the mutable state for a single proxy listener. type ListenerState struct { Addr string routes map[string]string // lowercase hostname → backend addr mu sync.RWMutex ActiveConnections atomic.Int64 } // Routes returns a snapshot of the listener's route table. func (ls *ListenerState) Routes() map[string]string { ls.mu.RLock() defer ls.mu.RUnlock() m := make(map[string]string, len(ls.routes)) for k, v := range ls.routes { m[k] = v } return m } // AddRoute adds a route to the listener. Returns an error if the hostname // already exists. func (ls *ListenerState) AddRoute(hostname, backend string) error { key := strings.ToLower(hostname) ls.mu.Lock() defer ls.mu.Unlock() if _, ok := ls.routes[key]; ok { return fmt.Errorf("route %q already exists", hostname) } ls.routes[key] = backend return nil } // RemoveRoute removes a route from the listener. Returns an error if the // hostname does not exist. func (ls *ListenerState) RemoveRoute(hostname string) error { key := strings.ToLower(hostname) ls.mu.Lock() defer ls.mu.Unlock() if _, ok := ls.routes[key]; !ok { return fmt.Errorf("route %q not found", hostname) } delete(ls.routes, key) return nil } func (ls *ListenerState) lookupRoute(hostname string) (string, bool) { ls.mu.RLock() defer ls.mu.RUnlock() backend, ok := ls.routes[hostname] return backend, ok } // Server is the mc-proxy server. It manages listeners, firewall evaluation, // SNI-based routing, and bidirectional proxying. type Server struct { cfg *config.Config fw *firewall.Firewall listeners []*ListenerState logger *slog.Logger wg sync.WaitGroup // tracks active connections startedAt time.Time version string } // New creates a Server from the given configuration. func New(cfg *config.Config, logger *slog.Logger, version string) (*Server, error) { fw, err := firewall.New(cfg.Firewall) if err != nil { return nil, fmt.Errorf("initializing firewall: %w", err) } var listeners []*ListenerState for _, lcfg := range cfg.Listeners { routes := make(map[string]string, len(lcfg.Routes)) for _, r := range lcfg.Routes { routes[strings.ToLower(r.Hostname)] = r.Backend } listeners = append(listeners, &ListenerState{ Addr: lcfg.Addr, routes: routes, }) } return &Server{ cfg: cfg, fw: fw, listeners: listeners, logger: logger, version: version, }, nil } // Firewall returns the server's firewall for use by the gRPC admin API. func (s *Server) Firewall() *firewall.Firewall { return s.fw } // Listeners returns the server's listener states for use by the gRPC admin API. func (s *Server) Listeners() []*ListenerState { return s.listeners } // StartedAt returns the time the server started. func (s *Server) StartedAt() time.Time { return s.startedAt } // Version returns the server's version string. func (s *Server) Version() string { return s.version } // TotalConnections returns the total number of active connections. func (s *Server) TotalConnections() int64 { var total int64 for _, ls := range s.listeners { total += ls.ActiveConnections.Load() } return total } // Run starts all listeners and blocks until ctx is cancelled. func (s *Server) Run(ctx context.Context) error { s.startedAt = time.Now() var netListeners []net.Listener for _, ls := range s.listeners { ln, err := net.Listen("tcp", ls.Addr) if err != nil { for _, l := range netListeners { l.Close() } return fmt.Errorf("listening on %s: %w", ls.Addr, err) } s.logger.Info("listening", "addr", ln.Addr(), "routes", len(ls.routes)) netListeners = append(netListeners, ln) } // Start accept loops. for i, ln := range netListeners { ln := ln ls := s.listeners[i] go s.serve(ctx, ln, ls) } // Block until shutdown signal. <-ctx.Done() s.logger.Info("shutting down") // Stop accepting new connections. for _, ln := range netListeners { ln.Close() } // Wait for in-flight connections with a timeout. done := make(chan struct{}) go func() { s.wg.Wait() close(done) }() select { case <-done: s.logger.Info("all connections drained") case <-time.After(s.cfg.Proxy.ShutdownTimeout.Duration): s.logger.Warn("shutdown timeout exceeded, forcing close") } s.fw.Close() return nil } // ReloadGeoIP reloads the GeoIP database from disk. func (s *Server) ReloadGeoIP() error { return s.fw.ReloadGeoIP() } func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState) { for { conn, err := ln.Accept() if err != nil { if ctx.Err() != nil { return } s.logger.Error("accept error", "error", err) continue } s.wg.Add(1) ls.ActiveConnections.Add(1) go s.handleConn(ctx, conn, ls) } } func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) { defer s.wg.Done() defer ls.ActiveConnections.Add(-1) defer conn.Close() remoteAddr := conn.RemoteAddr().String() addrPort, err := netip.ParseAddrPort(remoteAddr) if err != nil { s.logger.Error("parsing remote address", "addr", remoteAddr, "error", err) return } addr := addrPort.Addr() if s.fw.Blocked(addr) { s.logger.Debug("blocked by firewall", "addr", addr) return } hostname, peeked, err := sni.Extract(conn, time.Now().Add(10*time.Second)) if err != nil { s.logger.Debug("SNI extraction failed", "addr", addr, "error", err) return } backend, ok := ls.lookupRoute(hostname) if !ok { s.logger.Debug("no route for hostname", "addr", addr, "hostname", hostname) return } backendConn, err := net.DialTimeout("tcp", backend, s.cfg.Proxy.ConnectTimeout.Duration) if err != nil { s.logger.Error("backend dial failed", "hostname", hostname, "backend", backend, "error", err) return } defer backendConn.Close() s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", backend) result, err := proxy.Relay(ctx, conn, backendConn, peeked, s.cfg.Proxy.IdleTimeout.Duration) if err != nil && ctx.Err() == nil { s.logger.Debug("relay ended", "hostname", hostname, "error", err) } s.logger.Info("connection closed", "addr", addr, "hostname", hostname, "client_bytes", result.ClientBytes, "backend_bytes", result.BackendBytes, ) }