Initial implementation of mc-proxy
Layer 4 TLS SNI proxy with global firewall (IP/CIDR/GeoIP blocking), per-listener route tables, bidirectional TCP relay with half-close propagation, and a gRPC admin API (routes, firewall, status) with TLS/mTLS support. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
175
internal/sni/sni.go
Normal file
175
internal/sni/sni.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package sni
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const maxBufferSize = 16384 // 16 KiB, max TLS record size
|
||||
|
||||
// Extract reads the TLS ClientHello from conn and returns the SNI hostname.
|
||||
// The returned peeked bytes contain the full ClientHello and must be forwarded
|
||||
// to the backend before starting the bidirectional copy.
|
||||
//
|
||||
// A read deadline is set on the connection to prevent slowloris attacks.
|
||||
func Extract(conn net.Conn, deadline time.Time) (hostname string, peeked []byte, err error) {
|
||||
conn.SetReadDeadline(deadline)
|
||||
defer conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Read TLS record header (5 bytes).
|
||||
header := make([]byte, 5)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return "", nil, fmt.Errorf("reading TLS record header: %w", err)
|
||||
}
|
||||
|
||||
// Verify this is a TLS handshake record (content type 0x16).
|
||||
if header[0] != 0x16 {
|
||||
return "", nil, fmt.Errorf("not a TLS handshake record (type 0x%02x)", header[0])
|
||||
}
|
||||
|
||||
// Record length.
|
||||
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
|
||||
if recordLen == 0 || recordLen > maxBufferSize-5 {
|
||||
return "", nil, fmt.Errorf("TLS record length %d out of range", recordLen)
|
||||
}
|
||||
|
||||
// Read the full record body.
|
||||
buf := make([]byte, 5+recordLen)
|
||||
copy(buf, header)
|
||||
if _, err := io.ReadFull(conn, buf[5:]); err != nil {
|
||||
return "", nil, fmt.Errorf("reading TLS record body: %w", err)
|
||||
}
|
||||
|
||||
// Parse the handshake message from the record body.
|
||||
hostname, err = parseClientHello(buf[5:])
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return hostname, buf, nil
|
||||
}
|
||||
|
||||
func parseClientHello(data []byte) (string, error) {
|
||||
if len(data) < 4 {
|
||||
return "", fmt.Errorf("handshake message too short")
|
||||
}
|
||||
|
||||
// Handshake type: 0x01 = ClientHello.
|
||||
if data[0] != 0x01 {
|
||||
return "", fmt.Errorf("not a ClientHello (type 0x%02x)", data[0])
|
||||
}
|
||||
|
||||
// Handshake length (3 bytes).
|
||||
hsLen := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
|
||||
data = data[4:]
|
||||
if len(data) < hsLen {
|
||||
return "", fmt.Errorf("ClientHello truncated")
|
||||
}
|
||||
data = data[:hsLen]
|
||||
|
||||
// Skip client version (2 bytes) + random (32 bytes).
|
||||
if len(data) < 34 {
|
||||
return "", fmt.Errorf("ClientHello too short for version+random")
|
||||
}
|
||||
data = data[34:]
|
||||
|
||||
// Skip session ID (1-byte length prefix).
|
||||
if len(data) < 1 {
|
||||
return "", fmt.Errorf("ClientHello too short for session ID length")
|
||||
}
|
||||
sidLen := int(data[0])
|
||||
data = data[1:]
|
||||
if len(data) < sidLen {
|
||||
return "", fmt.Errorf("ClientHello truncated at session ID")
|
||||
}
|
||||
data = data[sidLen:]
|
||||
|
||||
// Skip cipher suites (2-byte length prefix).
|
||||
if len(data) < 2 {
|
||||
return "", fmt.Errorf("ClientHello too short for cipher suites length")
|
||||
}
|
||||
csLen := int(binary.BigEndian.Uint16(data[:2]))
|
||||
data = data[2:]
|
||||
if len(data) < csLen {
|
||||
return "", fmt.Errorf("ClientHello truncated at cipher suites")
|
||||
}
|
||||
data = data[csLen:]
|
||||
|
||||
// Skip compression methods (1-byte length prefix).
|
||||
if len(data) < 1 {
|
||||
return "", fmt.Errorf("ClientHello too short for compression methods length")
|
||||
}
|
||||
cmLen := int(data[0])
|
||||
data = data[1:]
|
||||
if len(data) < cmLen {
|
||||
return "", fmt.Errorf("ClientHello truncated at compression methods")
|
||||
}
|
||||
data = data[cmLen:]
|
||||
|
||||
// Extensions (2-byte total length).
|
||||
if len(data) < 2 {
|
||||
return "", fmt.Errorf("no extensions in ClientHello")
|
||||
}
|
||||
extLen := int(binary.BigEndian.Uint16(data[:2]))
|
||||
data = data[2:]
|
||||
if len(data) < extLen {
|
||||
return "", fmt.Errorf("ClientHello truncated at extensions")
|
||||
}
|
||||
data = data[:extLen]
|
||||
|
||||
return findSNI(data)
|
||||
}
|
||||
|
||||
func findSNI(data []byte) (string, error) {
|
||||
for len(data) >= 4 {
|
||||
extType := binary.BigEndian.Uint16(data[:2])
|
||||
extDataLen := int(binary.BigEndian.Uint16(data[2:4]))
|
||||
data = data[4:]
|
||||
if len(data) < extDataLen {
|
||||
return "", fmt.Errorf("extension truncated")
|
||||
}
|
||||
|
||||
if extType == 0x0000 { // server_name
|
||||
return parseServerNameExtension(data[:extDataLen])
|
||||
}
|
||||
|
||||
data = data[extDataLen:]
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no SNI extension found")
|
||||
}
|
||||
|
||||
func parseServerNameExtension(data []byte) (string, error) {
|
||||
if len(data) < 2 {
|
||||
return "", fmt.Errorf("server_name extension too short")
|
||||
}
|
||||
|
||||
// Server name list length.
|
||||
listLen := int(binary.BigEndian.Uint16(data[:2]))
|
||||
data = data[2:]
|
||||
if len(data) < listLen {
|
||||
return "", fmt.Errorf("server_name list truncated")
|
||||
}
|
||||
data = data[:listLen]
|
||||
|
||||
for len(data) >= 3 {
|
||||
nameType := data[0]
|
||||
nameLen := int(binary.BigEndian.Uint16(data[1:3]))
|
||||
data = data[3:]
|
||||
if len(data) < nameLen {
|
||||
return "", fmt.Errorf("server_name entry truncated")
|
||||
}
|
||||
|
||||
if nameType == 0x00 { // hostname
|
||||
return strings.ToLower(string(data[:nameLen])), nil
|
||||
}
|
||||
|
||||
data = data[nameLen:]
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no hostname in server_name extension")
|
||||
}
|
||||
220
internal/sni/sni_test.go
Normal file
220
internal/sni/sni_test.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package sni
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestExtract(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sni string
|
||||
wantSNI string
|
||||
wantErr bool
|
||||
}{
|
||||
{"basic", "example.com", "example.com", false},
|
||||
{"case insensitive", "FoO.BaR.CoM", "foo.bar.com", false},
|
||||
{"subdomain", "metacrypt.metacircular.net", "metacrypt.metacircular.net", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
hello := buildClientHello(tt.sni)
|
||||
|
||||
go func() {
|
||||
client.Write(hello)
|
||||
}()
|
||||
|
||||
hostname, peeked, err := Extract(server, time.Now().Add(5*time.Second))
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if hostname != tt.wantSNI {
|
||||
t.Fatalf("got hostname %q, want %q", hostname, tt.wantSNI)
|
||||
}
|
||||
if len(peeked) != len(hello) {
|
||||
t.Fatalf("peeked %d bytes, want %d", len(peeked), len(hello))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractNoSNI(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
hello := buildClientHelloNoSNI()
|
||||
|
||||
go func() {
|
||||
client.Write(hello)
|
||||
}()
|
||||
|
||||
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for ClientHello without SNI")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractNotTLS(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
go func() {
|
||||
client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
|
||||
}()
|
||||
|
||||
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-TLS data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTruncated(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
go func() {
|
||||
// Write just the TLS record header, then close.
|
||||
client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50})
|
||||
client.Close()
|
||||
}()
|
||||
|
||||
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for truncated record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractOversizedRecord(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
go func() {
|
||||
// Record header claiming a length larger than 16 KiB.
|
||||
header := []byte{0x16, 0x03, 0x01}
|
||||
header = binary.BigEndian.AppendUint16(header, 16384) // exceeds maxBufferSize - 5
|
||||
client.Write(header)
|
||||
client.Close()
|
||||
}()
|
||||
|
||||
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for oversized record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractMultipleExtensions(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
hello := buildClientHelloWithExtraExtensions("target.example.com")
|
||||
|
||||
go func() {
|
||||
client.Write(hello)
|
||||
}()
|
||||
|
||||
hostname, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if hostname != "target.example.com" {
|
||||
t.Fatalf("got hostname %q, want %q", hostname, "target.example.com")
|
||||
}
|
||||
}
|
||||
|
||||
// buildClientHello constructs a minimal TLS 1.2 ClientHello with an SNI extension.
|
||||
func buildClientHello(serverName string) []byte {
|
||||
return buildClientHelloWithExtensions(sniExtension(serverName))
|
||||
}
|
||||
|
||||
// buildClientHelloNoSNI constructs a ClientHello with no extensions.
|
||||
func buildClientHelloNoSNI() []byte {
|
||||
return buildClientHelloWithExtensions(nil)
|
||||
}
|
||||
|
||||
// buildClientHelloWithExtraExtensions puts a dummy extension before the SNI.
|
||||
func buildClientHelloWithExtraExtensions(serverName string) []byte {
|
||||
// Dummy extension (type 0xFF01, empty data).
|
||||
dummy := []byte{0xFF, 0x01, 0x00, 0x00}
|
||||
ext := append(dummy, sniExtension(serverName)...)
|
||||
return buildClientHelloWithExtensions(ext)
|
||||
}
|
||||
|
||||
func buildClientHelloWithExtensions(extensions []byte) []byte {
|
||||
var hello []byte
|
||||
|
||||
// Client version: TLS 1.2.
|
||||
hello = append(hello, 0x03, 0x03)
|
||||
|
||||
// Random: 32 bytes of zeros.
|
||||
hello = append(hello, make([]byte, 32)...)
|
||||
|
||||
// Session ID: empty.
|
||||
hello = append(hello, 0x00)
|
||||
|
||||
// Cipher suites: one suite (TLS_RSA_WITH_AES_128_GCM_SHA256).
|
||||
hello = append(hello, 0x00, 0x02, 0x00, 0x9C)
|
||||
|
||||
// Compression methods: null.
|
||||
hello = append(hello, 0x01, 0x00)
|
||||
|
||||
// Extensions.
|
||||
if len(extensions) > 0 {
|
||||
hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions)))
|
||||
hello = append(hello, extensions...)
|
||||
}
|
||||
|
||||
// Wrap in handshake header (type 0x01 = ClientHello).
|
||||
handshake := []byte{0x01, 0x00, 0x00, 0x00}
|
||||
handshake[1] = byte(len(hello) >> 16)
|
||||
handshake[2] = byte(len(hello) >> 8)
|
||||
handshake[3] = byte(len(hello))
|
||||
handshake = append(handshake, hello...)
|
||||
|
||||
// Wrap in TLS record header (type 0x16 = handshake, version TLS 1.0).
|
||||
record := []byte{0x16, 0x03, 0x01}
|
||||
record = binary.BigEndian.AppendUint16(record, uint16(len(handshake)))
|
||||
record = append(record, handshake...)
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func sniExtension(serverName string) []byte {
|
||||
name := []byte(serverName)
|
||||
|
||||
// Server name entry: type 0x00 (hostname), length, name.
|
||||
var entry []byte
|
||||
entry = append(entry, 0x00)
|
||||
entry = binary.BigEndian.AppendUint16(entry, uint16(len(name)))
|
||||
entry = append(entry, name...)
|
||||
|
||||
// Server name list: length prefix.
|
||||
var list []byte
|
||||
list = binary.BigEndian.AppendUint16(list, uint16(len(entry)))
|
||||
list = append(list, entry...)
|
||||
|
||||
// Extension: type 0x0000 (server_name), length, data.
|
||||
var ext []byte
|
||||
ext = binary.BigEndian.AppendUint16(ext, 0x0000)
|
||||
ext = binary.BigEndian.AppendUint16(ext, uint16(len(list)))
|
||||
ext = append(ext, list...)
|
||||
|
||||
return ext
|
||||
}
|
||||
Reference in New Issue
Block a user