Per-route HTTP-level blocking policies for L7 routes. Two rule types: block_user_agent (substring match against User-Agent, returns 403) and require_header (named header must be present, returns 403). Config: L7Policy struct with type/value fields, added as L7Policies slice on Route. Validated in config (type enum, non-empty value, warning if set on L4 routes). DB: Migration 4 creates l7_policies table with route_id FK (cascade delete), type CHECK constraint, UNIQUE(route_id, type, value). New l7policies.go with ListL7Policies, CreateL7Policy, DeleteL7Policy, GetRouteID. Seed updated to persist policies from config. L7 middleware: PolicyMiddleware in internal/l7/policy.go evaluates rules in order, returns 403 on first match, no-op if empty. Composed into the handler chain between context injection and reverse proxy. Server: L7PolicyRule type on RouteInfo with AddL7Policy/RemoveL7Policy mutation methods on ListenerState. handleL7 threads policies into l7.RouteConfig. Startup loads policies per L7 route from DB. Proto: L7Policy message, repeated l7_policies on Route. Three new RPCs: ListL7Policies, AddL7Policy, RemoveL7Policy. All follow the write-through pattern. Client: L7Policy type, ListL7Policies/AddL7Policy/RemoveL7Policy methods. CLI: mcproxyctl policies list/add/remove subcommands. Tests: 6 PolicyMiddleware unit tests (no policies, UA match/no-match, header present/absent, multiple rules). 4 DB tests (CRUD, cascade, duplicate, GetRouteID). 3 gRPC tests (add+list, remove, validation). 2 end-to-end L7 tests (UA block, required header with allow/deny). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
666 lines
17 KiB
Go
666 lines
17 KiB
Go
package l7
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ecdsa"
|
|
"crypto/elliptic"
|
|
"crypto/rand"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"io"
|
|
"log/slog"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"net/netip"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/net/http2"
|
|
"golang.org/x/net/http2/h2c"
|
|
)
|
|
|
|
// testCert generates a self-signed TLS certificate for the given hostname
|
|
// and writes the cert/key to temporary files, returning their paths.
|
|
func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
|
|
t.Helper()
|
|
|
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
if err != nil {
|
|
t.Fatalf("generating key: %v", err)
|
|
}
|
|
|
|
tmpl := &x509.Certificate{
|
|
SerialNumber: big.NewInt(1),
|
|
Subject: pkix.Name{CommonName: hostname},
|
|
DNSNames: []string{hostname},
|
|
NotBefore: time.Now().Add(-time.Hour),
|
|
NotAfter: time.Now().Add(time.Hour),
|
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
}
|
|
|
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
|
if err != nil {
|
|
t.Fatalf("creating certificate: %v", err)
|
|
}
|
|
|
|
dir := t.TempDir()
|
|
certPath = filepath.Join(dir, "cert.pem")
|
|
keyPath = filepath.Join(dir, "key.pem")
|
|
|
|
certFile, err := os.Create(certPath)
|
|
if err != nil {
|
|
t.Fatalf("creating cert file: %v", err)
|
|
}
|
|
pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
|
certFile.Close()
|
|
|
|
keyDER, err := x509.MarshalECPrivateKey(key)
|
|
if err != nil {
|
|
t.Fatalf("marshaling key: %v", err)
|
|
}
|
|
keyFile, err := os.Create(keyPath)
|
|
if err != nil {
|
|
t.Fatalf("creating key file: %v", err)
|
|
}
|
|
pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
|
keyFile.Close()
|
|
|
|
return certPath, keyPath
|
|
}
|
|
|
|
// startH2CBackend starts an h2c (HTTP/2 cleartext) backend server that
|
|
// responds with the given status and body. Returns the listener address.
|
|
func startH2CBackend(t *testing.T, handler http.Handler) string {
|
|
t.Helper()
|
|
|
|
h2s := &http2.Server{}
|
|
srv := &http.Server{
|
|
Handler: h2c.NewHandler(handler, h2s),
|
|
ReadHeaderTimeout: 5 * time.Second,
|
|
}
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen: %v", err)
|
|
}
|
|
t.Cleanup(func() {
|
|
srv.Close()
|
|
ln.Close()
|
|
})
|
|
|
|
go srv.Serve(ln)
|
|
return ln.Addr().String()
|
|
}
|
|
|
|
// dialTLSToProxy dials a TCP connection to the proxy, does a TLS handshake
|
|
// with the given serverName (skipping cert verification for self-signed),
|
|
// and returns an *http.Client configured to use that connection for HTTP/2.
|
|
func dialTLSToProxy(t *testing.T, proxyAddr, serverName string) *http.Client {
|
|
t.Helper()
|
|
|
|
tlsConf := &tls.Config{
|
|
ServerName: serverName,
|
|
InsecureSkipVerify: true,
|
|
NextProtos: []string{"h2"},
|
|
}
|
|
|
|
conn, err := tls.DialWithDialer(
|
|
&net.Dialer{Timeout: 5 * time.Second},
|
|
"tcp", proxyAddr, tlsConf,
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("TLS dial: %v", err)
|
|
}
|
|
t.Cleanup(func() { conn.Close() })
|
|
|
|
// Create an HTTP/2 client transport over this single connection.
|
|
tr := &http2.Transport{}
|
|
h2conn, err := tr.NewClientConn(conn)
|
|
if err != nil {
|
|
t.Fatalf("creating h2 client conn: %v", err)
|
|
}
|
|
|
|
return &http.Client{
|
|
Transport: &singleConnRoundTripper{cc: h2conn},
|
|
}
|
|
}
|
|
|
|
// singleConnRoundTripper is an http.RoundTripper that uses a single HTTP/2
|
|
// client connection.
|
|
type singleConnRoundTripper struct {
|
|
cc *http2.ClientConn
|
|
}
|
|
|
|
func (s *singleConnRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
return s.cc.RoundTrip(req)
|
|
}
|
|
|
|
// serveL7Route starts l7.Serve in a goroutine for a single connection.
|
|
// Returns when the goroutine completes.
|
|
func serveL7Route(t *testing.T, conn net.Conn, peeked []byte, route RouteConfig) {
|
|
t.Helper()
|
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
|
ctx := context.Background()
|
|
|
|
go func() {
|
|
l7Err := Serve(ctx, conn, peeked, route, clientAddr, logger)
|
|
if l7Err != nil {
|
|
t.Logf("l7.Serve: %v", l7Err)
|
|
}
|
|
}()
|
|
}
|
|
|
|
func TestL7H2CBackend(t *testing.T) {
|
|
certPath, keyPath := testCert(t, "l7.test")
|
|
|
|
// Start an h2c backend.
|
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Backend", "ok")
|
|
fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path)
|
|
}))
|
|
|
|
// Start a TCP listener for the L7 proxy.
|
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("proxy listen: %v", err)
|
|
}
|
|
defer proxyLn.Close()
|
|
|
|
route := RouteConfig{
|
|
Backend: backendAddr,
|
|
TLSCert: certPath,
|
|
TLSKey: keyPath,
|
|
ConnectTimeout: 5 * time.Second,
|
|
}
|
|
|
|
// Accept one connection and run L7 serve.
|
|
go func() {
|
|
conn, err := proxyLn.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
|
// No peeked bytes — the client is connecting directly with TLS.
|
|
Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
|
}()
|
|
|
|
// Connect as an HTTP/2 TLS client.
|
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "l7.test")
|
|
|
|
resp, err := client.Get(fmt.Sprintf("https://l7.test/foo"))
|
|
if err != nil {
|
|
t.Fatalf("GET: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
t.Fatalf("status = %d, want 200", resp.StatusCode)
|
|
}
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if got := string(body); got != "hello from backend, path=/foo" {
|
|
t.Fatalf("body = %q, want %q", got, "hello from backend, path=/foo")
|
|
}
|
|
|
|
if resp.Header.Get("X-Backend") != "ok" {
|
|
t.Fatalf("X-Backend header missing or wrong")
|
|
}
|
|
}
|
|
|
|
func TestL7ForwardingHeaders(t *testing.T) {
|
|
certPath, keyPath := testCert(t, "headers.test")
|
|
|
|
// Backend that echoes the forwarding headers.
|
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprintf(w, "xff=%s xfp=%s xri=%s",
|
|
r.Header.Get("X-Forwarded-For"),
|
|
r.Header.Get("X-Forwarded-Proto"),
|
|
r.Header.Get("X-Real-IP"),
|
|
)
|
|
}))
|
|
|
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("proxy listen: %v", err)
|
|
}
|
|
defer proxyLn.Close()
|
|
|
|
route := RouteConfig{
|
|
Backend: backendAddr,
|
|
TLSCert: certPath,
|
|
TLSKey: keyPath,
|
|
ConnectTimeout: 5 * time.Second,
|
|
}
|
|
|
|
go func() {
|
|
conn, err := proxyLn.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
|
Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
|
}()
|
|
|
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "headers.test")
|
|
resp, err := client.Get("https://headers.test/")
|
|
if err != nil {
|
|
t.Fatalf("GET: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
want := "xff=203.0.113.50 xfp=https xri=203.0.113.50"
|
|
if string(body) != want {
|
|
t.Fatalf("body = %q, want %q", body, want)
|
|
}
|
|
}
|
|
|
|
func TestL7BackendUnreachable(t *testing.T) {
|
|
certPath, keyPath := testCert(t, "unreachable.test")
|
|
|
|
// Find a port that nothing is listening on.
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen: %v", err)
|
|
}
|
|
deadAddr := ln.Addr().String()
|
|
ln.Close()
|
|
|
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("proxy listen: %v", err)
|
|
}
|
|
defer proxyLn.Close()
|
|
|
|
route := RouteConfig{
|
|
Backend: deadAddr,
|
|
TLSCert: certPath,
|
|
TLSKey: keyPath,
|
|
ConnectTimeout: 1 * time.Second,
|
|
}
|
|
|
|
go func() {
|
|
conn, err := proxyLn.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
|
Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
|
}()
|
|
|
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "unreachable.test")
|
|
resp, err := client.Get("https://unreachable.test/")
|
|
if err != nil {
|
|
t.Fatalf("GET: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusBadGateway {
|
|
t.Fatalf("status = %d, want 502", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestIsTimeoutError(t *testing.T) {
|
|
// context.DeadlineExceeded is a timeout.
|
|
if !isTimeoutError(context.DeadlineExceeded) {
|
|
t.Fatal("expected DeadlineExceeded to be a timeout error")
|
|
}
|
|
|
|
// A net timeout error is a timeout.
|
|
netErr := &net.OpError{Op: "dial", Err: &timeoutErr{}}
|
|
if !isTimeoutError(netErr) {
|
|
t.Fatal("expected net timeout to be a timeout error")
|
|
}
|
|
|
|
// A regular error is not a timeout.
|
|
if isTimeoutError(fmt.Errorf("connection refused")) {
|
|
t.Fatal("expected non-timeout error to return false")
|
|
}
|
|
}
|
|
|
|
// timeoutErr implements net.Error with Timeout() = true.
|
|
type timeoutErr struct{}
|
|
|
|
func (e *timeoutErr) Error() string { return "timeout" }
|
|
func (e *timeoutErr) Timeout() bool { return true }
|
|
func (e *timeoutErr) Temporary() bool { return false }
|
|
|
|
func TestL7MultipleRequests(t *testing.T) {
|
|
certPath, keyPath := testCert(t, "multi.test")
|
|
|
|
var reqCount int
|
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
reqCount++
|
|
fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path)
|
|
}))
|
|
|
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("proxy listen: %v", err)
|
|
}
|
|
defer proxyLn.Close()
|
|
|
|
route := RouteConfig{
|
|
Backend: backendAddr,
|
|
TLSCert: certPath,
|
|
TLSKey: keyPath,
|
|
ConnectTimeout: 5 * time.Second,
|
|
}
|
|
|
|
go func() {
|
|
conn, err := proxyLn.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
|
Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
|
}()
|
|
|
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "multi.test")
|
|
|
|
// Send multiple requests over the same HTTP/2 connection.
|
|
for i := range 3 {
|
|
path := fmt.Sprintf("/req%d", i)
|
|
resp, err := client.Get(fmt.Sprintf("https://multi.test%s", path))
|
|
if err != nil {
|
|
t.Fatalf("GET %s: %v", path, err)
|
|
}
|
|
body, _ := io.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
|
|
want := fmt.Sprintf("req=%d path=%s", i+1, path)
|
|
if string(body) != want {
|
|
t.Fatalf("request %d: body = %q, want %q", i, body, want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestL7LargeResponse(t *testing.T) {
|
|
certPath, keyPath := testCert(t, "large.test")
|
|
|
|
// Backend sends a 1 MB response.
|
|
largeBody := make([]byte, 1<<20)
|
|
for i := range largeBody {
|
|
largeBody[i] = byte(i % 256)
|
|
}
|
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Write(largeBody)
|
|
}))
|
|
|
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("proxy listen: %v", err)
|
|
}
|
|
defer proxyLn.Close()
|
|
|
|
route := RouteConfig{
|
|
Backend: backendAddr,
|
|
TLSCert: certPath,
|
|
TLSKey: keyPath,
|
|
ConnectTimeout: 5 * time.Second,
|
|
}
|
|
|
|
go func() {
|
|
conn, err := proxyLn.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
|
}()
|
|
|
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "large.test")
|
|
resp, err := client.Get("https://large.test/")
|
|
if err != nil {
|
|
t.Fatalf("GET: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
if len(body) != len(largeBody) {
|
|
t.Fatalf("got %d bytes, want %d", len(body), len(largeBody))
|
|
}
|
|
}
|
|
|
|
func TestL7GRPCTrailers(t *testing.T) {
|
|
certPath, keyPath := testCert(t, "trailers.test")
|
|
|
|
// Backend that sets HTTP trailers (used by gRPC for status).
|
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Trailer", "Grpc-Status, Grpc-Message")
|
|
w.Header().Set("Content-Type", "application/grpc")
|
|
w.WriteHeader(200)
|
|
// Flush to send headers.
|
|
if f, ok := w.(http.Flusher); ok {
|
|
f.Flush()
|
|
}
|
|
// Set trailers.
|
|
w.Header().Set("Grpc-Status", "0")
|
|
w.Header().Set("Grpc-Message", "OK")
|
|
}))
|
|
|
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("proxy listen: %v", err)
|
|
}
|
|
defer proxyLn.Close()
|
|
|
|
route := RouteConfig{
|
|
Backend: backendAddr,
|
|
TLSCert: certPath,
|
|
TLSKey: keyPath,
|
|
ConnectTimeout: 5 * time.Second,
|
|
}
|
|
|
|
go func() {
|
|
conn, err := proxyLn.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
|
}()
|
|
|
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "trailers.test")
|
|
req, _ := http.NewRequest("POST", "https://trailers.test/grpc.test.Service/Method", nil)
|
|
req.Header.Set("Content-Type", "application/grpc")
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("POST: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Read body to trigger trailer delivery.
|
|
io.ReadAll(resp.Body)
|
|
|
|
// Verify trailers were forwarded through the proxy.
|
|
grpcStatus := resp.Trailer.Get("Grpc-Status")
|
|
if grpcStatus != "0" {
|
|
t.Fatalf("Grpc-Status trailer = %q, want %q", grpcStatus, "0")
|
|
}
|
|
grpcMessage := resp.Trailer.Get("Grpc-Message")
|
|
if grpcMessage != "OK" {
|
|
t.Fatalf("Grpc-Message trailer = %q, want %q", grpcMessage, "OK")
|
|
}
|
|
}
|
|
|
|
func TestL7HTTP11Fallback(t *testing.T) {
|
|
certPath, keyPath := testCert(t, "http11.test")
|
|
|
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprintf(w, "proto=%s", r.Proto)
|
|
}))
|
|
|
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("proxy listen: %v", err)
|
|
}
|
|
defer proxyLn.Close()
|
|
|
|
route := RouteConfig{
|
|
Backend: backendAddr,
|
|
TLSCert: certPath,
|
|
TLSKey: keyPath,
|
|
ConnectTimeout: 5 * time.Second,
|
|
}
|
|
|
|
go func() {
|
|
conn, err := proxyLn.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
|
}()
|
|
|
|
// Connect with HTTP/1.1 only (no h2 ALPN).
|
|
tlsConf := &tls.Config{
|
|
ServerName: "http11.test",
|
|
InsecureSkipVerify: true,
|
|
NextProtos: []string{"http/1.1"},
|
|
}
|
|
tr := &http.Transport{TLSClientConfig: tlsConf}
|
|
client := &http.Client{Transport: tr}
|
|
|
|
resp, err := client.Get(fmt.Sprintf("https://%s/", proxyLn.Addr().String()))
|
|
if err != nil {
|
|
t.Fatalf("GET: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 200 {
|
|
t.Fatalf("status = %d, want 200", resp.StatusCode)
|
|
}
|
|
body, _ := io.ReadAll(resp.Body)
|
|
// The backend sees HTTP/2 (proxied via h2c) regardless of client protocol.
|
|
// Just verify we got a response — the protocol the backend sees depends
|
|
// on the h2c transport.
|
|
if len(body) == 0 {
|
|
t.Fatal("empty response body")
|
|
}
|
|
}
|
|
|
|
func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
|
|
certPath, keyPath := testCert(t, "policy.test")
|
|
|
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprint(w, "should-not-reach")
|
|
}))
|
|
|
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("proxy listen: %v", err)
|
|
}
|
|
defer proxyLn.Close()
|
|
|
|
route := RouteConfig{
|
|
Backend: backendAddr,
|
|
TLSCert: certPath,
|
|
TLSKey: keyPath,
|
|
ConnectTimeout: 5 * time.Second,
|
|
Policies: []PolicyRule{
|
|
{Type: "block_user_agent", Value: "EvilBot"},
|
|
},
|
|
}
|
|
|
|
go func() {
|
|
conn, err := proxyLn.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
|
}()
|
|
|
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "policy.test")
|
|
req, _ := http.NewRequest("GET", "https://policy.test/", nil)
|
|
req.Header.Set("User-Agent", "EvilBot/1.0")
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("GET: %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != 403 {
|
|
t.Fatalf("status = %d, want 403", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
|
|
certPath, keyPath := testCert(t, "reqhdr.test")
|
|
|
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprint(w, "ok")
|
|
}))
|
|
|
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("proxy listen: %v", err)
|
|
}
|
|
defer proxyLn.Close()
|
|
|
|
route := RouteConfig{
|
|
Backend: backendAddr,
|
|
TLSCert: certPath,
|
|
TLSKey: keyPath,
|
|
ConnectTimeout: 5 * time.Second,
|
|
Policies: []PolicyRule{
|
|
{Type: "require_header", Value: "X-Auth-Token"},
|
|
},
|
|
}
|
|
|
|
// Accept two connections (one blocked, one allowed).
|
|
go func() {
|
|
for range 2 {
|
|
conn, err := proxyLn.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
go func() {
|
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
|
}()
|
|
}
|
|
}()
|
|
|
|
// Without the required header → 403.
|
|
client1 := dialTLSToProxy(t, proxyLn.Addr().String(), "reqhdr.test")
|
|
resp1, err := client1.Get("https://reqhdr.test/")
|
|
if err != nil {
|
|
t.Fatalf("GET without header: %v", err)
|
|
}
|
|
resp1.Body.Close()
|
|
if resp1.StatusCode != 403 {
|
|
t.Fatalf("without header: status = %d, want 403", resp1.StatusCode)
|
|
}
|
|
|
|
// With the required header → 200.
|
|
client2 := dialTLSToProxy(t, proxyLn.Addr().String(), "reqhdr.test")
|
|
req, _ := http.NewRequest("GET", "https://reqhdr.test/", nil)
|
|
req.Header.Set("X-Auth-Token", "valid-token")
|
|
resp2, err := client2.Do(req)
|
|
if err != nil {
|
|
t.Fatalf("GET with header: %v", err)
|
|
}
|
|
defer resp2.Body.Close()
|
|
body, _ := io.ReadAll(resp2.Body)
|
|
if resp2.StatusCode != 200 {
|
|
t.Fatalf("with header: status = %d, want 200", resp2.StatusCode)
|
|
}
|
|
if string(body) != "ok" {
|
|
t.Fatalf("body = %q, want %q", body, "ok")
|
|
}
|
|
}
|