Extend the config, database schema, and server internals to support per-route L4/L7 mode selection and PROXY protocol fields. This is the foundation for L7 HTTP/2 reverse proxying and multi-hop PROXY protocol support described in the updated ARCHITECTURE.md. Config: Listener gains ProxyProtocol; Route gains Mode, TLSCert, TLSKey, BackendTLS, SendProxyProtocol. L7 routes validated at load time (cert/key pair must exist and parse). Mode defaults to "l4". DB: Migration v2 adds columns to listeners and routes tables. CRUD and seeding updated to persist all new fields. Server: RouteInfo replaces bare backend string in route lookup. handleConn dispatches on route.Mode (L7 path stubbed with error). ListenerState and ListenerData carry ProxyProtocol flag. All existing L4 tests pass unchanged. New tests cover migration v2, L7 field persistence, config validation for mode/cert/key, and proxy_protocol flag round-tripping. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
319 lines
7.8 KiB
Go
319 lines
7.8 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"net/netip"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/proxy"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
|
|
)
|
|
|
|
// RouteInfo holds the full configuration for a single route.
|
|
type RouteInfo struct {
|
|
Backend string
|
|
Mode string // "l4" or "l7"
|
|
TLSCert string
|
|
TLSKey string
|
|
BackendTLS bool
|
|
SendProxyProtocol bool
|
|
}
|
|
|
|
// ListenerState holds the mutable state for a single proxy listener.
|
|
type ListenerState struct {
|
|
ID int64 // database primary key
|
|
Addr string
|
|
ProxyProtocol bool
|
|
routes map[string]RouteInfo // lowercase hostname → route info
|
|
mu sync.RWMutex
|
|
ActiveConnections atomic.Int64
|
|
activeConns map[net.Conn]struct{} // tracked for forced shutdown
|
|
connMu sync.Mutex
|
|
}
|
|
|
|
// Routes returns a snapshot of the listener's route table.
|
|
func (ls *ListenerState) Routes() map[string]RouteInfo {
|
|
ls.mu.RLock()
|
|
defer ls.mu.RUnlock()
|
|
|
|
m := make(map[string]RouteInfo, len(ls.routes))
|
|
for k, v := range ls.routes {
|
|
m[k] = v
|
|
}
|
|
return m
|
|
}
|
|
|
|
// AddRoute adds a route to the listener. Returns an error if the hostname
|
|
// already exists.
|
|
func (ls *ListenerState) AddRoute(hostname string, info RouteInfo) error {
|
|
key := strings.ToLower(hostname)
|
|
|
|
ls.mu.Lock()
|
|
defer ls.mu.Unlock()
|
|
|
|
if _, ok := ls.routes[key]; ok {
|
|
return fmt.Errorf("route %q already exists", hostname)
|
|
}
|
|
ls.routes[key] = info
|
|
return nil
|
|
}
|
|
|
|
// RemoveRoute removes a route from the listener. Returns an error if the
|
|
// hostname does not exist.
|
|
func (ls *ListenerState) RemoveRoute(hostname string) error {
|
|
key := strings.ToLower(hostname)
|
|
|
|
ls.mu.Lock()
|
|
defer ls.mu.Unlock()
|
|
|
|
if _, ok := ls.routes[key]; !ok {
|
|
return fmt.Errorf("route %q not found", hostname)
|
|
}
|
|
delete(ls.routes, key)
|
|
return nil
|
|
}
|
|
|
|
func (ls *ListenerState) lookupRoute(hostname string) (RouteInfo, bool) {
|
|
ls.mu.RLock()
|
|
defer ls.mu.RUnlock()
|
|
|
|
info, ok := ls.routes[hostname]
|
|
return info, ok
|
|
}
|
|
|
|
// ListenerData holds the data needed to construct a ListenerState.
|
|
type ListenerData struct {
|
|
ID int64
|
|
Addr string
|
|
ProxyProtocol bool
|
|
Routes map[string]RouteInfo // lowercase hostname → route info
|
|
}
|
|
|
|
// Server is the mc-proxy server. It manages listeners, firewall evaluation,
|
|
// SNI-based routing, and bidirectional proxying.
|
|
type Server struct {
|
|
cfg *config.Config
|
|
fw *firewall.Firewall
|
|
listeners []*ListenerState
|
|
logger *slog.Logger
|
|
wg sync.WaitGroup
|
|
startedAt time.Time
|
|
version string
|
|
}
|
|
|
|
// New creates a Server from pre-loaded data.
|
|
func New(cfg *config.Config, fw *firewall.Firewall, listenerData []ListenerData, logger *slog.Logger, version string) *Server {
|
|
var listeners []*ListenerState
|
|
for _, ld := range listenerData {
|
|
listeners = append(listeners, &ListenerState{
|
|
ID: ld.ID,
|
|
Addr: ld.Addr,
|
|
ProxyProtocol: ld.ProxyProtocol,
|
|
routes: ld.Routes,
|
|
activeConns: make(map[net.Conn]struct{}),
|
|
})
|
|
}
|
|
|
|
return &Server{
|
|
cfg: cfg,
|
|
fw: fw,
|
|
listeners: listeners,
|
|
logger: logger,
|
|
version: version,
|
|
}
|
|
}
|
|
|
|
// Firewall returns the server's firewall for use by the gRPC admin API.
|
|
func (s *Server) Firewall() *firewall.Firewall {
|
|
return s.fw
|
|
}
|
|
|
|
// Listeners returns the server's listener states for use by the gRPC admin API.
|
|
func (s *Server) Listeners() []*ListenerState {
|
|
return s.listeners
|
|
}
|
|
|
|
// StartedAt returns the time the server started.
|
|
func (s *Server) StartedAt() time.Time {
|
|
return s.startedAt
|
|
}
|
|
|
|
// Version returns the server's version string.
|
|
func (s *Server) Version() string {
|
|
return s.version
|
|
}
|
|
|
|
// TotalConnections returns the total number of active connections.
|
|
func (s *Server) TotalConnections() int64 {
|
|
var total int64
|
|
for _, ls := range s.listeners {
|
|
total += ls.ActiveConnections.Load()
|
|
}
|
|
return total
|
|
}
|
|
|
|
// Run starts all listeners and blocks until ctx is cancelled.
|
|
func (s *Server) Run(ctx context.Context) error {
|
|
s.startedAt = time.Now()
|
|
|
|
var netListeners []net.Listener
|
|
|
|
for _, ls := range s.listeners {
|
|
ln, err := net.Listen("tcp", ls.Addr)
|
|
if err != nil {
|
|
for _, l := range netListeners {
|
|
l.Close()
|
|
}
|
|
return fmt.Errorf("listening on %s: %w", ls.Addr, err)
|
|
}
|
|
s.logger.Info("listening", "addr", ln.Addr(), "routes", len(ls.routes))
|
|
netListeners = append(netListeners, ln)
|
|
}
|
|
|
|
for i, ln := range netListeners {
|
|
ln := ln
|
|
ls := s.listeners[i]
|
|
go s.serve(ctx, ln, ls)
|
|
}
|
|
|
|
<-ctx.Done()
|
|
s.logger.Info("shutting down")
|
|
|
|
for _, ln := range netListeners {
|
|
ln.Close()
|
|
}
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
s.wg.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
s.logger.Info("all connections drained")
|
|
case <-time.After(s.cfg.Proxy.ShutdownTimeout.Duration):
|
|
s.logger.Warn("shutdown timeout exceeded, forcing close")
|
|
// Force-close all listener connections to unblock relay goroutines.
|
|
s.forceCloseAll()
|
|
<-done
|
|
}
|
|
|
|
s.fw.Close()
|
|
return nil
|
|
}
|
|
|
|
// ReloadGeoIP reloads the GeoIP database from disk.
|
|
func (s *Server) ReloadGeoIP() error {
|
|
return s.fw.ReloadGeoIP()
|
|
}
|
|
|
|
func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState) {
|
|
for {
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
s.logger.Error("accept error", "error", err)
|
|
continue
|
|
}
|
|
|
|
s.wg.Add(1)
|
|
ls.ActiveConnections.Add(1)
|
|
go s.handleConn(ctx, conn, ls)
|
|
}
|
|
}
|
|
|
|
// forceCloseAll closes all tracked connections across all listeners.
|
|
func (s *Server) forceCloseAll() {
|
|
for _, ls := range s.listeners {
|
|
ls.connMu.Lock()
|
|
for conn := range ls.activeConns {
|
|
conn.Close()
|
|
}
|
|
ls.connMu.Unlock()
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) {
|
|
defer s.wg.Done()
|
|
defer ls.ActiveConnections.Add(-1)
|
|
defer conn.Close()
|
|
|
|
ls.connMu.Lock()
|
|
ls.activeConns[conn] = struct{}{}
|
|
ls.connMu.Unlock()
|
|
defer func() {
|
|
ls.connMu.Lock()
|
|
delete(ls.activeConns, conn)
|
|
ls.connMu.Unlock()
|
|
}()
|
|
|
|
remoteAddr := conn.RemoteAddr().String()
|
|
addrPort, err := netip.ParseAddrPort(remoteAddr)
|
|
if err != nil {
|
|
s.logger.Error("parsing remote address", "addr", remoteAddr, "error", err)
|
|
return
|
|
}
|
|
addr := addrPort.Addr()
|
|
|
|
if s.fw.Blocked(addr) {
|
|
s.logger.Debug("blocked by firewall", "addr", addr)
|
|
return
|
|
}
|
|
|
|
hostname, peeked, err := sni.Extract(conn, time.Now().Add(10*time.Second))
|
|
if err != nil {
|
|
s.logger.Debug("SNI extraction failed", "addr", addr, "error", err)
|
|
return
|
|
}
|
|
|
|
route, ok := ls.lookupRoute(hostname)
|
|
if !ok {
|
|
s.logger.Debug("no route for hostname", "addr", addr, "hostname", hostname)
|
|
return
|
|
}
|
|
|
|
// Dispatch based on route mode. L7 will be implemented in a later phase.
|
|
switch route.Mode {
|
|
case "l7":
|
|
s.logger.Error("L7 mode not yet implemented", "hostname", hostname)
|
|
return
|
|
default:
|
|
s.handleL4(ctx, conn, ls, addr, hostname, route, peeked)
|
|
}
|
|
}
|
|
|
|
// handleL4 handles an L4 (passthrough) connection.
|
|
func (s *Server) handleL4(ctx context.Context, conn net.Conn, _ *ListenerState, addr netip.Addr, hostname string, route RouteInfo, peeked []byte) {
|
|
backendConn, err := net.DialTimeout("tcp", route.Backend, s.cfg.Proxy.ConnectTimeout.Duration)
|
|
if err != nil {
|
|
s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err)
|
|
return
|
|
}
|
|
defer backendConn.Close()
|
|
|
|
s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", route.Backend)
|
|
|
|
result, err := proxy.Relay(ctx, conn, backendConn, peeked, s.cfg.Proxy.IdleTimeout.Duration)
|
|
if err != nil && ctx.Err() == nil {
|
|
s.logger.Debug("relay ended", "hostname", hostname, "error", err)
|
|
}
|
|
|
|
s.logger.Info("connection closed",
|
|
"addr", addr,
|
|
"hostname", hostname,
|
|
"client_bytes", result.ClientBytes,
|
|
"backend_bytes", result.BackendBytes,
|
|
)
|
|
}
|