Add L7/PROXY protocol data model, config, and architecture docs

Extend the config, database schema, and server internals to support
per-route L4/L7 mode selection and PROXY protocol fields. This is the
foundation for L7 HTTP/2 reverse proxying and multi-hop PROXY protocol
support described in the updated ARCHITECTURE.md.

Config: Listener gains ProxyProtocol; Route gains Mode, TLSCert,
TLSKey, BackendTLS, SendProxyProtocol. L7 routes validated at load
time (cert/key pair must exist and parse). Mode defaults to "l4".

DB: Migration v2 adds columns to listeners and routes tables. CRUD
and seeding updated to persist all new fields.

Server: RouteInfo replaces bare backend string in route lookup.
handleConn dispatches on route.Mode (L7 path stubbed with error).
ListenerState and ListenerData carry ProxyProtocol flag.

All existing L4 tests pass unchanged. New tests cover migration v2,
L7 field persistence, config validation for mode/cert/key, and
proxy_protocol flag round-tripping.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-25 13:15:51 -07:00
parent 666d55018c
commit ed94548dfa
17 changed files with 1283 additions and 205 deletions

View File

@@ -1,6 +1,7 @@
package config
import (
"crypto/tls"
"fmt"
"os"
"strings"
@@ -27,13 +28,19 @@ type GRPC struct {
}
type Listener struct {
Addr string `toml:"addr"`
Routes []Route `toml:"routes"`
Addr string `toml:"addr"`
ProxyProtocol bool `toml:"proxy_protocol"`
Routes []Route `toml:"routes"`
}
type Route struct {
Hostname string `toml:"hostname"`
Backend string `toml:"backend"`
Hostname string `toml:"hostname"`
Backend string `toml:"backend"`
Mode string `toml:"mode"` // "l4" (default) or "l7"
TLSCert string `toml:"tls_cert"` // PEM certificate path (L7 only)
TLSKey string `toml:"tls_key"` // PEM private key path (L7 only)
BackendTLS bool `toml:"backend_tls"` // re-encrypt to backend (L7 only)
SendProxyProtocol bool `toml:"send_proxy_protocol"` // send PROXY v2 header to backend
}
type Firewall struct {
@@ -163,12 +170,14 @@ func (c *Config) validate() error {
}
// Validate listeners if provided (used for seeding on first run).
for i, l := range c.Listeners {
for i := range c.Listeners {
l := &c.Listeners[i]
if l.Addr == "" {
return fmt.Errorf("listener %d: addr is required", i)
}
seen := make(map[string]bool)
for j, r := range l.Routes {
for j := range l.Routes {
r := &l.Routes[j]
if r.Hostname == "" {
return fmt.Errorf("listener %d (%s), route %d: hostname is required", i, l.Addr, j)
}
@@ -179,6 +188,27 @@ func (c *Config) validate() error {
return fmt.Errorf("listener %d (%s), route %d: duplicate hostname %q", i, l.Addr, j, r.Hostname)
}
seen[r.Hostname] = true
// Normalize mode: empty defaults to "l4".
if r.Mode == "" {
r.Mode = "l4"
}
if r.Mode != "l4" && r.Mode != "l7" {
return fmt.Errorf("listener %d (%s), route %d (%s): mode must be \"l4\" or \"l7\", got %q",
i, l.Addr, j, r.Hostname, r.Mode)
}
// L7 routes require TLS cert and key.
if r.Mode == "l7" {
if r.TLSCert == "" || r.TLSKey == "" {
return fmt.Errorf("listener %d (%s), route %d (%s): L7 routes require tls_cert and tls_key",
i, l.Addr, j, r.Hostname)
}
if _, err := tls.LoadX509KeyPair(r.TLSCert, r.TLSKey); err != nil {
return fmt.Errorf("listener %d (%s), route %d (%s): loading TLS cert/key: %w",
i, l.Addr, j, r.Hostname, err)
}
}
}
}

View File

@@ -391,3 +391,147 @@ path = "/tmp/test.db"
t.Fatalf("got grpc.addr %q, want %q", cfg.GRPC.Addr, "/var/run/override.sock")
}
}
func TestLoadL4ModeDefault(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "example.com"
backend = "127.0.0.1:8443"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Mode should be normalized to "l4" when unset.
if cfg.Listeners[0].Routes[0].Mode != "l4" {
t.Fatalf("got mode %q, want %q", cfg.Listeners[0].Routes[0].Mode, "l4")
}
}
func TestLoadInvalidMode(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "example.com"
backend = "127.0.0.1:8443"
mode = "l5"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for invalid mode")
}
}
func TestLoadL7RequiresCertKey(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "example.com"
backend = "127.0.0.1:8080"
mode = "l7"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for L7 route without cert/key")
}
}
func TestLoadL7InvalidCertKey(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "example.com"
backend = "127.0.0.1:8080"
mode = "l7"
tls_cert = "/nonexistent/cert.pem"
tls_key = "/nonexistent/key.pem"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for L7 route with nonexistent cert/key files")
}
}
func TestLoadProxyProtocol(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[[listeners]]
addr = ":443"
proxy_protocol = true
[[listeners.routes]]
hostname = "example.com"
backend = "127.0.0.1:8443"
send_proxy_protocol = true
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !cfg.Listeners[0].ProxyProtocol {
t.Fatal("expected proxy_protocol = true")
}
if !cfg.Listeners[0].Routes[0].SendProxyProtocol {
t.Fatal("expected send_proxy_protocol = true")
}
}

