Add Prometheus metrics for connections, firewall, L7, and bytes transferred

Instrument mc-proxy with prometheus/client_golang. New internal/metrics/
package defines counters, gauges, and histograms for connection totals,
active connections, firewall blocks by reason, backend dial latency,
bytes transferred, L7 HTTP status codes, and L7 policy blocks. Optional
[metrics] config section starts a scrape endpoint. Firewall gains
BlockedWithReason() to report block cause. L7 handler wraps
ResponseWriter to record status codes per hostname.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-25 18:05:25 -07:00
parent 42c7fffc3e
commit ffc31f7d55
16 changed files with 439 additions and 32 deletions

View File

@@ -22,9 +22,16 @@ type Config struct {
GRPC GRPC `toml:"grpc"`
Firewall Firewall `toml:"firewall"`
Proxy Proxy `toml:"proxy"`
Metrics Metrics `toml:"metrics"`
Log Log `toml:"log"`
}
// Metrics holds the Prometheus metrics endpoint configuration.
type Metrics struct {
Addr string `toml:"addr"` // e.g. "127.0.0.1:9090"
Path string `toml:"path"` // e.g. "/metrics" (default)
}
// Database holds the database configuration.
type Database struct {
Path string `toml:"path"`
@@ -215,6 +222,10 @@ func (c *Config) validate() error {
}
}
if c.Metrics.Addr != "" && c.Metrics.Path != "" && !strings.HasPrefix(c.Metrics.Path, "/") {
return fmt.Errorf("metrics.path must start with \"/\"")
}
if c.Proxy.ConnectTimeout.Duration < 0 {
return fmt.Errorf("proxy.connect_timeout must not be negative")
}

View File

@@ -541,3 +541,54 @@ proxy_protocol = true
t.Fatal("expected send_proxy_protocol = true")
}
}
func TestLoadMetricsConfig(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[metrics]
addr = "127.0.0.1:9090"
path = "/metrics"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Metrics.Addr != "127.0.0.1:9090" {
t.Fatalf("got metrics.addr %q, want %q", cfg.Metrics.Addr, "127.0.0.1:9090")
}
if cfg.Metrics.Path != "/metrics" {
t.Fatalf("got metrics.path %q, want %q", cfg.Metrics.Path, "/metrics")
}
}
func TestValidateMetricsInvalidPath(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[metrics]
addr = "127.0.0.1:9090"
path = "no-slash"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for metrics.path without leading slash")
}
}

View File

