Files
mc-proxy/internal/server/server.go
Kyle Isom 28321e22f4 Make AddRoute idempotent (upsert instead of reject duplicates)
AddRoute now updates an existing route if one already exists for the
same (listener, hostname) pair, instead of returning AlreadyExists.
This makes repeated deploys idempotent — the MCP agent can register
routes on every deploy without needing to remove them first.

- DB: INSERT ... ON CONFLICT DO UPDATE (SQLite upsert)
- In-memory: overwrite existing route unconditionally
- gRPC: error code changed from AlreadyExists to Internal (for real DB errors)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 14:01:45 -07:00

450 lines
12 KiB
Go

package server
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"strings"
"sync"
"sync/atomic"
"time"
"git.wntrmute.dev/mc/mc-proxy/internal/config"
"git.wntrmute.dev/mc/mc-proxy/internal/firewall"
"git.wntrmute.dev/mc/mc-proxy/internal/l7"
"git.wntrmute.dev/mc/mc-proxy/internal/metrics"
"git.wntrmute.dev/mc/mc-proxy/internal/proxy"
"git.wntrmute.dev/mc/mc-proxy/internal/proxyproto"
"git.wntrmute.dev/mc/mc-proxy/internal/sni"
)
// L7PolicyRule is an L7 blocking policy attached to a route.
type L7PolicyRule struct {
Type string // "block_user_agent" or "require_header"
Value string
}
// RouteInfo holds the full configuration for a single route.
type RouteInfo struct {
Backend string
Mode string // "l4" or "l7"
TLSCert string
TLSKey string
BackendTLS bool
SendProxyProtocol bool
L7Policies []L7PolicyRule
}
// ListenerState holds the mutable state for a single proxy listener.
type ListenerState struct {
ID int64 // database primary key
Addr string
ProxyProtocol bool
MaxConnections int64 // 0 = unlimited
routes map[string]RouteInfo // lowercase hostname → route info
mu sync.RWMutex
ActiveConnections atomic.Int64
activeConns map[net.Conn]struct{} // tracked for forced shutdown
connMu sync.Mutex
}
// SetMaxConnections updates the connection limit at runtime.
func (ls *ListenerState) SetMaxConnections(n int64) {
ls.mu.Lock()
defer ls.mu.Unlock()
ls.MaxConnections = n
}
// Routes returns a snapshot of the listener's route table.
func (ls *ListenerState) Routes() map[string]RouteInfo {
ls.mu.RLock()
defer ls.mu.RUnlock()
m := make(map[string]RouteInfo, len(ls.routes))
for k, v := range ls.routes {
m[k] = v
}
return m
}
// AddRoute adds or updates a route on the listener. If a route for the
// hostname already exists, it is replaced (upsert).
func (ls *ListenerState) AddRoute(hostname string, info RouteInfo) {
key := strings.ToLower(hostname)
ls.mu.Lock()
defer ls.mu.Unlock()
ls.routes[key] = info
}
// RemoveRoute removes a route from the listener. Returns an error if the
// hostname does not exist.
func (ls *ListenerState) RemoveRoute(hostname string) error {
key := strings.ToLower(hostname)
ls.mu.Lock()
defer ls.mu.Unlock()
if _, ok := ls.routes[key]; !ok {
return fmt.Errorf("route %q not found", hostname)
}
delete(ls.routes, key)
return nil
}
// AddL7Policy appends an L7 policy to a route's policy list.
func (ls *ListenerState) AddL7Policy(hostname string, policy L7PolicyRule) {
key := strings.ToLower(hostname)
ls.mu.Lock()
defer ls.mu.Unlock()
if route, ok := ls.routes[key]; ok {
route.L7Policies = append(route.L7Policies, policy)
ls.routes[key] = route
}
}
// RemoveL7Policy removes an L7 policy from a route's policy list.
func (ls *ListenerState) RemoveL7Policy(hostname, policyType, policyValue string) {
key := strings.ToLower(hostname)
ls.mu.Lock()
defer ls.mu.Unlock()
route, ok := ls.routes[key]
if !ok {
return
}
filtered := route.L7Policies[:0]
for _, p := range route.L7Policies {
if p.Type != policyType || p.Value != policyValue {
filtered = append(filtered, p)
}
}
route.L7Policies = filtered
ls.routes[key] = route
}
func (ls *ListenerState) lookupRoute(hostname string) (RouteInfo, bool) {
ls.mu.RLock()
defer ls.mu.RUnlock()
info, ok := ls.routes[hostname]
return info, ok
}
// ListenerData holds the data needed to construct a ListenerState.
type ListenerData struct {
ID int64
Addr string
ProxyProtocol bool
MaxConnections int64
Routes map[string]RouteInfo // lowercase hostname → route info
}
// Server is the mc-proxy server. It manages listeners, firewall evaluation,
// SNI-based routing, and bidirectional proxying.
type Server struct {
cfg *config.Config
fw *firewall.Firewall
listeners []*ListenerState
logger *slog.Logger
wg sync.WaitGroup
startedAt time.Time
version string
}
// New creates a Server from pre-loaded data.
func New(cfg *config.Config, fw *firewall.Firewall, listenerData []ListenerData, logger *slog.Logger, version string) *Server {
var listeners []*ListenerState
for _, ld := range listenerData {
listeners = append(listeners, &ListenerState{
ID: ld.ID,
Addr: ld.Addr,
ProxyProtocol: ld.ProxyProtocol,
MaxConnections: ld.MaxConnections,
routes: ld.Routes,
activeConns: make(map[net.Conn]struct{}),
})
}
return &Server{
cfg: cfg,
fw: fw,
listeners: listeners,
logger: logger,
version: version,
}
}
// Firewall returns the server's firewall for use by the gRPC admin API.
func (s *Server) Firewall() *firewall.Firewall {
return s.fw
}
// Listeners returns the server's listener states for use by the gRPC admin API.
func (s *Server) Listeners() []*ListenerState {
return s.listeners
}
// StartedAt returns the time the server started.
func (s *Server) StartedAt() time.Time {
return s.startedAt
}
// Version returns the server's version string.
func (s *Server) Version() string {
return s.version
}
// listenerAddrForRoute finds the listener address that owns the given hostname.
func (s *Server) listenerAddrForRoute(hostname string) string {
key := strings.ToLower(hostname)
for _, ls := range s.listeners {
if _, ok := ls.lookupRoute(key); ok {
return ls.Addr
}
}
return "unknown"
}
// TotalConnections returns the total number of active connections.
func (s *Server) TotalConnections() int64 {
var total int64
for _, ls := range s.listeners {
total += ls.ActiveConnections.Load()
}
return total
}
// Run starts all listeners and blocks until ctx is cancelled.
func (s *Server) Run(ctx context.Context) error {
s.startedAt = time.Now()
var netListeners []net.Listener
for _, ls := range s.listeners {
ln, err := net.Listen("tcp", ls.Addr)
if err != nil {
for _, l := range netListeners {
_ = l.Close()
}
return fmt.Errorf("listening on %s: %w", ls.Addr, err)
}
s.logger.Info("listening", "addr", ln.Addr(), "routes", len(ls.routes))
netListeners = append(netListeners, ln)
}
for i, ln := range netListeners {
ln := ln
ls := s.listeners[i]
go s.serve(ctx, ln, ls)
}
<-ctx.Done()
s.logger.Info("shutting down")
for _, ln := range netListeners {
_ = ln.Close()
}
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
s.logger.Info("all connections drained")
case <-time.After(s.cfg.Proxy.ShutdownTimeout.Duration):
s.logger.Warn("shutdown timeout exceeded, forcing close")
// Force-close all listener connections to unblock relay goroutines.
s.forceCloseAll()
<-done
}
_ = s.fw.Close()
return nil
}
// ReloadGeoIP reloads the GeoIP database from disk.
func (s *Server) ReloadGeoIP() error {
return s.fw.ReloadGeoIP()
}
func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState) {
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil {
return
}
s.logger.Error("accept error", "error", err)
continue
}
// Enforce per-listener connection limit.
if ls.MaxConnections > 0 && ls.ActiveConnections.Load() >= ls.MaxConnections {
_ = conn.Close()
s.logger.Debug("connection limit reached", "addr", ls.Addr, "limit", ls.MaxConnections)
continue
}
s.wg.Add(1)
ls.ActiveConnections.Add(1)
metrics.ConnectionsActive.WithLabelValues(ls.Addr).Inc()
go s.handleConn(ctx, conn, ls)
}
}
// forceCloseAll closes all tracked connections across all listeners.
func (s *Server) forceCloseAll() {
for _, ls := range s.listeners {
ls.connMu.Lock()
for conn := range ls.activeConns {
_ = conn.Close()
}
ls.connMu.Unlock()
}
}
func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) {
defer s.wg.Done()
defer ls.ActiveConnections.Add(-1)
defer metrics.ConnectionsActive.WithLabelValues(ls.Addr).Dec()
defer func() { _ = conn.Close() }()
ls.connMu.Lock()
ls.activeConns[conn] = struct{}{}
ls.connMu.Unlock()
defer func() {
ls.connMu.Lock()
delete(ls.activeConns, conn)
ls.connMu.Unlock()
}()
remoteAddr := conn.RemoteAddr().String()
addrPort, err := netip.ParseAddrPort(remoteAddr)
if err != nil {
s.logger.Error("parsing remote address", "addr", remoteAddr, "error", err)
return
}
addr := addrPort.Addr()
// Parse PROXY protocol header if enabled on this listener.
if ls.ProxyProtocol {
hdr, err := proxyproto.Parse(conn, time.Now().Add(5*time.Second))
if err != nil {
s.logger.Debug("PROXY protocol parse failed", "addr", addr, "error", err)
return
}
if hdr.Command == proxyproto.CommandProxy {
addr = hdr.SrcAddr.Addr()
addrPort = hdr.SrcAddr
s.logger.Debug("PROXY protocol", "real_addr", addr, "peer_addr", remoteAddr)
}
}
if blocked, reason := s.fw.BlockedWithReason(addr); blocked {
metrics.FirewallBlockedTotal.WithLabelValues(reason).Inc()
s.logger.Debug("blocked by firewall", "addr", addr, "reason", reason)
return
}
hostname, peeked, err := sni.Extract(conn, time.Now().Add(10*time.Second))
if err != nil {
s.logger.Debug("SNI extraction failed", "addr", addr, "error", err)
return
}
route, ok := ls.lookupRoute(hostname)
if !ok {
s.logger.Debug("no route for hostname", "addr", addr, "hostname", hostname)
return
}
// Dispatch based on route mode.
switch route.Mode {
case "l7":
s.handleL7(ctx, conn, addr, addrPort, hostname, route, peeked)
default:
s.handleL4(ctx, conn, addr, addrPort, hostname, route, peeked)
}
}
// handleL4 handles an L4 (passthrough) connection.
func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) {
metrics.ConnectionsTotal.WithLabelValues(s.listenerAddrForRoute(hostname), "l4").Inc()
dialStart := time.Now()
backendConn, err := net.DialTimeout("tcp", route.Backend, s.cfg.Proxy.ConnectTimeout.Duration)
metrics.BackendDialDuration.WithLabelValues(route.Backend).Observe(time.Since(dialStart).Seconds())
if err != nil {
s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err)
return
}
defer func() { _ = backendConn.Close() }()
// Send PROXY protocol v2 header to backend if configured.
if route.SendProxyProtocol {
backendAddrPort, _ := netip.ParseAddrPort(backendConn.RemoteAddr().String())
if err := proxyproto.WriteV2(backendConn, clientAddrPort, backendAddrPort); err != nil {
s.logger.Error("writing PROXY protocol header", "hostname", hostname, "error", err)
return
}
}
s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", route.Backend)
result, err := proxy.Relay(ctx, conn, backendConn, peeked, s.cfg.Proxy.IdleTimeout.Duration)
if err != nil && ctx.Err() == nil {
s.logger.Debug("relay ended", "hostname", hostname, "error", err)
}
metrics.TransferredBytesTotal.WithLabelValues("client_to_backend", hostname).Add(float64(result.ClientBytes))
metrics.TransferredBytesTotal.WithLabelValues("backend_to_client", hostname).Add(float64(result.BackendBytes))
s.logger.Info("connection closed",
"addr", addr,
"hostname", hostname,
"client_bytes", result.ClientBytes,
"backend_bytes", result.BackendBytes,
)
}
// handleL7 handles an L7 (TLS-terminating) connection.
func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) {
metrics.ConnectionsTotal.WithLabelValues(s.listenerAddrForRoute(hostname), "l7").Inc()
s.logger.Debug("L7 proxying", "addr", addr, "hostname", hostname, "backend", route.Backend)
var policies []l7.PolicyRule
for _, p := range route.L7Policies {
policies = append(policies, l7.PolicyRule{Type: p.Type, Value: p.Value})
}
rc := l7.RouteConfig{
Hostname: hostname,
Backend: route.Backend,
TLSCert: route.TLSCert,
TLSKey: route.TLSKey,
BackendTLS: route.BackendTLS,
SendProxyProtocol: route.SendProxyProtocol,
ConnectTimeout: s.cfg.Proxy.ConnectTimeout.Duration,
Policies: policies,
}
if err := l7.Serve(ctx, conn, peeked, rc, clientAddrPort, s.logger); err != nil {
if ctx.Err() == nil {
s.logger.Debug("L7 serve ended", "hostname", hostname, "error", err)
}
}
s.logger.Info("L7 connection closed", "addr", addr, "hostname", hostname)
}