Layer 4 TLS SNI proxy with global firewall (IP/CIDR/GeoIP blocking), per-listener route tables, bidirectional TCP relay with half-close propagation, and a gRPC admin API (routes, firewall, status) with TLS/mTLS support. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
221 lines
5.4 KiB
Go
221 lines
5.4 KiB
Go
package sni
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestExtract(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
sni string
|
|
wantSNI string
|
|
wantErr bool
|
|
}{
|
|
{"basic", "example.com", "example.com", false},
|
|
{"case insensitive", "FoO.BaR.CoM", "foo.bar.com", false},
|
|
{"subdomain", "metacrypt.metacircular.net", "metacrypt.metacircular.net", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
client, server := net.Pipe()
|
|
defer client.Close()
|
|
defer server.Close()
|
|
|
|
hello := buildClientHello(tt.sni)
|
|
|
|
go func() {
|
|
client.Write(hello)
|
|
}()
|
|
|
|
hostname, peeked, err := Extract(server, time.Now().Add(5*time.Second))
|
|
if tt.wantErr {
|
|
if err == nil {
|
|
t.Fatal("expected error, got nil")
|
|
}
|
|
return
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if hostname != tt.wantSNI {
|
|
t.Fatalf("got hostname %q, want %q", hostname, tt.wantSNI)
|
|
}
|
|
if len(peeked) != len(hello) {
|
|
t.Fatalf("peeked %d bytes, want %d", len(peeked), len(hello))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExtractNoSNI(t *testing.T) {
|
|
client, server := net.Pipe()
|
|
defer client.Close()
|
|
defer server.Close()
|
|
|
|
hello := buildClientHelloNoSNI()
|
|
|
|
go func() {
|
|
client.Write(hello)
|
|
}()
|
|
|
|
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
|
if err == nil {
|
|
t.Fatal("expected error for ClientHello without SNI")
|
|
}
|
|
}
|
|
|
|
func TestExtractNotTLS(t *testing.T) {
|
|
client, server := net.Pipe()
|
|
defer client.Close()
|
|
defer server.Close()
|
|
|
|
go func() {
|
|
client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
|
|
}()
|
|
|
|
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
|
if err == nil {
|
|
t.Fatal("expected error for non-TLS data")
|
|
}
|
|
}
|
|
|
|
func TestExtractTruncated(t *testing.T) {
|
|
client, server := net.Pipe()
|
|
defer client.Close()
|
|
defer server.Close()
|
|
|
|
go func() {
|
|
// Write just the TLS record header, then close.
|
|
client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50})
|
|
client.Close()
|
|
}()
|
|
|
|
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
|
if err == nil {
|
|
t.Fatal("expected error for truncated record")
|
|
}
|
|
}
|
|
|
|
func TestExtractOversizedRecord(t *testing.T) {
|
|
client, server := net.Pipe()
|
|
defer client.Close()
|
|
defer server.Close()
|
|
|
|
go func() {
|
|
// Record header claiming a length larger than 16 KiB.
|
|
header := []byte{0x16, 0x03, 0x01}
|
|
header = binary.BigEndian.AppendUint16(header, 16384) // exceeds maxBufferSize - 5
|
|
client.Write(header)
|
|
client.Close()
|
|
}()
|
|
|
|
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
|
if err == nil {
|
|
t.Fatal("expected error for oversized record")
|
|
}
|
|
}
|
|
|
|
func TestExtractMultipleExtensions(t *testing.T) {
|
|
client, server := net.Pipe()
|
|
defer client.Close()
|
|
defer server.Close()
|
|
|
|
hello := buildClientHelloWithExtraExtensions("target.example.com")
|
|
|
|
go func() {
|
|
client.Write(hello)
|
|
}()
|
|
|
|
hostname, _, err := Extract(server, time.Now().Add(5*time.Second))
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if hostname != "target.example.com" {
|
|
t.Fatalf("got hostname %q, want %q", hostname, "target.example.com")
|
|
}
|
|
}
|
|
|
|
// buildClientHello constructs a minimal TLS 1.2 ClientHello with an SNI extension.
|
|
func buildClientHello(serverName string) []byte {
|
|
return buildClientHelloWithExtensions(sniExtension(serverName))
|
|
}
|
|
|
|
// buildClientHelloNoSNI constructs a ClientHello with no extensions.
|
|
func buildClientHelloNoSNI() []byte {
|
|
return buildClientHelloWithExtensions(nil)
|
|
}
|
|
|
|
// buildClientHelloWithExtraExtensions puts a dummy extension before the SNI.
|
|
func buildClientHelloWithExtraExtensions(serverName string) []byte {
|
|
// Dummy extension (type 0xFF01, empty data).
|
|
dummy := []byte{0xFF, 0x01, 0x00, 0x00}
|
|
ext := append(dummy, sniExtension(serverName)...)
|
|
return buildClientHelloWithExtensions(ext)
|
|
}
|
|
|
|
func buildClientHelloWithExtensions(extensions []byte) []byte {
|
|
var hello []byte
|
|
|
|
// Client version: TLS 1.2.
|
|
hello = append(hello, 0x03, 0x03)
|
|
|
|
// Random: 32 bytes of zeros.
|
|
hello = append(hello, make([]byte, 32)...)
|
|
|
|
// Session ID: empty.
|
|
hello = append(hello, 0x00)
|
|
|
|
// Cipher suites: one suite (TLS_RSA_WITH_AES_128_GCM_SHA256).
|
|
hello = append(hello, 0x00, 0x02, 0x00, 0x9C)
|
|
|
|
// Compression methods: null.
|
|
hello = append(hello, 0x01, 0x00)
|
|
|
|
// Extensions.
|
|
if len(extensions) > 0 {
|
|
hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions)))
|
|
hello = append(hello, extensions...)
|
|
}
|
|
|
|
// Wrap in handshake header (type 0x01 = ClientHello).
|
|
handshake := []byte{0x01, 0x00, 0x00, 0x00}
|
|
handshake[1] = byte(len(hello) >> 16)
|
|
handshake[2] = byte(len(hello) >> 8)
|
|
handshake[3] = byte(len(hello))
|
|
handshake = append(handshake, hello...)
|
|
|
|
// Wrap in TLS record header (type 0x16 = handshake, version TLS 1.0).
|
|
record := []byte{0x16, 0x03, 0x01}
|
|
record = binary.BigEndian.AppendUint16(record, uint16(len(handshake)))
|
|
record = append(record, handshake...)
|
|
|
|
return record
|
|
}
|
|
|
|
func sniExtension(serverName string) []byte {
|
|
name := []byte(serverName)
|
|
|
|
// Server name entry: type 0x00 (hostname), length, name.
|
|
var entry []byte
|
|
entry = append(entry, 0x00)
|
|
entry = binary.BigEndian.AppendUint16(entry, uint16(len(name)))
|
|
entry = append(entry, name...)
|
|
|
|
// Server name list: length prefix.
|
|
var list []byte
|
|
list = binary.BigEndian.AppendUint16(list, uint16(len(entry)))
|
|
list = append(list, entry...)
|
|
|
|
// Extension: type 0x0000 (server_name), length, data.
|
|
var ext []byte
|
|
ext = binary.BigEndian.AppendUint16(ext, 0x0000)
|
|
ext = binary.BigEndian.AppendUint16(ext, uint16(len(list)))
|
|
ext = append(ext, list...)
|
|
|
|
return ext
|
|
}
|