View File

@@ -41,7 +41,7 @@ func TestIsEmpty(t *testing.T) {
t.Fatal("expected empty database")
}
if _, err := store.CreateListener(":443"); err != nil {
if _, err := store.CreateListener(":443", false); err != nil {
t.Fatalf("create listener: %v", err)
}
@@ -57,7 +57,7 @@ func TestIsEmpty(t *testing.T) {
func TestListenerCRUD(t *testing.T) {
store := openTestDB(t)
id, err := store.CreateListener(":443")
id, err := store.CreateListener(":443", false)
if err != nil {
t.Fatalf("create: %v", err)
}
@@ -75,6 +75,9 @@ func TestListenerCRUD(t *testing.T) {
if listeners[0].Addr != ":443" {
t.Fatalf("got addr %q, want %q", listeners[0].Addr, ":443")
}
if listeners[0].ProxyProtocol {
t.Fatal("expected proxy_protocol = false")
}
l, err := store.GetListenerByAddr(":443")
if err != nil {
@@ -97,13 +100,33 @@ func TestListenerCRUD(t *testing.T) {
}
}
func TestListenerProxyProtocol(t *testing.T) {
store := openTestDB(t)
id, err := store.CreateListener(":443", true)
if err != nil {
t.Fatalf("create: %v", err)
}
l, err := store.GetListenerByAddr(":443")
if err != nil {
t.Fatalf("get by addr: %v", err)
}
if l.ID != id {
t.Fatalf("got ID %d, want %d", l.ID, id)
}
if !l.ProxyProtocol {
t.Fatal("expected proxy_protocol = true")
}
}
func TestListenerDuplicateAddr(t *testing.T) {
store := openTestDB(t)
if _, err := store.CreateListener(":443"); err != nil {
if _, err := store.CreateListener(":443", false); err != nil {
t.Fatalf("first create: %v", err)
}
if _, err := store.CreateListener(":443"); err == nil {
if _, err := store.CreateListener(":443", false); err == nil {
t.Fatal("expected error for duplicate addr")
}
}
@@ -111,12 +134,12 @@ func TestListenerDuplicateAddr(t *testing.T) {
func TestRouteCRUD(t *testing.T) {
store := openTestDB(t)
listenerID, err := store.CreateListener(":443")
listenerID, err := store.CreateListener(":443", false)
if err != nil {
t.Fatalf("create listener: %v", err)
}
routeID, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:8443")
routeID, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:8443", "l4", "", "", false, false)
if err != nil {
t.Fatalf("create route: %v", err)
}
@@ -134,6 +157,9 @@ func TestRouteCRUD(t *testing.T) {
if routes[0].Hostname != "example.com" {
t.Fatalf("got hostname %q, want %q", routes[0].Hostname, "example.com")
}
if routes[0].Mode != "l4" {
t.Fatalf("got mode %q, want %q", routes[0].Mode, "l4")
}
if err := store.DeleteRoute(listenerID, "example.com"); err != nil {
t.Fatalf("delete route: %v", err)
@@ -148,14 +174,51 @@ func TestRouteCRUD(t *testing.T) {
}
}
func TestRouteL7Fields(t *testing.T) {
store := openTestDB(t)
listenerID, _ := store.CreateListener(":443", false)
_, err := store.CreateRoute(listenerID, "api.example.com", "127.0.0.1:8080", "l7",
"/certs/api.crt", "/certs/api.key", false, true)
if err != nil {
t.Fatalf("create L7 route: %v", err)
}
routes, err := store.ListRoutes(listenerID)
if err != nil {
t.Fatalf("list routes: %v", err)
}
if len(routes) != 1 {
t.Fatalf("got %d routes, want 1", len(routes))
}
r := routes[0]
if r.Mode != "l7" {
t.Fatalf("mode = %q, want %q", r.Mode, "l7")
}
if r.TLSCert != "/certs/api.crt" {
t.Fatalf("tls_cert = %q, want %q", r.TLSCert, "/certs/api.crt")
}
if r.TLSKey != "/certs/api.key" {
t.Fatalf("tls_key = %q, want %q", r.TLSKey, "/certs/api.key")
}
if r.BackendTLS {
t.Fatal("expected backend_tls = false")
}
if !r.SendProxyProtocol {
t.Fatal("expected send_proxy_protocol = true")
}
}
func TestRouteDuplicateHostname(t *testing.T) {
store := openTestDB(t)
listenerID, _ := store.CreateListener(":443")
if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:8443"); err != nil {
listenerID, _ := store.CreateListener(":443", false)
if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:8443", "l4", "", "", false, false); err != nil {
t.Fatalf("first create: %v", err)
}
if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:9443"); err == nil {
if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:9443", "l4", "", "", false, false); err == nil {
t.Fatal("expected error for duplicate hostname on same listener")
}
}
@@ -163,9 +226,9 @@ func TestRouteDuplicateHostname(t *testing.T) {
func TestRouteCascadeDelete(t *testing.T) {
store := openTestDB(t)
listenerID, _ := store.CreateListener(":443")
store.CreateRoute(listenerID, "a.example.com", "127.0.0.1:8443")
store.CreateRoute(listenerID, "b.example.com", "127.0.0.1:9443")
listenerID, _ := store.CreateListener(":443", false)
store.CreateRoute(listenerID, "a.example.com", "127.0.0.1:8443", "l4", "", "", false, false)
store.CreateRoute(listenerID, "b.example.com", "127.0.0.1:9443", "l4", "", "", false, false)
if err := store.DeleteListener(listenerID); err != nil {
t.Fatalf("delete listener: %v", err)
@@ -237,14 +300,15 @@ func TestSeed(t *testing.T) {
{
Addr: ":443",
Routes: []config.Route{
{Hostname: "a.example.com", Backend: "127.0.0.1:8443"},
{Hostname: "a.example.com", Backend: "127.0.0.1:8443", Mode: "l4"},
{Hostname: "b.example.com", Backend: "127.0.0.1:9443"},
},
},
{
Addr: ":8443",
Addr: ":8443",
ProxyProtocol: true,
Routes: []config.Route{
{Hostname: "c.example.com", Backend: "127.0.0.1:18443"},
{Hostname: "c.example.com", Backend: "127.0.0.1:18443", Mode: "l4", SendProxyProtocol: true},
},
},
}
@@ -266,6 +330,9 @@ func TestSeed(t *testing.T) {
if len(dbListeners) != 2 {
t.Fatalf("got %d listeners, want 2", len(dbListeners))
}
if !dbListeners[1].ProxyProtocol {
t.Fatal("expected listener 2 proxy_protocol = true")
}
routes, err := store.ListRoutes(dbListeners[0].ID)
if err != nil {
@@ -275,6 +342,25 @@ func TestSeed(t *testing.T) {
t.Fatalf("got %d routes for listener 0, want 2", len(routes))
}
// Verify mode defaults to "l4" even when empty in config.
for _, r := range routes {
if r.Mode != "l4" {
t.Fatalf("route %q mode = %q, want %q", r.Hostname, r.Mode, "l4")
}
}
// Verify send_proxy_protocol on listener 2's route.
routes2, err := store.ListRoutes(dbListeners[1].ID)
if err != nil {
t.Fatalf("list routes listener 2: %v", err)
}
if len(routes2) != 1 {
t.Fatalf("got %d routes for listener 1, want 1", len(routes2))
}
if !routes2[0].SendProxyProtocol {
t.Fatal("expected send_proxy_protocol = true on listener 2 route")
}
rules, err := store.ListFirewallRules()
if err != nil {
t.Fatalf("list firewall rules: %v", err)
@@ -287,7 +373,7 @@ func TestSeed(t *testing.T) {
func TestSnapshot(t *testing.T) {
store := openTestDB(t)
store.CreateListener(":443")
store.CreateListener(":443", false)
dest := filepath.Join(t.TempDir(), "backup.db")
if err := store.Snapshot(dest); err != nil {
@@ -329,3 +415,57 @@ func TestDeleteNonexistent(t *testing.T) {
t.Fatal("expected error deleting nonexistent firewall rule")
}
}
// TestMigrationV2Upgrade verifies that migration v2 adds new columns
// to an existing v1 database without data loss.
func TestMigrationV2Upgrade(t *testing.T) {
dir := t.TempDir()
store, err := Open(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatalf("open: %v", err)
}
t.Cleanup(func() { store.Close() })
// Run full migrations (v1 + v2).
if err := store.Migrate(); err != nil {
t.Fatalf("migrate: %v", err)
}
// Insert a listener and route with defaults to verify new columns work.
lid, err := store.CreateListener(":443", false)
if err != nil {
t.Fatalf("create listener: %v", err)
}
_, err = store.CreateRoute(lid, "test.example.com", "127.0.0.1:8443", "l4", "", "", false, false)
if err != nil {
t.Fatalf("create route: %v", err)
}
// Read back and verify defaults.
routes, err := store.ListRoutes(lid)
if err != nil {
t.Fatalf("list routes: %v", err)
}
if len(routes) != 1 {
t.Fatalf("got %d routes, want 1", len(routes))
}
r := routes[0]
if r.Mode != "l4" {
t.Fatalf("mode = %q, want %q", r.Mode, "l4")
}
if r.TLSCert != "" || r.TLSKey != "" {
t.Fatalf("expected empty cert/key, got cert=%q key=%q", r.TLSCert, r.TLSKey)
}
if r.BackendTLS || r.SendProxyProtocol {
t.Fatal("expected false for backend_tls and send_proxy_protocol")
}
listeners, err := store.ListListeners()
if err != nil {
t.Fatalf("list listeners: %v", err)
}
if listeners[0].ProxyProtocol {
t.Fatal("expected proxy_protocol = false")
}
}

View File

@@ -4,13 +4,14 @@ import "fmt"
// Listener is a database listener record.
type Listener struct {
ID int64
Addr string
ID int64
Addr string
ProxyProtocol bool
}
// ListListeners returns all listeners.
func (s *Store) ListListeners() ([]Listener, error) {
rows, err := s.db.Query("SELECT id, addr FROM listeners ORDER BY id")
rows, err := s.db.Query("SELECT id, addr, proxy_protocol FROM listeners ORDER BY id")
if err != nil {
return nil, fmt.Errorf("querying listeners: %w", err)
}
@@ -19,7 +20,7 @@ func (s *Store) ListListeners() ([]Listener, error) {
var listeners []Listener
for rows.Next() {
var l Listener
if err := rows.Scan(&l.ID, &l.Addr); err != nil {
if err := rows.Scan(&l.ID, &l.Addr, &l.ProxyProtocol); err != nil {
return nil, fmt.Errorf("scanning listener: %w", err)
}
listeners = append(listeners, l)
@@ -28,8 +29,11 @@ func (s *Store) ListListeners() ([]Listener, error) {
}
// CreateListener inserts a listener and returns its ID.
func (s *Store) CreateListener(addr string) (int64, error) {
result, err := s.db.Exec("INSERT INTO listeners (addr) VALUES (?)", addr)
func (s *Store) CreateListener(addr string, proxyProtocol bool) (int64, error) {
result, err := s.db.Exec(
"INSERT INTO listeners (addr, proxy_protocol) VALUES (?, ?)",
addr, proxyProtocol,
)
if err != nil {
return 0, fmt.Errorf("inserting listener: %w", err)
}
@@ -52,8 +56,8 @@ func (s *Store) DeleteListener(id int64) error {
// GetListenerByAddr returns a listener by its address.
func (s *Store) GetListenerByAddr(addr string) (Listener, error) {
var l Listener
err := s.db.QueryRow("SELECT id, addr FROM listeners WHERE addr = ?", addr).
Scan(&l.ID, &l.Addr)
err := s.db.QueryRow("SELECT id, addr, proxy_protocol FROM listeners WHERE addr = ?", addr).
Scan(&l.ID, &l.Addr, &l.ProxyProtocol)
if err != nil {
return Listener{}, fmt.Errorf("querying listener by addr %q: %w", addr, err)
}

View File

@@ -13,6 +13,7 @@ type migration struct {
var migrations = []migration{
{1, "create_core_tables", migrate001CreateCoreTables},
{2, "add_proxy_protocol_and_l7_fields", migrate002AddL7Fields},
}
// Migrate runs all unapplied migrations sequentially.
@@ -91,3 +92,21 @@ func migrate001CreateCoreTables(tx *sql.Tx) error {
}
return nil
}
func migrate002AddL7Fields(tx *sql.Tx) error {
stmts := []string{
`ALTER TABLE listeners ADD COLUMN proxy_protocol INTEGER NOT NULL DEFAULT 0`,
`ALTER TABLE routes ADD COLUMN mode TEXT NOT NULL DEFAULT 'l4'`,
`ALTER TABLE routes ADD COLUMN tls_cert TEXT NOT NULL DEFAULT ''`,
`ALTER TABLE routes ADD COLUMN tls_key TEXT NOT NULL DEFAULT ''`,
`ALTER TABLE routes ADD COLUMN backend_tls INTEGER NOT NULL DEFAULT 0`,
`ALTER TABLE routes ADD COLUMN send_proxy_protocol INTEGER NOT NULL DEFAULT 0`,
}
for _, stmt := range stmts {
if _, err := tx.Exec(stmt); err != nil {
return err
}
}
return nil
}

View File

@@ -4,16 +4,22 @@ import "fmt"
// Route is a database route record.
type Route struct {
ID int64
ListenerID int64
Hostname string
Backend string
ID int64
ListenerID int64
Hostname string
Backend string
Mode string // "l4" or "l7"
TLSCert string
TLSKey string
BackendTLS bool
SendProxyProtocol bool
}
// ListRoutes returns all routes for a listener.
func (s *Store) ListRoutes(listenerID int64) ([]Route, error) {
rows, err := s.db.Query(
"SELECT id, listener_id, hostname, backend FROM routes WHERE listener_id = ? ORDER BY hostname",
`SELECT id, listener_id, hostname, backend, mode, tls_cert, tls_key, backend_tls, send_proxy_protocol
FROM routes WHERE listener_id = ? ORDER BY hostname`,
listenerID,
)
if err != nil {
@@ -24,7 +30,8 @@ func (s *Store) ListRoutes(listenerID int64) ([]Route, error) {
var routes []Route
for rows.Next() {
var r Route
if err := rows.Scan(&r.ID, &r.ListenerID, &r.Hostname, &r.Backend); err != nil {
if err := rows.Scan(&r.ID, &r.ListenerID, &r.Hostname, &r.Backend,
&r.Mode, &r.TLSCert, &r.TLSKey, &r.BackendTLS, &r.SendProxyProtocol); err != nil {
return nil, fmt.Errorf("scanning route: %w", err)
}
routes = append(routes, r)
@@ -33,10 +40,11 @@ func (s *Store) ListRoutes(listenerID int64) ([]Route, error) {
}
// CreateRoute inserts a route and returns its ID.
func (s *Store) CreateRoute(listenerID int64, hostname, backend string) (int64, error) {
func (s *Store) CreateRoute(listenerID int64, hostname, backend, mode, tlsCert, tlsKey string, backendTLS, sendProxyProtocol bool) (int64, error) {
result, err := s.db.Exec(
"INSERT INTO routes (listener_id, hostname, backend) VALUES (?, ?, ?)",
listenerID, hostname, backend,
`INSERT INTO routes (listener_id, hostname, backend, mode, tls_cert, tls_key, backend_tls, send_proxy_protocol)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
listenerID, hostname, backend, mode, tlsCert, tlsKey, backendTLS, sendProxyProtocol,
)
if err != nil {
return 0, fmt.Errorf("inserting route: %w", err)

View File

@@ -17,16 +17,25 @@ func (s *Store) Seed(listeners []config.Listener, fw config.Firewall) error {
defer tx.Rollback()
for _, l := range listeners {
result, err := tx.Exec("INSERT INTO listeners (addr) VALUES (?)", l.Addr)
result, err := tx.Exec(
"INSERT INTO listeners (addr, proxy_protocol) VALUES (?, ?)",
l.Addr, l.ProxyProtocol,
)
if err != nil {
return fmt.Errorf("seeding listener %q: %w", l.Addr, err)
}
listenerID, _ := result.LastInsertId()
for _, r := range l.Routes {
mode := r.Mode
if mode == "" {
mode = "l4"
}
_, err := tx.Exec(
"INSERT INTO routes (listener_id, hostname, backend) VALUES (?, ?, ?)",
`INSERT INTO routes (listener_id, hostname, backend, mode, tls_cert, tls_key, backend_tls, send_proxy_protocol)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
listenerID, strings.ToLower(r.Hostname), r.Backend,
mode, r.TLSCert, r.TLSKey, r.BackendTLS, r.SendProxyProtocol,
)
if err != nil {
return fmt.Errorf("seeding route %q on listener %q: %w", r.Hostname, l.Addr, err)

View File

@@ -88,10 +88,10 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (
resp := &pb.ListRoutesResponse{
ListenerAddr: ls.Addr,
}
for hostname, backend := range routes {
for hostname, route := range routes {
resp.Routes = append(resp.Routes, &pb.Route{
Hostname: hostname,
Backend: backend,
Backend: route.Backend,
})
}
return resp, nil
@@ -119,11 +119,11 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.
hostname := strings.ToLower(req.Route.Hostname)
// Write-through: DB first, then memory.
if _, err := a.store.CreateRoute(ls.ID, hostname, req.Route.Backend); err != nil {
if _, err := a.store.CreateRoute(ls.ID, hostname, req.Route.Backend, "l4", "", "", false, false); err != nil {
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
}
if err := ls.AddRoute(hostname, req.Route.Backend); err != nil {
if err := ls.AddRoute(hostname, server.RouteInfo{Backend: req.Route.Backend, Mode: "l4"}); err != nil {
// DB succeeded but memory failed (should not happen since DB enforces uniqueness).
a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err)
}

View File

@@ -87,14 +87,18 @@ func setup(t *testing.T) *testEnv {
if err != nil {
t.Fatalf("list routes: %v", err)
}
routes := make(map[string]string, len(dbRoutes))
routes := make(map[string]server.RouteInfo, len(dbRoutes))
for _, r := range dbRoutes {
routes[r.Hostname] = r.Backend
routes[r.Hostname] = server.RouteInfo{
Backend: r.Backend,
Mode: r.Mode,
}
}
listenerData = append(listenerData, server.ListenerData{
ID: l.ID,
Addr: l.Addr,
Routes: routes,
ID: l.ID,
Addr: l.Addr,
ProxyProtocol: l.ProxyProtocol,
Routes: routes,
})
}

View File

@@ -17,11 +17,22 @@ import (
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
)
// RouteInfo holds the full configuration for a single route.
type RouteInfo struct {
Backend string
Mode string // "l4" or "l7"
TLSCert string
TLSKey string
BackendTLS bool
SendProxyProtocol bool
}
// ListenerState holds the mutable state for a single proxy listener.
type ListenerState struct {
ID int64 // database primary key
Addr string
routes map[string]string // lowercase hostname → backend addr
ProxyProtocol bool
routes map[string]RouteInfo // lowercase hostname → route info
mu sync.RWMutex
ActiveConnections atomic.Int64
activeConns map[net.Conn]struct{} // tracked for forced shutdown
@@ -29,11 +40,11 @@ type ListenerState struct {
}
// Routes returns a snapshot of the listener's route table.
func (ls *ListenerState) Routes() map[string]string {
func (ls *ListenerState) Routes() map[string]RouteInfo {
ls.mu.RLock()
defer ls.mu.RUnlock()
m := make(map[string]string, len(ls.routes))
m := make(map[string]RouteInfo, len(ls.routes))
for k, v := range ls.routes {
m[k] = v
}
@@ -42,7 +53,7 @@ func (ls *ListenerState) Routes() map[string]string {
// AddRoute adds a route to the listener. Returns an error if the hostname
// already exists.
func (ls *ListenerState) AddRoute(hostname, backend string) error {
func (ls *ListenerState) AddRoute(hostname string, info RouteInfo) error {
key := strings.ToLower(hostname)
ls.mu.Lock()
@@ -51,7 +62,7 @@ func (ls *ListenerState) AddRoute(hostname, backend string) error {
if _, ok := ls.routes[key]; ok {
return fmt.Errorf("route %q already exists", hostname)
}
ls.routes[key] = backend
ls.routes[key] = info
return nil
}
@@ -70,19 +81,20 @@ func (ls *ListenerState) RemoveRoute(hostname string) error {
return nil
}
func (ls *ListenerState) lookupRoute(hostname string) (string, bool) {
func (ls *ListenerState) lookupRoute(hostname string) (RouteInfo, bool) {
ls.mu.RLock()
defer ls.mu.RUnlock()
backend, ok := ls.routes[hostname]
return backend, ok
info, ok := ls.routes[hostname]
return info, ok
}
// ListenerData holds the data needed to construct a ListenerState.
type ListenerData struct {
ID int64
Addr string
Routes map[string]string // lowercase hostname → backend
ID int64
Addr string
ProxyProtocol bool
Routes map[string]RouteInfo // lowercase hostname → route info
}
// Server is the mc-proxy server. It manages listeners, firewall evaluation,
@@ -102,10 +114,11 @@ func New(cfg *config.Config, fw *firewall.Firewall, listenerData []ListenerData,
var listeners []*ListenerState
for _, ld := range listenerData {
listeners = append(listeners, &ListenerState{
ID: ld.ID,
Addr: ld.Addr,
routes: ld.Routes,
activeConns: make(map[net.Conn]struct{}),
ID: ld.ID,
Addr: ld.Addr,
ProxyProtocol: ld.ProxyProtocol,
routes: ld.Routes,
activeConns: make(map[net.Conn]struct{}),
})
}
@@ -264,20 +277,32 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat
return
}
backend, ok := ls.lookupRoute(hostname)
route, ok := ls.lookupRoute(hostname)
if !ok {
s.logger.Debug("no route for hostname", "addr", addr, "hostname", hostname)
return
}
backendConn, err := net.DialTimeout("tcp", backend, s.cfg.Proxy.ConnectTimeout.Duration)
// Dispatch based on route mode. L7 will be implemented in a later phase.
switch route.Mode {
case "l7":
s.logger.Error("L7 mode not yet implemented", "hostname", hostname)
return
default:
s.handleL4(ctx, conn, ls, addr, hostname, route, peeked)
}
}
// handleL4 handles an L4 (passthrough) connection.
func (s *Server) handleL4(ctx context.Context, conn net.Conn, _ *ListenerState, addr netip.Addr, hostname string, route RouteInfo, peeked []byte) {
backendConn, err := net.DialTimeout("tcp", route.Backend, s.cfg.Proxy.ConnectTimeout.Duration)
if err != nil {
s.logger.Error("backend dial failed", "hostname", hostname, "backend", backend, "error", err)
s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err)
return
}
defer backendConn.Close()
s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", backend)
s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", route.Backend)
result, err := proxy.Relay(ctx, conn, backendConn, peeked, s.cfg.Proxy.IdleTimeout.Duration)
if err != nil && ctx.Err() == nil {

View File

@@ -14,6 +14,11 @@ import (
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
)
// l4Route creates a RouteInfo for an L4 passthrough route.
func l4Route(backend string) RouteInfo {
return RouteInfo{Backend: backend, Mode: "l4"}
}
// echoServer accepts one connection, copies everything back, then closes.
func echoServer(t *testing.T, ln net.Listener) {
t.Helper()
@@ -85,8 +90,8 @@ func TestProxyRoundTrip(t *testing.T) {
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]string{
"echo.test": backendLn.Addr().String(),
Routes: map[string]RouteInfo{
"echo.test": l4Route(backendLn.Addr().String()),
},
},
})
@@ -141,8 +146,8 @@ func TestNoRouteResets(t *testing.T) {
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]string{
"other.test": "127.0.0.1:1", // exists but won't match
Routes: map[string]RouteInfo{
"other.test": l4Route("127.0.0.1:1"), // exists but won't match
},
},
})
@@ -212,8 +217,8 @@ func TestFirewallBlocks(t *testing.T) {
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]string{
"echo.test": backendLn.Addr().String(),
Routes: map[string]RouteInfo{
"echo.test": l4Route(backendLn.Addr().String()),
},
},
}, logger, "test")
@@ -267,7 +272,7 @@ func TestNotTLSResets(t *testing.T) {
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]string{"x.test": "127.0.0.1:1"},
Routes: map[string]RouteInfo{"x.test": l4Route("127.0.0.1:1")},
},
})
@@ -325,8 +330,8 @@ func TestConnectionTracking(t *testing.T) {
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]string{
"conn.test": backendLn.Addr().String(),
Routes: map[string]RouteInfo{
"conn.test": l4Route(backendLn.Addr().String()),
},
},
})
@@ -432,8 +437,8 @@ func TestMultipleListeners(t *testing.T) {
ln2.Close()
srv := newTestServer(t, []ListenerData{
{ID: 1, Addr: addr1, Routes: map[string]string{"svc.test": backendA.Addr().String()}},
{ID: 2, Addr: addr2, Routes: map[string]string{"svc.test": backendB.Addr().String()}},
{ID: 1, Addr: addr1, Routes: map[string]RouteInfo{"svc.test": l4Route(backendA.Addr().String())}},
{ID: 2, Addr: addr2, Routes: map[string]RouteInfo{"svc.test": l4Route(backendB.Addr().String())}},
})
stop := startAndStop(t, srv)
@@ -500,8 +505,8 @@ func TestCaseInsensitiveRouting(t *testing.T) {
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]string{
"echo.test": backendLn.Addr().String(),
Routes: map[string]RouteInfo{
"echo.test": l4Route(backendLn.Addr().String()),
},
},
})
@@ -549,8 +554,8 @@ func TestBackendUnreachable(t *testing.T) {
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]string{
"dead.test": deadAddr,
Routes: map[string]RouteInfo{
"dead.test": l4Route(deadAddr),
},
},
})
@@ -612,7 +617,7 @@ func TestGracefulShutdown(t *testing.T) {
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
srv := New(cfg, fw, []ListenerData{
{ID: 1, Addr: proxyAddr, Routes: map[string]string{"hold.test": backendLn.Addr().String()}},
{ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{"hold.test": l4Route(backendLn.Addr().String())}},
}, logger, "test")
ctx, cancel := context.WithCancel(context.Background())
@@ -651,18 +656,18 @@ func TestListenerStateRoutes(t *testing.T) {
ls := &ListenerState{
ID: 1,
Addr: ":443",
routes: map[string]string{
"a.test": "127.0.0.1:1",
routes: map[string]RouteInfo{
"a.test": l4Route("127.0.0.1:1"),
},
}
// AddRoute
if err := ls.AddRoute("b.test", "127.0.0.1:2"); err != nil {
if err := ls.AddRoute("b.test", l4Route("127.0.0.1:2")); err != nil {
t.Fatalf("AddRoute: %v", err)
}
// AddRoute duplicate
if err := ls.AddRoute("b.test", "127.0.0.1:3"); err == nil {
if err := ls.AddRoute("b.test", l4Route("127.0.0.1:3")); err == nil {
t.Fatal("expected error for duplicate route")
}
@@ -686,8 +691,8 @@ func TestListenerStateRoutes(t *testing.T) {
if len(routes) != 1 {
t.Fatalf("expected 1 route, got %d", len(routes))
}
if routes["b.test"] != "127.0.0.1:2" {
t.Fatalf("expected b.test → 127.0.0.1:2, got %q", routes["b.test"])
if routes["b.test"].Backend != "127.0.0.1:2" {
t.Fatalf("expected b.test → 127.0.0.1:2, got %q", routes["b.test"].Backend)
}
}