Layer 4 TLS SNI proxy with global firewall (IP/CIDR/GeoIP blocking), per-listener route tables, bidirectional TCP relay with half-close propagation, and a gRPC admin API (routes, firewall, status) with TLS/mTLS support. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
176 lines
4.5 KiB
Go
176 lines
4.5 KiB
Go
package sni
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const maxBufferSize = 16384 // 16 KiB, max TLS record size
|
|
|
|
// Extract reads the TLS ClientHello from conn and returns the SNI hostname.
|
|
// The returned peeked bytes contain the full ClientHello and must be forwarded
|
|
// to the backend before starting the bidirectional copy.
|
|
//
|
|
// A read deadline is set on the connection to prevent slowloris attacks.
|
|
func Extract(conn net.Conn, deadline time.Time) (hostname string, peeked []byte, err error) {
|
|
conn.SetReadDeadline(deadline)
|
|
defer conn.SetReadDeadline(time.Time{})
|
|
|
|
// Read TLS record header (5 bytes).
|
|
header := make([]byte, 5)
|
|
if _, err := io.ReadFull(conn, header); err != nil {
|
|
return "", nil, fmt.Errorf("reading TLS record header: %w", err)
|
|
}
|
|
|
|
// Verify this is a TLS handshake record (content type 0x16).
|
|
if header[0] != 0x16 {
|
|
return "", nil, fmt.Errorf("not a TLS handshake record (type 0x%02x)", header[0])
|
|
}
|
|
|
|
// Record length.
|
|
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
|
|
if recordLen == 0 || recordLen > maxBufferSize-5 {
|
|
return "", nil, fmt.Errorf("TLS record length %d out of range", recordLen)
|
|
}
|
|
|
|
// Read the full record body.
|
|
buf := make([]byte, 5+recordLen)
|
|
copy(buf, header)
|
|
if _, err := io.ReadFull(conn, buf[5:]); err != nil {
|
|
return "", nil, fmt.Errorf("reading TLS record body: %w", err)
|
|
}
|
|
|
|
// Parse the handshake message from the record body.
|
|
hostname, err = parseClientHello(buf[5:])
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
|
|
return hostname, buf, nil
|
|
}
|
|
|
|
func parseClientHello(data []byte) (string, error) {
|
|
if len(data) < 4 {
|
|
return "", fmt.Errorf("handshake message too short")
|
|
}
|
|
|
|
// Handshake type: 0x01 = ClientHello.
|
|
if data[0] != 0x01 {
|
|
return "", fmt.Errorf("not a ClientHello (type 0x%02x)", data[0])
|
|
}
|
|
|
|
// Handshake length (3 bytes).
|
|
hsLen := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
|
|
data = data[4:]
|
|
if len(data) < hsLen {
|
|
return "", fmt.Errorf("ClientHello truncated")
|
|
}
|
|
data = data[:hsLen]
|
|
|
|
// Skip client version (2 bytes) + random (32 bytes).
|
|
if len(data) < 34 {
|
|
return "", fmt.Errorf("ClientHello too short for version+random")
|
|
}
|
|
data = data[34:]
|
|
|
|
// Skip session ID (1-byte length prefix).
|
|
if len(data) < 1 {
|
|
return "", fmt.Errorf("ClientHello too short for session ID length")
|
|
}
|
|
sidLen := int(data[0])
|
|
data = data[1:]
|
|
if len(data) < sidLen {
|
|
return "", fmt.Errorf("ClientHello truncated at session ID")
|
|
}
|
|
data = data[sidLen:]
|
|
|
|
// Skip cipher suites (2-byte length prefix).
|
|
if len(data) < 2 {
|
|
return "", fmt.Errorf("ClientHello too short for cipher suites length")
|
|
}
|
|
csLen := int(binary.BigEndian.Uint16(data[:2]))
|
|
data = data[2:]
|
|
if len(data) < csLen {
|
|
return "", fmt.Errorf("ClientHello truncated at cipher suites")
|
|
}
|
|
data = data[csLen:]
|
|
|
|
// Skip compression methods (1-byte length prefix).
|
|
if len(data) < 1 {
|
|
return "", fmt.Errorf("ClientHello too short for compression methods length")
|
|
}
|
|
cmLen := int(data[0])
|
|
data = data[1:]
|
|
if len(data) < cmLen {
|
|
return "", fmt.Errorf("ClientHello truncated at compression methods")
|
|
}
|
|
data = data[cmLen:]
|
|
|
|
// Extensions (2-byte total length).
|
|
if len(data) < 2 {
|
|
return "", fmt.Errorf("no extensions in ClientHello")
|
|
}
|
|
extLen := int(binary.BigEndian.Uint16(data[:2]))
|
|
data = data[2:]
|
|
if len(data) < extLen {
|
|
return "", fmt.Errorf("ClientHello truncated at extensions")
|
|
}
|
|
data = data[:extLen]
|
|
|
|
return findSNI(data)
|
|
}
|
|
|
|
func findSNI(data []byte) (string, error) {
|
|
for len(data) >= 4 {
|
|
extType := binary.BigEndian.Uint16(data[:2])
|
|
extDataLen := int(binary.BigEndian.Uint16(data[2:4]))
|
|
data = data[4:]
|
|
if len(data) < extDataLen {
|
|
return "", fmt.Errorf("extension truncated")
|
|
}
|
|
|
|
if extType == 0x0000 { // server_name
|
|
return parseServerNameExtension(data[:extDataLen])
|
|
}
|
|
|
|
data = data[extDataLen:]
|
|
}
|
|
|
|
return "", fmt.Errorf("no SNI extension found")
|
|
}
|
|
|
|
func parseServerNameExtension(data []byte) (string, error) {
|
|
if len(data) < 2 {
|
|
return "", fmt.Errorf("server_name extension too short")
|
|
}
|
|
|
|
// Server name list length.
|
|
listLen := int(binary.BigEndian.Uint16(data[:2]))
|
|
data = data[2:]
|
|
if len(data) < listLen {
|
|
return "", fmt.Errorf("server_name list truncated")
|
|
}
|
|
data = data[:listLen]
|
|
|
|
for len(data) >= 3 {
|
|
nameType := data[0]
|
|
nameLen := int(binary.BigEndian.Uint16(data[1:3]))
|
|
data = data[3:]
|
|
if len(data) < nameLen {
|
|
return "", fmt.Errorf("server_name entry truncated")
|
|
}
|
|
|
|
if nameType == 0x00 { // hostname
|
|
return strings.ToLower(string(data[:nameLen])), nil
|
|
}
|
|
|
|
data = data[nameLen:]
|
|
}
|
|
|
|
return "", fmt.Errorf("no hostname in server_name extension")
|
|
}
|