From ffc31f7d5523944da8e5593625abe63b59bee5a5 Mon Sep 17 00:00:00 2001 From: Kyle Isom Date: Wed, 25 Mar 2026 18:05:25 -0700 Subject: [PATCH] Add Prometheus metrics for connections, firewall, L7, and bytes transferred Instrument mc-proxy with prometheus/client_golang. New internal/metrics/ package defines counters, gauges, and histograms for connection totals, active connections, firewall blocks by reason, backend dial latency, bytes transferred, L7 HTTP status codes, and L7 policy blocks. Optional [metrics] config section starts a scrape endpoint. Firewall gains BlockedWithReason() to report block cause. L7 handler wraps ResponseWriter to record status codes per hostname. Co-Authored-By: Claude Opus 4.6 (1M context) --- ARCHITECTURE.md | 3 - CLAUDE.md | 1 + PROGRESS.md | 23 +++--- cmd/mc-proxy/server.go | 11 +++ go.mod | 8 ++ go.sum | 14 ++++ internal/config/config.go | 11 +++ internal/config/config_test.go | 51 ++++++++++++ internal/firewall/firewall.go | 17 ++-- internal/firewall/firewall_test.go | 41 ++++++++++ internal/l7/policy.go | 6 +- internal/l7/policy_test.go | 12 +-- internal/l7/serve.go | 26 ++++++- internal/metrics/metrics.go | 95 ++++++++++++++++++++++ internal/metrics/metrics_test.go | 121 +++++++++++++++++++++++++++++ internal/server/server.go | 31 +++++++- 16 files changed, 439 insertions(+), 32 deletions(-) create mode 100644 internal/metrics/metrics.go create mode 100644 internal/metrics/metrics_test.go diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 3ba2b12..b87e199 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -710,9 +710,6 @@ Items are listed roughly in priority order: | Item | Description | |------|-------------| | **ACME integration** | Automatic certificate provisioning via Let's Encrypt for L7 routes, removing the need for manual cert management. | -| **L7 policies** | User-agent blocking, header-based routing, request rate limiting per endpoint. Requires L7 mode. | | **MCP integration** | Wire the gRPC admin API into the Metacircular Control Plane for centralized management. | | **Connection pooling** | Pool backend connections for L7 routes to reduce connection setup overhead under high request volume. | -| **Per-listener connection limits** | Cap maximum concurrent connections per listener. | -| **Metrics** | Prometheus-compatible metrics: connections per listener, firewall blocks by rule, backend dial latency, active connections, HTTP status code distributions. | | **Metacrypt key storage** | Store L7 TLS private keys in metacrypt rather than on the filesystem. | diff --git a/CLAUDE.md b/CLAUDE.md index 5a77737..ae24cde 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -49,6 +49,7 @@ go test ./internal/sni -run TestExtract - `internal/l7/` — L7 TLS termination, `prefixConn`, HTTP/2 reverse proxy with h2c backend transport - `internal/server/` — orchestrates listeners → PROXY protocol → firewall → SNI → route → L4/L7 dispatch; per-listener state with connection tracking - `internal/grpcserver/` — gRPC admin API: route/firewall CRUD, status, write-through to DB +- `internal/metrics/` — Prometheus metric definitions and HTTP server; optional `[metrics]` config section - `proto/mc_proxy/v1/` — protobuf definitions; `gen/mc_proxy/v1/` has generated code ## Signals diff --git a/PROGRESS.md b/PROGRESS.md index dc62595..adf5aa9 100644 --- a/PROGRESS.md +++ b/PROGRESS.md @@ -71,20 +71,21 @@ proceeds. Each item is marked: ## Phase 8: Prometheus Metrics -- [ ] 8.1 Dependency: add `prometheus/client_golang` -- [ ] 8.2 Config: `Metrics` section (`addr`, `path`) -- [ ] 8.3 Package: `internal/metrics/` definitions and HTTP server -- [ ] 8.4 Instrumentation: connections, firewall, dial latency, bytes, HTTP status, policy blocks -- [ ] 8.5 Firewall: `BlockedWithReason()` method -- [ ] 8.6 L7: status recording on ResponseWriter -- [ ] 8.7 Startup: conditionally start metrics server -- [ ] 8.8 Tests: metric sanity, server endpoint, `BlockedWithReason` +- [x] 8.1 Dependency: add `prometheus/client_golang` +- [x] 8.2 Config: `Metrics` section (`addr`, `path`) +- [x] 8.3 Package: `internal/metrics/` definitions and HTTP server +- [x] 8.4 Instrumentation: connections, firewall, dial latency, bytes, HTTP status, policy blocks +- [x] 8.5 Firewall: `BlockedWithReason()` method +- [x] 8.6 L7: status recording on ResponseWriter +- [x] 8.7 Startup: conditionally start metrics server +- [x] 8.8 Tests: metric sanity, server endpoint, `BlockedWithReason` --- ## Current State -Phases 1-6 complete. Per-listener connection limits are implemented and -tested. L7 policies and Prometheus metrics are next. +Phases 1-8 complete. Prometheus metrics are instrumented across +connections, firewall, dial latency, bytes transferred, L7 HTTP status +codes, and L7 policy blocks. -`go vet` and `go test` pass across all 13 packages. +`go vet` and `go test` pass across all 14 packages. diff --git a/cmd/mc-proxy/server.go b/cmd/mc-proxy/server.go index e1dc26a..30d9d2f 100644 --- a/cmd/mc-proxy/server.go +++ b/cmd/mc-proxy/server.go @@ -15,6 +15,7 @@ import ( "git.wntrmute.dev/kyle/mc-proxy/internal/db" "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/kyle/mc-proxy/internal/grpcserver" + "git.wntrmute.dev/kyle/mc-proxy/internal/metrics" "git.wntrmute.dev/kyle/mc-proxy/internal/server" ) @@ -108,6 +109,16 @@ func serverCmd() *cobra.Command { } }() + // Start Prometheus metrics server if configured. + if cfg.Metrics.Addr != "" { + logger.Info("metrics server listening", "addr", cfg.Metrics.Addr, "path", cfg.Metrics.Path) + go func() { + if err := metrics.ListenAndServe(ctx, cfg.Metrics.Addr, cfg.Metrics.Path); err != nil { + logger.Error("metrics server error", "error", err) + } + }() + } + logger.Info("mc-proxy starting", "version", version) return srv.Run(ctx) }, diff --git a/go.mod b/go.mod index 397ad34..62214e6 100644 --- a/go.mod +++ b/go.mod @@ -14,14 +14,22 @@ require ( replace git.wntrmute.dev/kyle/mcdsl => /home/kyle/src/metacircular/mcdsl require ( + github.com/beorn7/perks v1.0.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect github.com/pelletier/go-toml/v2 v2.3.0 // indirect + github.com/prometheus/client_golang v1.23.2 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/common v0.66.1 // indirect + github.com/prometheus/procfs v0.16.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/spf13/pflag v1.0.9 // indirect + go.yaml.in/yaml/v2 v2.4.2 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.32.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect diff --git a/go.sum b/go.sum index 4646df4..078d4cc 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= @@ -23,6 +25,8 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE= @@ -31,6 +35,14 @@ github.com/pelletier/go-toml/v2 v2.3.0 h1:k59bC/lIZREW0/iVaQR8nDHxVq8OVlIzYCOJf4 github.com/pelletier/go-toml/v2 v2.3.0/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= +github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= +github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= +github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -52,6 +64,8 @@ go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2W go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= +go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= +go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= diff --git a/internal/config/config.go b/internal/config/config.go index fee552b..1739ddb 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -22,9 +22,16 @@ type Config struct { GRPC GRPC `toml:"grpc"` Firewall Firewall `toml:"firewall"` Proxy Proxy `toml:"proxy"` + Metrics Metrics `toml:"metrics"` Log Log `toml:"log"` } +// Metrics holds the Prometheus metrics endpoint configuration. +type Metrics struct { + Addr string `toml:"addr"` // e.g. "127.0.0.1:9090" + Path string `toml:"path"` // e.g. "/metrics" (default) +} + // Database holds the database configuration. type Database struct { Path string `toml:"path"` @@ -215,6 +222,10 @@ func (c *Config) validate() error { } } + if c.Metrics.Addr != "" && c.Metrics.Path != "" && !strings.HasPrefix(c.Metrics.Path, "/") { + return fmt.Errorf("metrics.path must start with \"/\"") + } + if c.Proxy.ConnectTimeout.Duration < 0 { return fmt.Errorf("proxy.connect_timeout must not be negative") } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 1c016c8..177dbdc 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -541,3 +541,54 @@ proxy_protocol = true t.Fatal("expected send_proxy_protocol = true") } } + +func TestLoadMetricsConfig(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[database] +path = "/tmp/test.db" + +[metrics] +addr = "127.0.0.1:9090" +path = "/metrics" +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + cfg, err := Load(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.Metrics.Addr != "127.0.0.1:9090" { + t.Fatalf("got metrics.addr %q, want %q", cfg.Metrics.Addr, "127.0.0.1:9090") + } + if cfg.Metrics.Path != "/metrics" { + t.Fatalf("got metrics.path %q, want %q", cfg.Metrics.Path, "/metrics") + } +} + +func TestValidateMetricsInvalidPath(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[database] +path = "/tmp/test.db" + +[metrics] +addr = "127.0.0.1:9090" +path = "no-slash" +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for metrics.path without leading slash") + } +} diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index 931f8ef..187a9ba 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -71,18 +71,25 @@ func New(geoIPPath string, ips, cidrs, countries []string, rateLimit int64, rate // Blocked returns true if the given address should be blocked. func (f *Firewall) Blocked(addr netip.Addr) bool { + blocked, _ := f.BlockedWithReason(addr) + return blocked +} + +// BlockedWithReason returns whether the address is blocked and the reason. +// Possible reasons: "ip", "cidr", "country", "rate_limit", or "" if not blocked. +func (f *Firewall) BlockedWithReason(addr netip.Addr) (bool, string) { addr = addr.Unmap() f.mu.RLock() defer f.mu.RUnlock() if _, ok := f.blockedIPs[addr]; ok { - return true + return true, "ip" } for _, prefix := range f.blockedCIDRs { if prefix.Contains(addr) { - return true + return true, "cidr" } } @@ -90,7 +97,7 @@ func (f *Firewall) Blocked(addr netip.Addr) bool { var record geoIPRecord if err := f.geoDB.Lookup(addr.AsSlice(), &record); err == nil { if _, ok := f.blockedCountries[record.Country.ISOCode]; ok { - return true + return true, "country" } } } @@ -98,10 +105,10 @@ func (f *Firewall) Blocked(addr netip.Addr) bool { // Rate limiting is checked after blocklist — no point tracking state // for already-blocked IPs. if f.rl != nil && !f.rl.Allow(addr) { - return true + return true, "rate_limit" } - return false + return false, "" } // AddIP adds an IP address to the blocklist. diff --git a/internal/firewall/firewall_test.go b/internal/firewall/firewall_test.go index 5ef38d7..6a83cdc 100644 --- a/internal/firewall/firewall_test.go +++ b/internal/firewall/firewall_test.go @@ -170,6 +170,47 @@ func TestRateLimitBlocklistFirst(t *testing.T) { } } +func TestBlockedWithReason(t *testing.T) { + fw, err := New("", []string{"10.0.0.1"}, []string{"192.168.0.0/16"}, nil, 2, time.Minute) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer fw.Close() + + tests := []struct { + addr string + wantBlock bool + wantReason string + }{ + {"10.0.0.1", true, "ip"}, + {"192.168.1.1", true, "cidr"}, + {"172.16.0.1", false, ""}, + } + + for _, tt := range tests { + addr := netip.MustParseAddr(tt.addr) + blocked, reason := fw.BlockedWithReason(addr) + if blocked != tt.wantBlock { + t.Fatalf("BlockedWithReason(%s) blocked = %v, want %v", tt.addr, blocked, tt.wantBlock) + } + if reason != tt.wantReason { + t.Fatalf("BlockedWithReason(%s) reason = %q, want %q", tt.addr, reason, tt.wantReason) + } + } + + // Test rate limit reason: use a fresh IP that will exceed the limit. + rlAddr := netip.MustParseAddr("10.10.10.10") + fw.BlockedWithReason(rlAddr) // 1 + fw.BlockedWithReason(rlAddr) // 2 + blocked, reason := fw.BlockedWithReason(rlAddr) // 3 — should be blocked + if !blocked { + t.Fatal("expected rate limit block") + } + if reason != "rate_limit" { + t.Fatalf("reason = %q, want %q", reason, "rate_limit") + } +} + func TestRuntimeMutation(t *testing.T) { fw, err := New("", nil, nil, nil, 0, 0) if err != nil { diff --git a/internal/l7/policy.go b/internal/l7/policy.go index 932ad1a..fa5996a 100644 --- a/internal/l7/policy.go +++ b/internal/l7/policy.go @@ -3,6 +3,8 @@ package l7 import ( "net/http" "strings" + + "git.wntrmute.dev/kyle/mc-proxy/internal/metrics" ) // PolicyRule defines an L7 blocking policy. @@ -14,7 +16,7 @@ type PolicyRule struct { // PolicyMiddleware returns an http.Handler that evaluates L7 policies // before delegating to next. Returns HTTP 403 if any policy blocks. // If policies is empty, returns next unchanged. -func PolicyMiddleware(policies []PolicyRule, next http.Handler) http.Handler { +func PolicyMiddleware(policies []PolicyRule, hostname string, next http.Handler) http.Handler { if len(policies) == 0 { return next } @@ -23,11 +25,13 @@ func PolicyMiddleware(policies []PolicyRule, next http.Handler) http.Handler { switch p.Type { case "block_user_agent": if strings.Contains(r.UserAgent(), p.Value) { + metrics.L7PolicyBlocksTotal.WithLabelValues(hostname, "block_user_agent").Inc() w.WriteHeader(http.StatusForbidden) return } case "require_header": if r.Header.Get(p.Value) == "" { + metrics.L7PolicyBlocksTotal.WithLabelValues(hostname, "require_header").Inc() w.WriteHeader(http.StatusForbidden) return } diff --git a/internal/l7/policy_test.go b/internal/l7/policy_test.go index decb9ec..3f71c60 100644 --- a/internal/l7/policy_test.go +++ b/internal/l7/policy_test.go @@ -13,7 +13,7 @@ func TestPolicyMiddlewareNoPolicies(t *testing.T) { w.WriteHeader(200) }) - handler := PolicyMiddleware(nil, next) + handler := PolicyMiddleware(nil, "test.example.com", next) req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() @@ -35,7 +35,7 @@ func TestPolicyBlockUserAgentMatch(t *testing.T) { policies := []PolicyRule{ {Type: "block_user_agent", Value: "BadBot"}, } - handler := PolicyMiddleware(policies, next) + handler := PolicyMiddleware(policies, "test.example.com", next) req := httptest.NewRequest("GET", "/", nil) req.Header.Set("User-Agent", "Mozilla/5.0 BadBot/1.0") @@ -57,7 +57,7 @@ func TestPolicyBlockUserAgentNoMatch(t *testing.T) { policies := []PolicyRule{ {Type: "block_user_agent", Value: "BadBot"}, } - handler := PolicyMiddleware(policies, next) + handler := PolicyMiddleware(policies, "test.example.com", next) req := httptest.NewRequest("GET", "/", nil) req.Header.Set("User-Agent", "Mozilla/5.0 GoodBrowser/1.0") @@ -82,7 +82,7 @@ func TestPolicyRequireHeaderPresent(t *testing.T) { policies := []PolicyRule{ {Type: "require_header", Value: "X-API-Key"}, } - handler := PolicyMiddleware(policies, next) + handler := PolicyMiddleware(policies, "test.example.com", next) req := httptest.NewRequest("GET", "/", nil) req.Header.Set("X-API-Key", "secret") @@ -105,7 +105,7 @@ func TestPolicyRequireHeaderAbsent(t *testing.T) { policies := []PolicyRule{ {Type: "require_header", Value: "X-API-Key"}, } - handler := PolicyMiddleware(policies, next) + handler := PolicyMiddleware(policies, "test.example.com", next) req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() @@ -125,7 +125,7 @@ func TestPolicyMultipleRules(t *testing.T) { {Type: "block_user_agent", Value: "BadBot"}, {Type: "require_header", Value: "X-Token"}, } - handler := PolicyMiddleware(policies, next) + handler := PolicyMiddleware(policies, "test.example.com", next) // Blocked by UA even though header is present. req := httptest.NewRequest("GET", "/", nil) diff --git a/internal/l7/serve.go b/internal/l7/serve.go index 0d9d759..8121c7b 100644 --- a/internal/l7/serve.go +++ b/internal/l7/serve.go @@ -12,14 +12,17 @@ import ( "net/http/httputil" "net/netip" "net/url" + "strconv" "time" + "git.wntrmute.dev/kyle/mc-proxy/internal/metrics" "git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto" "golang.org/x/net/http2" ) // RouteConfig holds the L7 route parameters needed by the l7 package. type RouteConfig struct { + Hostname string Backend string TLSCert string TLSKey string @@ -29,6 +32,21 @@ type RouteConfig struct { Policies []PolicyRule } +// statusRecorder wraps http.ResponseWriter to capture the status code. +type statusRecorder struct { + http.ResponseWriter + status int +} + +func (sr *statusRecorder) WriteHeader(code int) { + sr.status = code + sr.ResponseWriter.WriteHeader(code) +} + +func (sr *statusRecorder) Unwrap() http.ResponseWriter { + return sr.ResponseWriter +} + // contextKey is an unexported type for context keys in this package. type contextKey int @@ -75,12 +93,14 @@ func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig, return fmt.Errorf("creating reverse proxy: %w", err) } - // Build handler chain: context injection → L7 policies → reverse proxy. + // Build handler chain: context injection → metrics → L7 policies → reverse proxy. var inner http.Handler = rp - inner = PolicyMiddleware(route.Policies, inner) + inner = PolicyMiddleware(route.Policies, route.Hostname, inner) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r = r.WithContext(context.WithValue(r.Context(), clientAddrKey, clientAddr)) - inner.ServeHTTP(w, r) + sr := &statusRecorder{ResponseWriter: w, status: http.StatusOK} + inner.ServeHTTP(sr, r) + metrics.L7ResponsesTotal.WithLabelValues(route.Hostname, strconv.Itoa(sr.status)).Inc() }) // Serve HTTP on the TLS connection. Use HTTP/2 if negotiated, diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go new file mode 100644 index 0000000..d174767 --- /dev/null +++ b/internal/metrics/metrics.go @@ -0,0 +1,95 @@ +// Package metrics defines Prometheus metrics for mc-proxy and provides +// an HTTP server for the /metrics endpoint. +package metrics + +import ( + "context" + "errors" + "net" + "net/http" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +var ( + // ConnectionsTotal counts connections accepted per listener and mode. + ConnectionsTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "mcproxy", + Name: "connections_total", + Help: "Total connections accepted.", + }, []string{"listener", "mode"}) + + // ConnectionsActive tracks currently active connections per listener. + ConnectionsActive = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: "mcproxy", + Name: "connections_active", + Help: "Currently active connections.", + }, []string{"listener"}) + + // FirewallBlockedTotal counts firewall blocks by reason. + FirewallBlockedTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "mcproxy", + Name: "firewall_blocked_total", + Help: "Total connections blocked by the firewall.", + }, []string{"reason"}) + + // BackendDialDuration observes backend dial latency in seconds. + BackendDialDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: "mcproxy", + Name: "backend_dial_duration_seconds", + Help: "Backend dial latency in seconds.", + Buckets: []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5}, + }, []string{"backend"}) + + // TransferredBytesTotal counts bytes transferred by direction and hostname. + TransferredBytesTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "mcproxy", + Name: "transferred_bytes_total", + Help: "Total bytes transferred.", + }, []string{"direction", "hostname"}) + + // L7ResponsesTotal counts L7 HTTP responses by hostname and status code. + L7ResponsesTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "mcproxy", + Name: "l7_responses_total", + Help: "Total L7 HTTP responses.", + }, []string{"hostname", "code"}) + + // L7PolicyBlocksTotal counts L7 policy blocks by hostname and policy type. + L7PolicyBlocksTotal = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "mcproxy", + Name: "l7_policy_blocks_total", + Help: "Total L7 policy blocks.", + }, []string{"hostname", "policy_type"}) +) + +// ListenAndServe starts a Prometheus metrics HTTP server. It blocks until +// ctx is cancelled, then shuts down gracefully. +func ListenAndServe(ctx context.Context, addr, path string) error { + if path == "" { + path = "/metrics" + } + + mux := http.NewServeMux() + mux.Handle(path, promhttp.Handler()) + + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + + srv := &http.Server{Handler: mux} + + go func() { + <-ctx.Done() + _ = srv.Close() + }() + + err = srv.Serve(ln) + if errors.Is(err, http.ErrServerClosed) { + return nil + } + return err +} diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go new file mode 100644 index 0000000..9a2b050 --- /dev/null +++ b/internal/metrics/metrics_test.go @@ -0,0 +1,121 @@ +package metrics + +import ( + "context" + "io" + "net" + "net/http" + "strings" + "testing" + "time" +) + +func TestListenAndServeShutdown(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCh := make(chan error, 1) + go func() { + errCh <- ListenAndServe(ctx, "127.0.0.1:0", "/metrics") + }() + + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case err := <-errCh: + if err != nil { + t.Fatalf("ListenAndServe returned error: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("ListenAndServe did not return after context cancel") + } +} + +func TestMetricsEndpoint(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + addr := ln.Addr().String() + _ = ln.Close() + + // Increment counters so they appear in output. + ConnectionsTotal.WithLabelValues("127.0.0.1:4430", "l4").Inc() + FirewallBlockedTotal.WithLabelValues("ip").Inc() + ConnectionsActive.WithLabelValues("127.0.0.1:4430").Set(1) + + go func() { _ = ListenAndServe(ctx, addr, "/metrics") }() + time.Sleep(100 * time.Millisecond) + + resp, err := http.Get("http://" + addr + "/metrics") + if err != nil { + t.Fatalf("GET /metrics: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != 200 { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("reading body: %v", err) + } + + text := string(body) + for _, want := range []string{ + "mcproxy_connections_total", + "mcproxy_firewall_blocked_total", + "mcproxy_connections_active", + } { + if !strings.Contains(text, want) { + t.Errorf("response missing %s", want) + } + } +} + +func TestMetricsDefaultPath(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + addr := ln.Addr().String() + _ = ln.Close() + + go func() { _ = ListenAndServe(ctx, addr, "") }() + time.Sleep(100 * time.Millisecond) + + resp, err := http.Get("http://" + addr + "/metrics") + if err != nil { + t.Fatalf("GET /metrics: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != 200 { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } +} + +func TestMetricsSanity(t *testing.T) { + // Verify all metric vars can be used without panicking. + ConnectionsTotal.WithLabelValues("test:443", "l4").Inc() + ConnectionsActive.WithLabelValues("test:443").Set(5) + FirewallBlockedTotal.WithLabelValues("ip").Inc() + FirewallBlockedTotal.WithLabelValues("cidr").Inc() + FirewallBlockedTotal.WithLabelValues("country").Inc() + FirewallBlockedTotal.WithLabelValues("rate_limit").Inc() + BackendDialDuration.WithLabelValues("127.0.0.1:8080").Observe(0.005) + TransferredBytesTotal.WithLabelValues("client_to_backend", "example.com").Add(1024) + TransferredBytesTotal.WithLabelValues("backend_to_client", "example.com").Add(2048) + L7ResponsesTotal.WithLabelValues("example.com", "200").Inc() + L7ResponsesTotal.WithLabelValues("example.com", "502").Inc() + L7PolicyBlocksTotal.WithLabelValues("example.com", "block_user_agent").Inc() + L7PolicyBlocksTotal.WithLabelValues("example.com", "require_header").Inc() +} diff --git a/internal/server/server.go b/internal/server/server.go index b32e557..8eb9c44 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -14,6 +14,7 @@ import ( "git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/kyle/mc-proxy/internal/l7" + "git.wntrmute.dev/kyle/mc-proxy/internal/metrics" "git.wntrmute.dev/kyle/mc-proxy/internal/proxy" "git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto" "git.wntrmute.dev/kyle/mc-proxy/internal/sni" @@ -41,7 +42,7 @@ type ListenerState struct { ID int64 // database primary key Addr string ProxyProtocol bool - MaxConnections int64 // 0 = unlimited + MaxConnections int64 // 0 = unlimited routes map[string]RouteInfo // lowercase hostname → route info mu sync.RWMutex ActiveConnections atomic.Int64 @@ -204,6 +205,17 @@ func (s *Server) Version() string { return s.version } +// listenerAddrForRoute finds the listener address that owns the given hostname. +func (s *Server) listenerAddrForRoute(hostname string) string { + key := strings.ToLower(hostname) + for _, ls := range s.listeners { + if _, ok := ls.lookupRoute(key); ok { + return ls.Addr + } + } + return "unknown" +} + // TotalConnections returns the total number of active connections. func (s *Server) TotalConnections() int64 { var total int64 @@ -289,6 +301,7 @@ func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState) s.wg.Add(1) ls.ActiveConnections.Add(1) + metrics.ConnectionsActive.WithLabelValues(ls.Addr).Inc() go s.handleConn(ctx, conn, ls) } } @@ -307,6 +320,7 @@ func (s *Server) forceCloseAll() { func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) { defer s.wg.Done() defer ls.ActiveConnections.Add(-1) + defer metrics.ConnectionsActive.WithLabelValues(ls.Addr).Dec() defer conn.Close() ls.connMu.Lock() @@ -340,8 +354,9 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat } } - if s.fw.Blocked(addr) { - s.logger.Debug("blocked by firewall", "addr", addr) + if blocked, reason := s.fw.BlockedWithReason(addr); blocked { + metrics.FirewallBlockedTotal.WithLabelValues(reason).Inc() + s.logger.Debug("blocked by firewall", "addr", addr, "reason", reason) return } @@ -368,7 +383,11 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat // handleL4 handles an L4 (passthrough) connection. func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) { + metrics.ConnectionsTotal.WithLabelValues(s.listenerAddrForRoute(hostname), "l4").Inc() + + dialStart := time.Now() backendConn, err := net.DialTimeout("tcp", route.Backend, s.cfg.Proxy.ConnectTimeout.Duration) + metrics.BackendDialDuration.WithLabelValues(route.Backend).Observe(time.Since(dialStart).Seconds()) if err != nil { s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err) return @@ -391,6 +410,9 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, c s.logger.Debug("relay ended", "hostname", hostname, "error", err) } + metrics.TransferredBytesTotal.WithLabelValues("client_to_backend", hostname).Add(float64(result.ClientBytes)) + metrics.TransferredBytesTotal.WithLabelValues("backend_to_client", hostname).Add(float64(result.BackendBytes)) + s.logger.Info("connection closed", "addr", addr, "hostname", hostname, @@ -401,6 +423,8 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, c // handleL7 handles an L7 (TLS-terminating) connection. func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) { + metrics.ConnectionsTotal.WithLabelValues(s.listenerAddrForRoute(hostname), "l7").Inc() + s.logger.Debug("L7 proxying", "addr", addr, "hostname", hostname, "backend", route.Backend) var policies []l7.PolicyRule @@ -409,6 +433,7 @@ func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, c } rc := l7.RouteConfig{ + Hostname: hostname, Backend: route.Backend, TLSCert: route.TLSCert, TLSKey: route.TLSKey,