@@ -71,18 +71,25 @@ func New(geoIPPath string, ips, cidrs, countries []string, rateLimit int64, rate
// Blocked returns true if the given address should be blocked.
func (f *Firewall) Blocked(addr netip.Addr) bool {
blocked, _ := f.BlockedWithReason(addr)
return blocked
}
// BlockedWithReason returns whether the address is blocked and the reason.
// Possible reasons: "ip", "cidr", "country", "rate_limit", or "" if not blocked.
func (f *Firewall) BlockedWithReason(addr netip.Addr) (bool, string) {
addr = addr.Unmap()
f.mu.RLock()
defer f.mu.RUnlock()
if _, ok := f.blockedIPs[addr]; ok {
return true
return true, "ip"
}
for _, prefix := range f.blockedCIDRs {
if prefix.Contains(addr) {
return true
return true, "cidr"
}
}
@@ -90,7 +97,7 @@ func (f *Firewall) Blocked(addr netip.Addr) bool {
var record geoIPRecord
if err := f.geoDB.Lookup(addr.AsSlice(), &record); err == nil {
if _, ok := f.blockedCountries[record.Country.ISOCode]; ok {
return true
return true, "country"
}
}
}
@@ -98,10 +105,10 @@ func (f *Firewall) Blocked(addr netip.Addr) bool {
// Rate limiting is checked after blocklist — no point tracking state
// for already-blocked IPs.
if f.rl != nil && !f.rl.Allow(addr) {
return true
return true, "rate_limit"
}
return false
return false, ""
}
// AddIP adds an IP address to the blocklist.

View File

@@ -170,6 +170,47 @@ func TestRateLimitBlocklistFirst(t *testing.T) {
}
}
func TestBlockedWithReason(t *testing.T) {
fw, err := New("", []string{"10.0.0.1"}, []string{"192.168.0.0/16"}, nil, 2, time.Minute)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
tests := []struct {
addr string
wantBlock bool
wantReason string
}{
{"10.0.0.1", true, "ip"},
{"192.168.1.1", true, "cidr"},
{"172.16.0.1", false, ""},
}
for _, tt := range tests {
addr := netip.MustParseAddr(tt.addr)
blocked, reason := fw.BlockedWithReason(addr)
if blocked != tt.wantBlock {
t.Fatalf("BlockedWithReason(%s) blocked = %v, want %v", tt.addr, blocked, tt.wantBlock)
}
if reason != tt.wantReason {
t.Fatalf("BlockedWithReason(%s) reason = %q, want %q", tt.addr, reason, tt.wantReason)
}
}
// Test rate limit reason: use a fresh IP that will exceed the limit.
rlAddr := netip.MustParseAddr("10.10.10.10")
fw.BlockedWithReason(rlAddr) // 1
fw.BlockedWithReason(rlAddr) // 2
blocked, reason := fw.BlockedWithReason(rlAddr) // 3 — should be blocked
if !blocked {
t.Fatal("expected rate limit block")
}
if reason != "rate_limit" {
t.Fatalf("reason = %q, want %q", reason, "rate_limit")
}
}
func TestRuntimeMutation(t *testing.T) {
fw, err := New("", nil, nil, nil, 0, 0)
if err != nil {

View File

@@ -3,6 +3,8 @@ package l7
import (
"net/http"
"strings"
"git.wntrmute.dev/kyle/mc-proxy/internal/metrics"
)
// PolicyRule defines an L7 blocking policy.
@@ -14,7 +16,7 @@ type PolicyRule struct {
// PolicyMiddleware returns an http.Handler that evaluates L7 policies
// before delegating to next. Returns HTTP 403 if any policy blocks.
// If policies is empty, returns next unchanged.
func PolicyMiddleware(policies []PolicyRule, next http.Handler) http.Handler {
func PolicyMiddleware(policies []PolicyRule, hostname string, next http.Handler) http.Handler {
if len(policies) == 0 {
return next
}
@@ -23,11 +25,13 @@ func PolicyMiddleware(policies []PolicyRule, next http.Handler) http.Handler {
switch p.Type {
case "block_user_agent":
if strings.Contains(r.UserAgent(), p.Value) {
metrics.L7PolicyBlocksTotal.WithLabelValues(hostname, "block_user_agent").Inc()
w.WriteHeader(http.StatusForbidden)
return
}
case "require_header":
if r.Header.Get(p.Value) == "" {
metrics.L7PolicyBlocksTotal.WithLabelValues(hostname, "require_header").Inc()
w.WriteHeader(http.StatusForbidden)
return
}

View File

@@ -13,7 +13,7 @@ func TestPolicyMiddlewareNoPolicies(t *testing.T) {
w.WriteHeader(200)
})
handler := PolicyMiddleware(nil, next)
handler := PolicyMiddleware(nil, "test.example.com", next)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
@@ -35,7 +35,7 @@ func TestPolicyBlockUserAgentMatch(t *testing.T) {
policies := []PolicyRule{
{Type: "block_user_agent", Value: "BadBot"},
}
handler := PolicyMiddleware(policies, next)
handler := PolicyMiddleware(policies, "test.example.com", next)
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("User-Agent", "Mozilla/5.0 BadBot/1.0")
@@ -57,7 +57,7 @@ func TestPolicyBlockUserAgentNoMatch(t *testing.T) {
policies := []PolicyRule{
{Type: "block_user_agent", Value: "BadBot"},
}
handler := PolicyMiddleware(policies, next)
handler := PolicyMiddleware(policies, "test.example.com", next)
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("User-Agent", "Mozilla/5.0 GoodBrowser/1.0")
@@ -82,7 +82,7 @@ func TestPolicyRequireHeaderPresent(t *testing.T) {
policies := []PolicyRule{
{Type: "require_header", Value: "X-API-Key"},
}
handler := PolicyMiddleware(policies, next)
handler := PolicyMiddleware(policies, "test.example.com", next)
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("X-API-Key", "secret")
@@ -105,7 +105,7 @@ func TestPolicyRequireHeaderAbsent(t *testing.T) {
policies := []PolicyRule{
{Type: "require_header", Value: "X-API-Key"},
}
handler := PolicyMiddleware(policies, next)
handler := PolicyMiddleware(policies, "test.example.com", next)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
@@ -125,7 +125,7 @@ func TestPolicyMultipleRules(t *testing.T) {
{Type: "block_user_agent", Value: "BadBot"},
{Type: "require_header", Value: "X-Token"},
}
handler := PolicyMiddleware(policies, next)
handler := PolicyMiddleware(policies, "test.example.com", next)
// Blocked by UA even though header is present.
req := httptest.NewRequest("GET", "/", nil)

View File

@@ -12,14 +12,17 @@ import (
"net/http/httputil"
"net/netip"
"net/url"
"strconv"
"time"
"git.wntrmute.dev/kyle/mc-proxy/internal/metrics"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
"golang.org/x/net/http2"
)
// RouteConfig holds the L7 route parameters needed by the l7 package.
type RouteConfig struct {
Hostname string
Backend string
TLSCert string
TLSKey string
@@ -29,6 +32,21 @@ type RouteConfig struct {
Policies []PolicyRule
}
// statusRecorder wraps http.ResponseWriter to capture the status code.
type statusRecorder struct {
http.ResponseWriter
status int
}
func (sr *statusRecorder) WriteHeader(code int) {
sr.status = code
sr.ResponseWriter.WriteHeader(code)
}
func (sr *statusRecorder) Unwrap() http.ResponseWriter {
return sr.ResponseWriter
}
// contextKey is an unexported type for context keys in this package.
type contextKey int
@@ -75,12 +93,14 @@ func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig,
return fmt.Errorf("creating reverse proxy: %w", err)
}
// Build handler chain: context injection → L7 policies → reverse proxy.
// Build handler chain: context injection → metrics → L7 policies → reverse proxy.
var inner http.Handler = rp
inner = PolicyMiddleware(route.Policies, inner)
inner = PolicyMiddleware(route.Policies, route.Hostname, inner)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), clientAddrKey, clientAddr))
inner.ServeHTTP(w, r)
sr := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
inner.ServeHTTP(sr, r)
metrics.L7ResponsesTotal.WithLabelValues(route.Hostname, strconv.Itoa(sr.status)).Inc()
})
// Serve HTTP on the TLS connection. Use HTTP/2 if negotiated,

View File

@@ -0,0 +1,95 @@
// Package metrics defines Prometheus metrics for mc-proxy and provides
// an HTTP server for the /metrics endpoint.
package metrics
import (
"context"
"errors"
"net"
"net/http"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
var (
// ConnectionsTotal counts connections accepted per listener and mode.
ConnectionsTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "mcproxy",
Name: "connections_total",
Help: "Total connections accepted.",
}, []string{"listener", "mode"})
// ConnectionsActive tracks currently active connections per listener.
ConnectionsActive = promauto.NewGaugeVec(prometheus.GaugeOpts{
Namespace: "mcproxy",
Name: "connections_active",
Help: "Currently active connections.",
}, []string{"listener"})
// FirewallBlockedTotal counts firewall blocks by reason.
FirewallBlockedTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "mcproxy",
Name: "firewall_blocked_total",
Help: "Total connections blocked by the firewall.",
}, []string{"reason"})
// BackendDialDuration observes backend dial latency in seconds.
BackendDialDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: "mcproxy",
Name: "backend_dial_duration_seconds",
Help: "Backend dial latency in seconds.",
Buckets: []float64{0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5},
}, []string{"backend"})
// TransferredBytesTotal counts bytes transferred by direction and hostname.
TransferredBytesTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "mcproxy",
Name: "transferred_bytes_total",
Help: "Total bytes transferred.",
}, []string{"direction", "hostname"})
// L7ResponsesTotal counts L7 HTTP responses by hostname and status code.
L7ResponsesTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "mcproxy",
Name: "l7_responses_total",
Help: "Total L7 HTTP responses.",
}, []string{"hostname", "code"})
// L7PolicyBlocksTotal counts L7 policy blocks by hostname and policy type.
L7PolicyBlocksTotal = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "mcproxy",
Name: "l7_policy_blocks_total",
Help: "Total L7 policy blocks.",
}, []string{"hostname", "policy_type"})
)
// ListenAndServe starts a Prometheus metrics HTTP server. It blocks until
// ctx is cancelled, then shuts down gracefully.
func ListenAndServe(ctx context.Context, addr, path string) error {
if path == "" {
path = "/metrics"
}
mux := http.NewServeMux()
mux.Handle(path, promhttp.Handler())
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
srv := &http.Server{Handler: mux}
go func() {
<-ctx.Done()
_ = srv.Close()
}()
err = srv.Serve(ln)
if errors.Is(err, http.ErrServerClosed) {
return nil
}
return err
}

