package server import ( "context" "fmt" "log/slog" "net" "net/netip" "strings" "sync" "sync/atomic" "time" "git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/kyle/mc-proxy/internal/l7" "git.wntrmute.dev/kyle/mc-proxy/internal/proxy" "git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto" "git.wntrmute.dev/kyle/mc-proxy/internal/sni" ) // RouteInfo holds the full configuration for a single route. type RouteInfo struct { Backend string Mode string // "l4" or "l7" TLSCert string TLSKey string BackendTLS bool SendProxyProtocol bool } // ListenerState holds the mutable state for a single proxy listener. type ListenerState struct { ID int64 // database primary key Addr string ProxyProtocol bool MaxConnections int64 // 0 = unlimited routes map[string]RouteInfo // lowercase hostname → route info mu sync.RWMutex ActiveConnections atomic.Int64 activeConns map[net.Conn]struct{} // tracked for forced shutdown connMu sync.Mutex } // SetMaxConnections updates the connection limit at runtime. func (ls *ListenerState) SetMaxConnections(n int64) { ls.mu.Lock() defer ls.mu.Unlock() ls.MaxConnections = n } // Routes returns a snapshot of the listener's route table. func (ls *ListenerState) Routes() map[string]RouteInfo { ls.mu.RLock() defer ls.mu.RUnlock() m := make(map[string]RouteInfo, len(ls.routes)) for k, v := range ls.routes { m[k] = v } return m } // AddRoute adds a route to the listener. Returns an error if the hostname // already exists. func (ls *ListenerState) AddRoute(hostname string, info RouteInfo) error { key := strings.ToLower(hostname) ls.mu.Lock() defer ls.mu.Unlock() if _, ok := ls.routes[key]; ok { return fmt.Errorf("route %q already exists", hostname) } ls.routes[key] = info return nil } // RemoveRoute removes a route from the listener. Returns an error if the // hostname does not exist. func (ls *ListenerState) RemoveRoute(hostname string) error { key := strings.ToLower(hostname) ls.mu.Lock() defer ls.mu.Unlock() if _, ok := ls.routes[key]; !ok { return fmt.Errorf("route %q not found", hostname) } delete(ls.routes, key) return nil } func (ls *ListenerState) lookupRoute(hostname string) (RouteInfo, bool) { ls.mu.RLock() defer ls.mu.RUnlock() info, ok := ls.routes[hostname] return info, ok } // ListenerData holds the data needed to construct a ListenerState. type ListenerData struct { ID int64 Addr string ProxyProtocol bool MaxConnections int64 Routes map[string]RouteInfo // lowercase hostname → route info } // Server is the mc-proxy server. It manages listeners, firewall evaluation, // SNI-based routing, and bidirectional proxying. type Server struct { cfg *config.Config fw *firewall.Firewall listeners []*ListenerState logger *slog.Logger wg sync.WaitGroup startedAt time.Time version string } // New creates a Server from pre-loaded data. func New(cfg *config.Config, fw *firewall.Firewall, listenerData []ListenerData, logger *slog.Logger, version string) *Server { var listeners []*ListenerState for _, ld := range listenerData { listeners = append(listeners, &ListenerState{ ID: ld.ID, Addr: ld.Addr, ProxyProtocol: ld.ProxyProtocol, MaxConnections: ld.MaxConnections, routes: ld.Routes, activeConns: make(map[net.Conn]struct{}), }) } return &Server{ cfg: cfg, fw: fw, listeners: listeners, logger: logger, version: version, } } // Firewall returns the server's firewall for use by the gRPC admin API. func (s *Server) Firewall() *firewall.Firewall { return s.fw } // Listeners returns the server's listener states for use by the gRPC admin API. func (s *Server) Listeners() []*ListenerState { return s.listeners } // StartedAt returns the time the server started. func (s *Server) StartedAt() time.Time { return s.startedAt } // Version returns the server's version string. func (s *Server) Version() string { return s.version } // TotalConnections returns the total number of active connections. func (s *Server) TotalConnections() int64 { var total int64 for _, ls := range s.listeners { total += ls.ActiveConnections.Load() } return total } // Run starts all listeners and blocks until ctx is cancelled. func (s *Server) Run(ctx context.Context) error { s.startedAt = time.Now() var netListeners []net.Listener for _, ls := range s.listeners { ln, err := net.Listen("tcp", ls.Addr) if err != nil { for _, l := range netListeners { l.Close() } return fmt.Errorf("listening on %s: %w", ls.Addr, err) } s.logger.Info("listening", "addr", ln.Addr(), "routes", len(ls.routes)) netListeners = append(netListeners, ln) } for i, ln := range netListeners { ln := ln ls := s.listeners[i] go s.serve(ctx, ln, ls) } <-ctx.Done() s.logger.Info("shutting down") for _, ln := range netListeners { ln.Close() } done := make(chan struct{}) go func() { s.wg.Wait() close(done) }() select { case <-done: s.logger.Info("all connections drained") case <-time.After(s.cfg.Proxy.ShutdownTimeout.Duration): s.logger.Warn("shutdown timeout exceeded, forcing close") // Force-close all listener connections to unblock relay goroutines. s.forceCloseAll() <-done } s.fw.Close() return nil } // ReloadGeoIP reloads the GeoIP database from disk. func (s *Server) ReloadGeoIP() error { return s.fw.ReloadGeoIP() } func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState) { for { conn, err := ln.Accept() if err != nil { if ctx.Err() != nil { return } s.logger.Error("accept error", "error", err) continue } // Enforce per-listener connection limit. if ls.MaxConnections > 0 && ls.ActiveConnections.Load() >= ls.MaxConnections { conn.Close() s.logger.Debug("connection limit reached", "addr", ls.Addr, "limit", ls.MaxConnections) continue } s.wg.Add(1) ls.ActiveConnections.Add(1) go s.handleConn(ctx, conn, ls) } } // forceCloseAll closes all tracked connections across all listeners. func (s *Server) forceCloseAll() { for _, ls := range s.listeners { ls.connMu.Lock() for conn := range ls.activeConns { conn.Close() } ls.connMu.Unlock() } } func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) { defer s.wg.Done() defer ls.ActiveConnections.Add(-1) defer conn.Close() ls.connMu.Lock() ls.activeConns[conn] = struct{}{} ls.connMu.Unlock() defer func() { ls.connMu.Lock() delete(ls.activeConns, conn) ls.connMu.Unlock() }() remoteAddr := conn.RemoteAddr().String() addrPort, err := netip.ParseAddrPort(remoteAddr) if err != nil { s.logger.Error("parsing remote address", "addr", remoteAddr, "error", err) return } addr := addrPort.Addr() // Parse PROXY protocol header if enabled on this listener. if ls.ProxyProtocol { hdr, err := proxyproto.Parse(conn, time.Now().Add(5*time.Second)) if err != nil { s.logger.Debug("PROXY protocol parse failed", "addr", addr, "error", err) return } if hdr.Command == proxyproto.CommandProxy { addr = hdr.SrcAddr.Addr() addrPort = hdr.SrcAddr s.logger.Debug("PROXY protocol", "real_addr", addr, "peer_addr", remoteAddr) } } if s.fw.Blocked(addr) { s.logger.Debug("blocked by firewall", "addr", addr) return } hostname, peeked, err := sni.Extract(conn, time.Now().Add(10*time.Second)) if err != nil { s.logger.Debug("SNI extraction failed", "addr", addr, "error", err) return } route, ok := ls.lookupRoute(hostname) if !ok { s.logger.Debug("no route for hostname", "addr", addr, "hostname", hostname) return } // Dispatch based on route mode. switch route.Mode { case "l7": s.handleL7(ctx, conn, addr, addrPort, hostname, route, peeked) default: s.handleL4(ctx, conn, addr, addrPort, hostname, route, peeked) } } // handleL4 handles an L4 (passthrough) connection. func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) { backendConn, err := net.DialTimeout("tcp", route.Backend, s.cfg.Proxy.ConnectTimeout.Duration) if err != nil { s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err) return } defer backendConn.Close() // Send PROXY protocol v2 header to backend if configured. if route.SendProxyProtocol { backendAddrPort, _ := netip.ParseAddrPort(backendConn.RemoteAddr().String()) if err := proxyproto.WriteV2(backendConn, clientAddrPort, backendAddrPort); err != nil { s.logger.Error("writing PROXY protocol header", "hostname", hostname, "error", err) return } } s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", route.Backend) result, err := proxy.Relay(ctx, conn, backendConn, peeked, s.cfg.Proxy.IdleTimeout.Duration) if err != nil && ctx.Err() == nil { s.logger.Debug("relay ended", "hostname", hostname, "error", err) } s.logger.Info("connection closed", "addr", addr, "hostname", hostname, "client_bytes", result.ClientBytes, "backend_bytes", result.BackendBytes, ) } // handleL7 handles an L7 (TLS-terminating) connection. func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) { s.logger.Debug("L7 proxying", "addr", addr, "hostname", hostname, "backend", route.Backend) rc := l7.RouteConfig{ Backend: route.Backend, TLSCert: route.TLSCert, TLSKey: route.TLSKey, BackendTLS: route.BackendTLS, SendProxyProtocol: route.SendProxyProtocol, ConnectTimeout: s.cfg.Proxy.ConnectTimeout.Duration, } if err := l7.Serve(ctx, conn, peeked, rc, clientAddrPort, s.logger); err != nil { if ctx.Err() == nil { s.logger.Debug("L7 serve ended", "hostname", hostname, "error", err) } } s.logger.Info("L7 connection closed", "addr", addr, "hostname", hostname) }