Add L7/PROXY protocol data model, config, and architecture docs
Extend the config, database schema, and server internals to support per-route L4/L7 mode selection and PROXY protocol fields. This is the foundation for L7 HTTP/2 reverse proxying and multi-hop PROXY protocol support described in the updated ARCHITECTURE.md. Config: Listener gains ProxyProtocol; Route gains Mode, TLSCert, TLSKey, BackendTLS, SendProxyProtocol. L7 routes validated at load time (cert/key pair must exist and parse). Mode defaults to "l4". DB: Migration v2 adds columns to listeners and routes tables. CRUD and seeding updated to persist all new fields. Server: RouteInfo replaces bare backend string in route lookup. handleConn dispatches on route.Mode (L7 path stubbed with error). ListenerState and ListenerData carry ProxyProtocol flag. All existing L4 tests pass unchanged. New tests cover migration v2, L7 field persistence, config validation for mode/cert/key, and proxy_protocol flag round-tripping. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -27,13 +28,19 @@ type GRPC struct {
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
Addr string `toml:"addr"`
|
||||
Routes []Route `toml:"routes"`
|
||||
Addr string `toml:"addr"`
|
||||
ProxyProtocol bool `toml:"proxy_protocol"`
|
||||
Routes []Route `toml:"routes"`
|
||||
}
|
||||
|
||||
type Route struct {
|
||||
Hostname string `toml:"hostname"`
|
||||
Backend string `toml:"backend"`
|
||||
Hostname string `toml:"hostname"`
|
||||
Backend string `toml:"backend"`
|
||||
Mode string `toml:"mode"` // "l4" (default) or "l7"
|
||||
TLSCert string `toml:"tls_cert"` // PEM certificate path (L7 only)
|
||||
TLSKey string `toml:"tls_key"` // PEM private key path (L7 only)
|
||||
BackendTLS bool `toml:"backend_tls"` // re-encrypt to backend (L7 only)
|
||||
SendProxyProtocol bool `toml:"send_proxy_protocol"` // send PROXY v2 header to backend
|
||||
}
|
||||
|
||||
type Firewall struct {
|
||||
@@ -163,12 +170,14 @@ func (c *Config) validate() error {
|
||||
}
|
||||
|
||||
// Validate listeners if provided (used for seeding on first run).
|
||||
for i, l := range c.Listeners {
|
||||
for i := range c.Listeners {
|
||||
l := &c.Listeners[i]
|
||||
if l.Addr == "" {
|
||||
return fmt.Errorf("listener %d: addr is required", i)
|
||||
}
|
||||
seen := make(map[string]bool)
|
||||
for j, r := range l.Routes {
|
||||
for j := range l.Routes {
|
||||
r := &l.Routes[j]
|
||||
if r.Hostname == "" {
|
||||
return fmt.Errorf("listener %d (%s), route %d: hostname is required", i, l.Addr, j)
|
||||
}
|
||||
@@ -179,6 +188,27 @@ func (c *Config) validate() error {
|
||||
return fmt.Errorf("listener %d (%s), route %d: duplicate hostname %q", i, l.Addr, j, r.Hostname)
|
||||
}
|
||||
seen[r.Hostname] = true
|
||||
|
||||
// Normalize mode: empty defaults to "l4".
|
||||
if r.Mode == "" {
|
||||
r.Mode = "l4"
|
||||
}
|
||||
if r.Mode != "l4" && r.Mode != "l7" {
|
||||
return fmt.Errorf("listener %d (%s), route %d (%s): mode must be \"l4\" or \"l7\", got %q",
|
||||
i, l.Addr, j, r.Hostname, r.Mode)
|
||||
}
|
||||
|
||||
// L7 routes require TLS cert and key.
|
||||
if r.Mode == "l7" {
|
||||
if r.TLSCert == "" || r.TLSKey == "" {
|
||||
return fmt.Errorf("listener %d (%s), route %d (%s): L7 routes require tls_cert and tls_key",
|
||||
i, l.Addr, j, r.Hostname)
|
||||
}
|
||||
if _, err := tls.LoadX509KeyPair(r.TLSCert, r.TLSKey); err != nil {
|
||||
return fmt.Errorf("listener %d (%s), route %d (%s): loading TLS cert/key: %w",
|
||||
i, l.Addr, j, r.Hostname, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -391,3 +391,147 @@ path = "/tmp/test.db"
|
||||
t.Fatalf("got grpc.addr %q, want %q", cfg.GRPC.Addr, "/var/run/override.sock")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadL4ModeDefault(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.toml")
|
||||
|
||||
data := `
|
||||
[database]
|
||||
path = "/tmp/test.db"
|
||||
|
||||
[[listeners]]
|
||||
addr = ":443"
|
||||
|
||||
[[listeners.routes]]
|
||||
hostname = "example.com"
|
||||
backend = "127.0.0.1:8443"
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Mode should be normalized to "l4" when unset.
|
||||
if cfg.Listeners[0].Routes[0].Mode != "l4" {
|
||||
t.Fatalf("got mode %q, want %q", cfg.Listeners[0].Routes[0].Mode, "l4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadInvalidMode(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.toml")
|
||||
|
||||
data := `
|
||||
[database]
|
||||
path = "/tmp/test.db"
|
||||
|
||||
[[listeners]]
|
||||
addr = ":443"
|
||||
|
||||
[[listeners.routes]]
|
||||
hostname = "example.com"
|
||||
backend = "127.0.0.1:8443"
|
||||
mode = "l5"
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadL7RequiresCertKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.toml")
|
||||
|
||||
data := `
|
||||
[database]
|
||||
path = "/tmp/test.db"
|
||||
|
||||
[[listeners]]
|
||||
addr = ":443"
|
||||
|
||||
[[listeners.routes]]
|
||||
hostname = "example.com"
|
||||
backend = "127.0.0.1:8080"
|
||||
mode = "l7"
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for L7 route without cert/key")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadL7InvalidCertKey(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.toml")
|
||||
|
||||
data := `
|
||||
[database]
|
||||
path = "/tmp/test.db"
|
||||
|
||||
[[listeners]]
|
||||
addr = ":443"
|
||||
|
||||
[[listeners.routes]]
|
||||
hostname = "example.com"
|
||||
backend = "127.0.0.1:8080"
|
||||
mode = "l7"
|
||||
tls_cert = "/nonexistent/cert.pem"
|
||||
tls_key = "/nonexistent/key.pem"
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for L7 route with nonexistent cert/key files")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadProxyProtocol(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.toml")
|
||||
|
||||
data := `
|
||||
[database]
|
||||
path = "/tmp/test.db"
|
||||
|
||||
[[listeners]]
|
||||
addr = ":443"
|
||||
proxy_protocol = true
|
||||
|
||||
[[listeners.routes]]
|
||||
hostname = "example.com"
|
||||
backend = "127.0.0.1:8443"
|
||||
send_proxy_protocol = true
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if !cfg.Listeners[0].ProxyProtocol {
|
||||
t.Fatal("expected proxy_protocol = true")
|
||||
}
|
||||
if !cfg.Listeners[0].Routes[0].SendProxyProtocol {
|
||||
t.Fatal("expected send_proxy_protocol = true")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ func TestIsEmpty(t *testing.T) {
|
||||
t.Fatal("expected empty database")
|
||||
}
|
||||
|
||||
if _, err := store.CreateListener(":443"); err != nil {
|
||||
if _, err := store.CreateListener(":443", false); err != nil {
|
||||
t.Fatalf("create listener: %v", err)
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ func TestIsEmpty(t *testing.T) {
|
||||
func TestListenerCRUD(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
id, err := store.CreateListener(":443")
|
||||
id, err := store.CreateListener(":443", false)
|
||||
if err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
@@ -75,6 +75,9 @@ func TestListenerCRUD(t *testing.T) {
|
||||
if listeners[0].Addr != ":443" {
|
||||
t.Fatalf("got addr %q, want %q", listeners[0].Addr, ":443")
|
||||
}
|
||||
if listeners[0].ProxyProtocol {
|
||||
t.Fatal("expected proxy_protocol = false")
|
||||
}
|
||||
|
||||
l, err := store.GetListenerByAddr(":443")
|
||||
if err != nil {
|
||||
@@ -97,13 +100,33 @@ func TestListenerCRUD(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenerProxyProtocol(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
id, err := store.CreateListener(":443", true)
|
||||
if err != nil {
|
||||
t.Fatalf("create: %v", err)
|
||||
}
|
||||
|
||||
l, err := store.GetListenerByAddr(":443")
|
||||
if err != nil {
|
||||
t.Fatalf("get by addr: %v", err)
|
||||
}
|
||||
if l.ID != id {
|
||||
t.Fatalf("got ID %d, want %d", l.ID, id)
|
||||
}
|
||||
if !l.ProxyProtocol {
|
||||
t.Fatal("expected proxy_protocol = true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenerDuplicateAddr(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
if _, err := store.CreateListener(":443"); err != nil {
|
||||
if _, err := store.CreateListener(":443", false); err != nil {
|
||||
t.Fatalf("first create: %v", err)
|
||||
}
|
||||
if _, err := store.CreateListener(":443"); err == nil {
|
||||
if _, err := store.CreateListener(":443", false); err == nil {
|
||||
t.Fatal("expected error for duplicate addr")
|
||||
}
|
||||
}
|
||||
@@ -111,12 +134,12 @@ func TestListenerDuplicateAddr(t *testing.T) {
|
||||
func TestRouteCRUD(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
listenerID, err := store.CreateListener(":443")
|
||||
listenerID, err := store.CreateListener(":443", false)
|
||||
if err != nil {
|
||||
t.Fatalf("create listener: %v", err)
|
||||
}
|
||||
|
||||
routeID, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:8443")
|
||||
routeID, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:8443", "l4", "", "", false, false)
|
||||
if err != nil {
|
||||
t.Fatalf("create route: %v", err)
|
||||
}
|
||||
@@ -134,6 +157,9 @@ func TestRouteCRUD(t *testing.T) {
|
||||
if routes[0].Hostname != "example.com" {
|
||||
t.Fatalf("got hostname %q, want %q", routes[0].Hostname, "example.com")
|
||||
}
|
||||
if routes[0].Mode != "l4" {
|
||||
t.Fatalf("got mode %q, want %q", routes[0].Mode, "l4")
|
||||
}
|
||||
|
||||
if err := store.DeleteRoute(listenerID, "example.com"); err != nil {
|
||||
t.Fatalf("delete route: %v", err)
|
||||
@@ -148,14 +174,51 @@ func TestRouteCRUD(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteL7Fields(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
listenerID, _ := store.CreateListener(":443", false)
|
||||
|
||||
_, err := store.CreateRoute(listenerID, "api.example.com", "127.0.0.1:8080", "l7",
|
||||
"/certs/api.crt", "/certs/api.key", false, true)
|
||||
if err != nil {
|
||||
t.Fatalf("create L7 route: %v", err)
|
||||
}
|
||||
|
||||
routes, err := store.ListRoutes(listenerID)
|
||||
if err != nil {
|
||||
t.Fatalf("list routes: %v", err)
|
||||
}
|
||||
if len(routes) != 1 {
|
||||
t.Fatalf("got %d routes, want 1", len(routes))
|
||||
}
|
||||
|
||||
r := routes[0]
|
||||
if r.Mode != "l7" {
|
||||
t.Fatalf("mode = %q, want %q", r.Mode, "l7")
|
||||
}
|
||||
if r.TLSCert != "/certs/api.crt" {
|
||||
t.Fatalf("tls_cert = %q, want %q", r.TLSCert, "/certs/api.crt")
|
||||
}
|
||||
if r.TLSKey != "/certs/api.key" {
|
||||
t.Fatalf("tls_key = %q, want %q", r.TLSKey, "/certs/api.key")
|
||||
}
|
||||
if r.BackendTLS {
|
||||
t.Fatal("expected backend_tls = false")
|
||||
}
|
||||
if !r.SendProxyProtocol {
|
||||
t.Fatal("expected send_proxy_protocol = true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteDuplicateHostname(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
listenerID, _ := store.CreateListener(":443")
|
||||
if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:8443"); err != nil {
|
||||
listenerID, _ := store.CreateListener(":443", false)
|
||||
if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:8443", "l4", "", "", false, false); err != nil {
|
||||
t.Fatalf("first create: %v", err)
|
||||
}
|
||||
if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:9443"); err == nil {
|
||||
if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:9443", "l4", "", "", false, false); err == nil {
|
||||
t.Fatal("expected error for duplicate hostname on same listener")
|
||||
}
|
||||
}
|
||||
@@ -163,9 +226,9 @@ func TestRouteDuplicateHostname(t *testing.T) {
|
||||
func TestRouteCascadeDelete(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
listenerID, _ := store.CreateListener(":443")
|
||||
store.CreateRoute(listenerID, "a.example.com", "127.0.0.1:8443")
|
||||
store.CreateRoute(listenerID, "b.example.com", "127.0.0.1:9443")
|
||||
listenerID, _ := store.CreateListener(":443", false)
|
||||
store.CreateRoute(listenerID, "a.example.com", "127.0.0.1:8443", "l4", "", "", false, false)
|
||||
store.CreateRoute(listenerID, "b.example.com", "127.0.0.1:9443", "l4", "", "", false, false)
|
||||
|
||||
if err := store.DeleteListener(listenerID); err != nil {
|
||||
t.Fatalf("delete listener: %v", err)
|
||||
@@ -237,14 +300,15 @@ func TestSeed(t *testing.T) {
|
||||
{
|
||||
Addr: ":443",
|
||||
Routes: []config.Route{
|
||||
{Hostname: "a.example.com", Backend: "127.0.0.1:8443"},
|
||||
{Hostname: "a.example.com", Backend: "127.0.0.1:8443", Mode: "l4"},
|
||||
{Hostname: "b.example.com", Backend: "127.0.0.1:9443"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Addr: ":8443",
|
||||
Addr: ":8443",
|
||||
ProxyProtocol: true,
|
||||
Routes: []config.Route{
|
||||
{Hostname: "c.example.com", Backend: "127.0.0.1:18443"},
|
||||
{Hostname: "c.example.com", Backend: "127.0.0.1:18443", Mode: "l4", SendProxyProtocol: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -266,6 +330,9 @@ func TestSeed(t *testing.T) {
|
||||
if len(dbListeners) != 2 {
|
||||
t.Fatalf("got %d listeners, want 2", len(dbListeners))
|
||||
}
|
||||
if !dbListeners[1].ProxyProtocol {
|
||||
t.Fatal("expected listener 2 proxy_protocol = true")
|
||||
}
|
||||
|
||||
routes, err := store.ListRoutes(dbListeners[0].ID)
|
||||
if err != nil {
|
||||
@@ -275,6 +342,25 @@ func TestSeed(t *testing.T) {
|
||||
t.Fatalf("got %d routes for listener 0, want 2", len(routes))
|
||||
}
|
||||
|
||||
// Verify mode defaults to "l4" even when empty in config.
|
||||
for _, r := range routes {
|
||||
if r.Mode != "l4" {
|
||||
t.Fatalf("route %q mode = %q, want %q", r.Hostname, r.Mode, "l4")
|
||||
}
|
||||
}
|
||||
|
||||
// Verify send_proxy_protocol on listener 2's route.
|
||||
routes2, err := store.ListRoutes(dbListeners[1].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("list routes listener 2: %v", err)
|
||||
}
|
||||
if len(routes2) != 1 {
|
||||
t.Fatalf("got %d routes for listener 1, want 1", len(routes2))
|
||||
}
|
||||
if !routes2[0].SendProxyProtocol {
|
||||
t.Fatal("expected send_proxy_protocol = true on listener 2 route")
|
||||
}
|
||||
|
||||
rules, err := store.ListFirewallRules()
|
||||
if err != nil {
|
||||
t.Fatalf("list firewall rules: %v", err)
|
||||
@@ -287,7 +373,7 @@ func TestSeed(t *testing.T) {
|
||||
func TestSnapshot(t *testing.T) {
|
||||
store := openTestDB(t)
|
||||
|
||||
store.CreateListener(":443")
|
||||
store.CreateListener(":443", false)
|
||||
|
||||
dest := filepath.Join(t.TempDir(), "backup.db")
|
||||
if err := store.Snapshot(dest); err != nil {
|
||||
@@ -329,3 +415,57 @@ func TestDeleteNonexistent(t *testing.T) {
|
||||
t.Fatal("expected error deleting nonexistent firewall rule")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMigrationV2Upgrade verifies that migration v2 adds new columns
|
||||
// to an existing v1 database without data loss.
|
||||
func TestMigrationV2Upgrade(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
store, err := Open(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("open: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { store.Close() })
|
||||
|
||||
// Run full migrations (v1 + v2).
|
||||
if err := store.Migrate(); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
|
||||
// Insert a listener and route with defaults to verify new columns work.
|
||||
lid, err := store.CreateListener(":443", false)
|
||||
if err != nil {
|
||||
t.Fatalf("create listener: %v", err)
|
||||
}
|
||||
|
||||
_, err = store.CreateRoute(lid, "test.example.com", "127.0.0.1:8443", "l4", "", "", false, false)
|
||||
if err != nil {
|
||||
t.Fatalf("create route: %v", err)
|
||||
}
|
||||
|
||||
// Read back and verify defaults.
|
||||
routes, err := store.ListRoutes(lid)
|
||||
if err != nil {
|
||||
t.Fatalf("list routes: %v", err)
|
||||
}
|
||||
if len(routes) != 1 {
|
||||
t.Fatalf("got %d routes, want 1", len(routes))
|
||||
}
|
||||
r := routes[0]
|
||||
if r.Mode != "l4" {
|
||||
t.Fatalf("mode = %q, want %q", r.Mode, "l4")
|
||||
}
|
||||
if r.TLSCert != "" || r.TLSKey != "" {
|
||||
t.Fatalf("expected empty cert/key, got cert=%q key=%q", r.TLSCert, r.TLSKey)
|
||||
}
|
||||
if r.BackendTLS || r.SendProxyProtocol {
|
||||
t.Fatal("expected false for backend_tls and send_proxy_protocol")
|
||||
}
|
||||
|
||||
listeners, err := store.ListListeners()
|
||||
if err != nil {
|
||||
t.Fatalf("list listeners: %v", err)
|
||||
}
|
||||
if listeners[0].ProxyProtocol {
|
||||
t.Fatal("expected proxy_protocol = false")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,13 +4,14 @@ import "fmt"
|
||||
|
||||
// Listener is a database listener record.
|
||||
type Listener struct {
|
||||
ID int64
|
||||
Addr string
|
||||
ID int64
|
||||
Addr string
|
||||
ProxyProtocol bool
|
||||
}
|
||||
|
||||
// ListListeners returns all listeners.
|
||||
func (s *Store) ListListeners() ([]Listener, error) {
|
||||
rows, err := s.db.Query("SELECT id, addr FROM listeners ORDER BY id")
|
||||
rows, err := s.db.Query("SELECT id, addr, proxy_protocol FROM listeners ORDER BY id")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("querying listeners: %w", err)
|
||||
}
|
||||
@@ -19,7 +20,7 @@ func (s *Store) ListListeners() ([]Listener, error) {
|
||||
var listeners []Listener
|
||||
for rows.Next() {
|
||||
var l Listener
|
||||
if err := rows.Scan(&l.ID, &l.Addr); err != nil {
|
||||
if err := rows.Scan(&l.ID, &l.Addr, &l.ProxyProtocol); err != nil {
|
||||
return nil, fmt.Errorf("scanning listener: %w", err)
|
||||
}
|
||||
listeners = append(listeners, l)
|
||||
@@ -28,8 +29,11 @@ func (s *Store) ListListeners() ([]Listener, error) {
|
||||
}
|
||||
|
||||
// CreateListener inserts a listener and returns its ID.
|
||||
func (s *Store) CreateListener(addr string) (int64, error) {
|
||||
result, err := s.db.Exec("INSERT INTO listeners (addr) VALUES (?)", addr)
|
||||
func (s *Store) CreateListener(addr string, proxyProtocol bool) (int64, error) {
|
||||
result, err := s.db.Exec(
|
||||
"INSERT INTO listeners (addr, proxy_protocol) VALUES (?, ?)",
|
||||
addr, proxyProtocol,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("inserting listener: %w", err)
|
||||
}
|
||||
@@ -52,8 +56,8 @@ func (s *Store) DeleteListener(id int64) error {
|
||||
// GetListenerByAddr returns a listener by its address.
|
||||
func (s *Store) GetListenerByAddr(addr string) (Listener, error) {
|
||||
var l Listener
|
||||
err := s.db.QueryRow("SELECT id, addr FROM listeners WHERE addr = ?", addr).
|
||||
Scan(&l.ID, &l.Addr)
|
||||
err := s.db.QueryRow("SELECT id, addr, proxy_protocol FROM listeners WHERE addr = ?", addr).
|
||||
Scan(&l.ID, &l.Addr, &l.ProxyProtocol)
|
||||
if err != nil {
|
||||
return Listener{}, fmt.Errorf("querying listener by addr %q: %w", addr, err)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ type migration struct {
|
||||
|
||||
var migrations = []migration{
|
||||
{1, "create_core_tables", migrate001CreateCoreTables},
|
||||
{2, "add_proxy_protocol_and_l7_fields", migrate002AddL7Fields},
|
||||
}
|
||||
|
||||
// Migrate runs all unapplied migrations sequentially.
|
||||
@@ -91,3 +92,21 @@ func migrate001CreateCoreTables(tx *sql.Tx) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func migrate002AddL7Fields(tx *sql.Tx) error {
|
||||
stmts := []string{
|
||||
`ALTER TABLE listeners ADD COLUMN proxy_protocol INTEGER NOT NULL DEFAULT 0`,
|
||||
`ALTER TABLE routes ADD COLUMN mode TEXT NOT NULL DEFAULT 'l4'`,
|
||||
`ALTER TABLE routes ADD COLUMN tls_cert TEXT NOT NULL DEFAULT ''`,
|
||||
`ALTER TABLE routes ADD COLUMN tls_key TEXT NOT NULL DEFAULT ''`,
|
||||
`ALTER TABLE routes ADD COLUMN backend_tls INTEGER NOT NULL DEFAULT 0`,
|
||||
`ALTER TABLE routes ADD COLUMN send_proxy_protocol INTEGER NOT NULL DEFAULT 0`,
|
||||
}
|
||||
|
||||
for _, stmt := range stmts {
|
||||
if _, err := tx.Exec(stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,16 +4,22 @@ import "fmt"
|
||||
|
||||
// Route is a database route record.
|
||||
type Route struct {
|
||||
ID int64
|
||||
ListenerID int64
|
||||
Hostname string
|
||||
Backend string
|
||||
ID int64
|
||||
ListenerID int64
|
||||
Hostname string
|
||||
Backend string
|
||||
Mode string // "l4" or "l7"
|
||||
TLSCert string
|
||||
TLSKey string
|
||||
BackendTLS bool
|
||||
SendProxyProtocol bool
|
||||
}
|
||||
|
||||
// ListRoutes returns all routes for a listener.
|
||||
func (s *Store) ListRoutes(listenerID int64) ([]Route, error) {
|
||||
rows, err := s.db.Query(
|
||||
"SELECT id, listener_id, hostname, backend FROM routes WHERE listener_id = ? ORDER BY hostname",
|
||||
`SELECT id, listener_id, hostname, backend, mode, tls_cert, tls_key, backend_tls, send_proxy_protocol
|
||||
FROM routes WHERE listener_id = ? ORDER BY hostname`,
|
||||
listenerID,
|
||||
)
|
||||
if err != nil {
|
||||
@@ -24,7 +30,8 @@ func (s *Store) ListRoutes(listenerID int64) ([]Route, error) {
|
||||
var routes []Route
|
||||
for rows.Next() {
|
||||
var r Route
|
||||
if err := rows.Scan(&r.ID, &r.ListenerID, &r.Hostname, &r.Backend); err != nil {
|
||||
if err := rows.Scan(&r.ID, &r.ListenerID, &r.Hostname, &r.Backend,
|
||||
&r.Mode, &r.TLSCert, &r.TLSKey, &r.BackendTLS, &r.SendProxyProtocol); err != nil {
|
||||
return nil, fmt.Errorf("scanning route: %w", err)
|
||||
}
|
||||
routes = append(routes, r)
|
||||
@@ -33,10 +40,11 @@ func (s *Store) ListRoutes(listenerID int64) ([]Route, error) {
|
||||
}
|
||||
|
||||
// CreateRoute inserts a route and returns its ID.
|
||||
func (s *Store) CreateRoute(listenerID int64, hostname, backend string) (int64, error) {
|
||||
func (s *Store) CreateRoute(listenerID int64, hostname, backend, mode, tlsCert, tlsKey string, backendTLS, sendProxyProtocol bool) (int64, error) {
|
||||
result, err := s.db.Exec(
|
||||
"INSERT INTO routes (listener_id, hostname, backend) VALUES (?, ?, ?)",
|
||||
listenerID, hostname, backend,
|
||||
`INSERT INTO routes (listener_id, hostname, backend, mode, tls_cert, tls_key, backend_tls, send_proxy_protocol)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
listenerID, hostname, backend, mode, tlsCert, tlsKey, backendTLS, sendProxyProtocol,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("inserting route: %w", err)
|
||||
|
||||
@@ -17,16 +17,25 @@ func (s *Store) Seed(listeners []config.Listener, fw config.Firewall) error {
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, l := range listeners {
|
||||
result, err := tx.Exec("INSERT INTO listeners (addr) VALUES (?)", l.Addr)
|
||||
result, err := tx.Exec(
|
||||
"INSERT INTO listeners (addr, proxy_protocol) VALUES (?, ?)",
|
||||
l.Addr, l.ProxyProtocol,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("seeding listener %q: %w", l.Addr, err)
|
||||
}
|
||||
listenerID, _ := result.LastInsertId()
|
||||
|
||||
for _, r := range l.Routes {
|
||||
mode := r.Mode
|
||||
if mode == "" {
|
||||
mode = "l4"
|
||||
}
|
||||
_, err := tx.Exec(
|
||||
"INSERT INTO routes (listener_id, hostname, backend) VALUES (?, ?, ?)",
|
||||
`INSERT INTO routes (listener_id, hostname, backend, mode, tls_cert, tls_key, backend_tls, send_proxy_protocol)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
listenerID, strings.ToLower(r.Hostname), r.Backend,
|
||||
mode, r.TLSCert, r.TLSKey, r.BackendTLS, r.SendProxyProtocol,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("seeding route %q on listener %q: %w", r.Hostname, l.Addr, err)
|
||||
|
||||
@@ -88,10 +88,10 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (
|
||||
resp := &pb.ListRoutesResponse{
|
||||
ListenerAddr: ls.Addr,
|
||||
}
|
||||
for hostname, backend := range routes {
|
||||
for hostname, route := range routes {
|
||||
resp.Routes = append(resp.Routes, &pb.Route{
|
||||
Hostname: hostname,
|
||||
Backend: backend,
|
||||
Backend: route.Backend,
|
||||
})
|
||||
}
|
||||
return resp, nil
|
||||
@@ -119,11 +119,11 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.
|
||||
hostname := strings.ToLower(req.Route.Hostname)
|
||||
|
||||
// Write-through: DB first, then memory.
|
||||
if _, err := a.store.CreateRoute(ls.ID, hostname, req.Route.Backend); err != nil {
|
||||
if _, err := a.store.CreateRoute(ls.ID, hostname, req.Route.Backend, "l4", "", "", false, false); err != nil {
|
||||
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
|
||||
}
|
||||
|
||||
if err := ls.AddRoute(hostname, req.Route.Backend); err != nil {
|
||||
if err := ls.AddRoute(hostname, server.RouteInfo{Backend: req.Route.Backend, Mode: "l4"}); err != nil {
|
||||
// DB succeeded but memory failed (should not happen since DB enforces uniqueness).
|
||||
a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err)
|
||||
}
|
||||
|
||||
@@ -87,14 +87,18 @@ func setup(t *testing.T) *testEnv {
|
||||
if err != nil {
|
||||
t.Fatalf("list routes: %v", err)
|
||||
}
|
||||
routes := make(map[string]string, len(dbRoutes))
|
||||
routes := make(map[string]server.RouteInfo, len(dbRoutes))
|
||||
for _, r := range dbRoutes {
|
||||
routes[r.Hostname] = r.Backend
|
||||
routes[r.Hostname] = server.RouteInfo{
|
||||
Backend: r.Backend,
|
||||
Mode: r.Mode,
|
||||
}
|
||||
}
|
||||
listenerData = append(listenerData, server.ListenerData{
|
||||
ID: l.ID,
|
||||
Addr: l.Addr,
|
||||
Routes: routes,
|
||||
ID: l.ID,
|
||||
Addr: l.Addr,
|
||||
ProxyProtocol: l.ProxyProtocol,
|
||||
Routes: routes,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -17,11 +17,22 @@ import (
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
|
||||
)
|
||||
|
||||
// RouteInfo holds the full configuration for a single route.
|
||||
type RouteInfo struct {
|
||||
Backend string
|
||||
Mode string // "l4" or "l7"
|
||||
TLSCert string
|
||||
TLSKey string
|
||||
BackendTLS bool
|
||||
SendProxyProtocol bool
|
||||
}
|
||||
|
||||
// ListenerState holds the mutable state for a single proxy listener.
|
||||
type ListenerState struct {
|
||||
ID int64 // database primary key
|
||||
Addr string
|
||||
routes map[string]string // lowercase hostname → backend addr
|
||||
ProxyProtocol bool
|
||||
routes map[string]RouteInfo // lowercase hostname → route info
|
||||
mu sync.RWMutex
|
||||
ActiveConnections atomic.Int64
|
||||
activeConns map[net.Conn]struct{} // tracked for forced shutdown
|
||||
@@ -29,11 +40,11 @@ type ListenerState struct {
|
||||
}
|
||||
|
||||
// Routes returns a snapshot of the listener's route table.
|
||||
func (ls *ListenerState) Routes() map[string]string {
|
||||
func (ls *ListenerState) Routes() map[string]RouteInfo {
|
||||
ls.mu.RLock()
|
||||
defer ls.mu.RUnlock()
|
||||
|
||||
m := make(map[string]string, len(ls.routes))
|
||||
m := make(map[string]RouteInfo, len(ls.routes))
|
||||
for k, v := range ls.routes {
|
||||
m[k] = v
|
||||
}
|
||||
@@ -42,7 +53,7 @@ func (ls *ListenerState) Routes() map[string]string {
|
||||
|
||||
// AddRoute adds a route to the listener. Returns an error if the hostname
|
||||
// already exists.
|
||||
func (ls *ListenerState) AddRoute(hostname, backend string) error {
|
||||
func (ls *ListenerState) AddRoute(hostname string, info RouteInfo) error {
|
||||
key := strings.ToLower(hostname)
|
||||
|
||||
ls.mu.Lock()
|
||||
@@ -51,7 +62,7 @@ func (ls *ListenerState) AddRoute(hostname, backend string) error {
|
||||
if _, ok := ls.routes[key]; ok {
|
||||
return fmt.Errorf("route %q already exists", hostname)
|
||||
}
|
||||
ls.routes[key] = backend
|
||||
ls.routes[key] = info
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -70,19 +81,20 @@ func (ls *ListenerState) RemoveRoute(hostname string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ls *ListenerState) lookupRoute(hostname string) (string, bool) {
|
||||
func (ls *ListenerState) lookupRoute(hostname string) (RouteInfo, bool) {
|
||||
ls.mu.RLock()
|
||||
defer ls.mu.RUnlock()
|
||||
|
||||
backend, ok := ls.routes[hostname]
|
||||
return backend, ok
|
||||
info, ok := ls.routes[hostname]
|
||||
return info, ok
|
||||
}
|
||||
|
||||
// ListenerData holds the data needed to construct a ListenerState.
|
||||
type ListenerData struct {
|
||||
ID int64
|
||||
Addr string
|
||||
Routes map[string]string // lowercase hostname → backend
|
||||
ID int64
|
||||
Addr string
|
||||
ProxyProtocol bool
|
||||
Routes map[string]RouteInfo // lowercase hostname → route info
|
||||
}
|
||||
|
||||
// Server is the mc-proxy server. It manages listeners, firewall evaluation,
|
||||
@@ -102,10 +114,11 @@ func New(cfg *config.Config, fw *firewall.Firewall, listenerData []ListenerData,
|
||||
var listeners []*ListenerState
|
||||
for _, ld := range listenerData {
|
||||
listeners = append(listeners, &ListenerState{
|
||||
ID: ld.ID,
|
||||
Addr: ld.Addr,
|
||||
routes: ld.Routes,
|
||||
activeConns: make(map[net.Conn]struct{}),
|
||||
ID: ld.ID,
|
||||
Addr: ld.Addr,
|
||||
ProxyProtocol: ld.ProxyProtocol,
|
||||
routes: ld.Routes,
|
||||
activeConns: make(map[net.Conn]struct{}),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -264,20 +277,32 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat
|
||||
return
|
||||
}
|
||||
|
||||
backend, ok := ls.lookupRoute(hostname)
|
||||
route, ok := ls.lookupRoute(hostname)
|
||||
if !ok {
|
||||
s.logger.Debug("no route for hostname", "addr", addr, "hostname", hostname)
|
||||
return
|
||||
}
|
||||
|
||||
backendConn, err := net.DialTimeout("tcp", backend, s.cfg.Proxy.ConnectTimeout.Duration)
|
||||
// Dispatch based on route mode. L7 will be implemented in a later phase.
|
||||
switch route.Mode {
|
||||
case "l7":
|
||||
s.logger.Error("L7 mode not yet implemented", "hostname", hostname)
|
||||
return
|
||||
default:
|
||||
s.handleL4(ctx, conn, ls, addr, hostname, route, peeked)
|
||||
}
|
||||
}
|
||||
|
||||
// handleL4 handles an L4 (passthrough) connection.
|
||||
func (s *Server) handleL4(ctx context.Context, conn net.Conn, _ *ListenerState, addr netip.Addr, hostname string, route RouteInfo, peeked []byte) {
|
||||
backendConn, err := net.DialTimeout("tcp", route.Backend, s.cfg.Proxy.ConnectTimeout.Duration)
|
||||
if err != nil {
|
||||
s.logger.Error("backend dial failed", "hostname", hostname, "backend", backend, "error", err)
|
||||
s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err)
|
||||
return
|
||||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", backend)
|
||||
s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", route.Backend)
|
||||
|
||||
result, err := proxy.Relay(ctx, conn, backendConn, peeked, s.cfg.Proxy.IdleTimeout.Duration)
|
||||
if err != nil && ctx.Err() == nil {
|
||||
|
||||
@@ -14,6 +14,11 @@ import (
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
||||
)
|
||||
|
||||
// l4Route creates a RouteInfo for an L4 passthrough route.
|
||||
func l4Route(backend string) RouteInfo {
|
||||
return RouteInfo{Backend: backend, Mode: "l4"}
|
||||
}
|
||||
|
||||
// echoServer accepts one connection, copies everything back, then closes.
|
||||
func echoServer(t *testing.T, ln net.Listener) {
|
||||
t.Helper()
|
||||
@@ -85,8 +90,8 @@ func TestProxyRoundTrip(t *testing.T) {
|
||||
{
|
||||
ID: 1,
|
||||
Addr: proxyAddr,
|
||||
Routes: map[string]string{
|
||||
"echo.test": backendLn.Addr().String(),
|
||||
Routes: map[string]RouteInfo{
|
||||
"echo.test": l4Route(backendLn.Addr().String()),
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -141,8 +146,8 @@ func TestNoRouteResets(t *testing.T) {
|
||||
{
|
||||
ID: 1,
|
||||
Addr: proxyAddr,
|
||||
Routes: map[string]string{
|
||||
"other.test": "127.0.0.1:1", // exists but won't match
|
||||
Routes: map[string]RouteInfo{
|
||||
"other.test": l4Route("127.0.0.1:1"), // exists but won't match
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -212,8 +217,8 @@ func TestFirewallBlocks(t *testing.T) {
|
||||
{
|
||||
ID: 1,
|
||||
Addr: proxyAddr,
|
||||
Routes: map[string]string{
|
||||
"echo.test": backendLn.Addr().String(),
|
||||
Routes: map[string]RouteInfo{
|
||||
"echo.test": l4Route(backendLn.Addr().String()),
|
||||
},
|
||||
},
|
||||
}, logger, "test")
|
||||
@@ -267,7 +272,7 @@ func TestNotTLSResets(t *testing.T) {
|
||||
{
|
||||
ID: 1,
|
||||
Addr: proxyAddr,
|
||||
Routes: map[string]string{"x.test": "127.0.0.1:1"},
|
||||
Routes: map[string]RouteInfo{"x.test": l4Route("127.0.0.1:1")},
|
||||
},
|
||||
})
|
||||
|
||||
@@ -325,8 +330,8 @@ func TestConnectionTracking(t *testing.T) {
|
||||
{
|
||||
ID: 1,
|
||||
Addr: proxyAddr,
|
||||
Routes: map[string]string{
|
||||
"conn.test": backendLn.Addr().String(),
|
||||
Routes: map[string]RouteInfo{
|
||||
"conn.test": l4Route(backendLn.Addr().String()),
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -432,8 +437,8 @@ func TestMultipleListeners(t *testing.T) {
|
||||
ln2.Close()
|
||||
|
||||
srv := newTestServer(t, []ListenerData{
|
||||
{ID: 1, Addr: addr1, Routes: map[string]string{"svc.test": backendA.Addr().String()}},
|
||||
{ID: 2, Addr: addr2, Routes: map[string]string{"svc.test": backendB.Addr().String()}},
|
||||
{ID: 1, Addr: addr1, Routes: map[string]RouteInfo{"svc.test": l4Route(backendA.Addr().String())}},
|
||||
{ID: 2, Addr: addr2, Routes: map[string]RouteInfo{"svc.test": l4Route(backendB.Addr().String())}},
|
||||
})
|
||||
|
||||
stop := startAndStop(t, srv)
|
||||
@@ -500,8 +505,8 @@ func TestCaseInsensitiveRouting(t *testing.T) {
|
||||
{
|
||||
ID: 1,
|
||||
Addr: proxyAddr,
|
||||
Routes: map[string]string{
|
||||
"echo.test": backendLn.Addr().String(),
|
||||
Routes: map[string]RouteInfo{
|
||||
"echo.test": l4Route(backendLn.Addr().String()),
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -549,8 +554,8 @@ func TestBackendUnreachable(t *testing.T) {
|
||||
{
|
||||
ID: 1,
|
||||
Addr: proxyAddr,
|
||||
Routes: map[string]string{
|
||||
"dead.test": deadAddr,
|
||||
Routes: map[string]RouteInfo{
|
||||
"dead.test": l4Route(deadAddr),
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -612,7 +617,7 @@ func TestGracefulShutdown(t *testing.T) {
|
||||
}
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
srv := New(cfg, fw, []ListenerData{
|
||||
{ID: 1, Addr: proxyAddr, Routes: map[string]string{"hold.test": backendLn.Addr().String()}},
|
||||
{ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{"hold.test": l4Route(backendLn.Addr().String())}},
|
||||
}, logger, "test")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -651,18 +656,18 @@ func TestListenerStateRoutes(t *testing.T) {
|
||||
ls := &ListenerState{
|
||||
ID: 1,
|
||||
Addr: ":443",
|
||||
routes: map[string]string{
|
||||
"a.test": "127.0.0.1:1",
|
||||
routes: map[string]RouteInfo{
|
||||
"a.test": l4Route("127.0.0.1:1"),
|
||||
},
|
||||
}
|
||||
|
||||
// AddRoute
|
||||
if err := ls.AddRoute("b.test", "127.0.0.1:2"); err != nil {
|
||||
if err := ls.AddRoute("b.test", l4Route("127.0.0.1:2")); err != nil {
|
||||
t.Fatalf("AddRoute: %v", err)
|
||||
}
|
||||
|
||||
// AddRoute duplicate
|
||||
if err := ls.AddRoute("b.test", "127.0.0.1:3"); err == nil {
|
||||
if err := ls.AddRoute("b.test", l4Route("127.0.0.1:3")); err == nil {
|
||||
t.Fatal("expected error for duplicate route")
|
||||
}
|
||||
|
||||
@@ -686,8 +691,8 @@ func TestListenerStateRoutes(t *testing.T) {
|
||||
if len(routes) != 1 {
|
||||
t.Fatalf("expected 1 route, got %d", len(routes))
|
||||
}
|
||||
if routes["b.test"] != "127.0.0.1:2" {
|
||||
t.Fatalf("expected b.test → 127.0.0.1:2, got %q", routes["b.test"])
|
||||
if routes["b.test"].Backend != "127.0.0.1:2" {
|
||||
t.Fatalf("expected b.test → 127.0.0.1:2, got %q", routes["b.test"].Backend)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user