View File

@@ -0,0 +1,121 @@
package metrics
import (
"context"
"io"
"net"
"net/http"
"strings"
"testing"
"time"
)
func TestListenAndServeShutdown(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 1)
go func() {
errCh <- ListenAndServe(ctx, "127.0.0.1:0", "/metrics")
}()
time.Sleep(50 * time.Millisecond)
cancel()
select {
case err := <-errCh:
if err != nil {
t.Fatalf("ListenAndServe returned error: %v", err)
}
case <-time.After(2 * time.Second):
t.Fatal("ListenAndServe did not return after context cancel")
}
}
func TestMetricsEndpoint(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
addr := ln.Addr().String()
_ = ln.Close()
// Increment counters so they appear in output.
ConnectionsTotal.WithLabelValues("127.0.0.1:4430", "l4").Inc()
FirewallBlockedTotal.WithLabelValues("ip").Inc()
ConnectionsActive.WithLabelValues("127.0.0.1:4430").Set(1)
go func() { _ = ListenAndServe(ctx, addr, "/metrics") }()
time.Sleep(100 * time.Millisecond)
resp, err := http.Get("http://" + addr + "/metrics")
if err != nil {
t.Fatalf("GET /metrics: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 {
t.Fatalf("status = %d, want 200", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("reading body: %v", err)
}
text := string(body)
for _, want := range []string{
"mcproxy_connections_total",
"mcproxy_firewall_blocked_total",
"mcproxy_connections_active",
} {
if !strings.Contains(text, want) {
t.Errorf("response missing %s", want)
}
}
}
func TestMetricsDefaultPath(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
addr := ln.Addr().String()
_ = ln.Close()
go func() { _ = ListenAndServe(ctx, addr, "") }()
time.Sleep(100 * time.Millisecond)
resp, err := http.Get("http://" + addr + "/metrics")
if err != nil {
t.Fatalf("GET /metrics: %v", err)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 {
t.Fatalf("status = %d, want 200", resp.StatusCode)
}
}
func TestMetricsSanity(t *testing.T) {
// Verify all metric vars can be used without panicking.
ConnectionsTotal.WithLabelValues("test:443", "l4").Inc()
ConnectionsActive.WithLabelValues("test:443").Set(5)
FirewallBlockedTotal.WithLabelValues("ip").Inc()
FirewallBlockedTotal.WithLabelValues("cidr").Inc()
FirewallBlockedTotal.WithLabelValues("country").Inc()
FirewallBlockedTotal.WithLabelValues("rate_limit").Inc()
BackendDialDuration.WithLabelValues("127.0.0.1:8080").Observe(0.005)
TransferredBytesTotal.WithLabelValues("client_to_backend", "example.com").Add(1024)
TransferredBytesTotal.WithLabelValues("backend_to_client", "example.com").Add(2048)
L7ResponsesTotal.WithLabelValues("example.com", "200").Inc()
L7ResponsesTotal.WithLabelValues("example.com", "502").Inc()
L7PolicyBlocksTotal.WithLabelValues("example.com", "block_user_agent").Inc()
L7PolicyBlocksTotal.WithLabelValues("example.com", "require_header").Inc()
}

View File

@@ -14,6 +14,7 @@ import (
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
"git.wntrmute.dev/kyle/mc-proxy/internal/l7"
"git.wntrmute.dev/kyle/mc-proxy/internal/metrics"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxy"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
@@ -41,7 +42,7 @@ type ListenerState struct {
ID int64 // database primary key
Addr string
ProxyProtocol bool
MaxConnections int64 // 0 = unlimited
MaxConnections int64 // 0 = unlimited
routes map[string]RouteInfo // lowercase hostname → route info
mu sync.RWMutex
ActiveConnections atomic.Int64
@@ -204,6 +205,17 @@ func (s *Server) Version() string {
return s.version
}
// listenerAddrForRoute finds the listener address that owns the given hostname.
func (s *Server) listenerAddrForRoute(hostname string) string {
key := strings.ToLower(hostname)
for _, ls := range s.listeners {
if _, ok := ls.lookupRoute(key); ok {
return ls.Addr
}
}
return "unknown"
}
// TotalConnections returns the total number of active connections.
func (s *Server) TotalConnections() int64 {
var total int64
@@ -289,6 +301,7 @@ func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState)
s.wg.Add(1)
ls.ActiveConnections.Add(1)
metrics.ConnectionsActive.WithLabelValues(ls.Addr).Inc()
go s.handleConn(ctx, conn, ls)
}
}
@@ -307,6 +320,7 @@ func (s *Server) forceCloseAll() {
func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) {
defer s.wg.Done()
defer ls.ActiveConnections.Add(-1)
defer metrics.ConnectionsActive.WithLabelValues(ls.Addr).Dec()
defer conn.Close()
ls.connMu.Lock()
@@ -340,8 +354,9 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat
}
}
if s.fw.Blocked(addr) {
s.logger.Debug("blocked by firewall", "addr", addr)
if blocked, reason := s.fw.BlockedWithReason(addr); blocked {
metrics.FirewallBlockedTotal.WithLabelValues(reason).Inc()
s.logger.Debug("blocked by firewall", "addr", addr, "reason", reason)
return
}
@@ -368,7 +383,11 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat
// handleL4 handles an L4 (passthrough) connection.
func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) {
metrics.ConnectionsTotal.WithLabelValues(s.listenerAddrForRoute(hostname), "l4").Inc()
dialStart := time.Now()
backendConn, err := net.DialTimeout("tcp", route.Backend, s.cfg.Proxy.ConnectTimeout.Duration)
metrics.BackendDialDuration.WithLabelValues(route.Backend).Observe(time.Since(dialStart).Seconds())
if err != nil {
s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err)
return
@@ -391,6 +410,9 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, c
s.logger.Debug("relay ended", "hostname", hostname, "error", err)
}
metrics.TransferredBytesTotal.WithLabelValues("client_to_backend", hostname).Add(float64(result.ClientBytes))
metrics.TransferredBytesTotal.WithLabelValues("backend_to_client", hostname).Add(float64(result.BackendBytes))
s.logger.Info("connection closed",
"addr", addr,
"hostname", hostname,
@@ -401,6 +423,8 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, c
// handleL7 handles an L7 (TLS-terminating) connection.
func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) {
metrics.ConnectionsTotal.WithLabelValues(s.listenerAddrForRoute(hostname), "l7").Inc()
s.logger.Debug("L7 proxying", "addr", addr, "hostname", hostname, "backend", route.Backend)
var policies []l7.PolicyRule
@@ -409,6 +433,7 @@ func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, c
}
rc := l7.RouteConfig{
Hostname: hostname,
Backend: route.Backend,
TLSCert: route.TLSCert,
TLSKey: route.TLSKey,