package sni import ( "encoding/binary" "fmt" "io" "net" "strings" "time" ) const maxBufferSize = 16384 // 16 KiB, max TLS record size // Extract reads the TLS ClientHello from conn and returns the SNI hostname. // The returned peeked bytes contain the full ClientHello and must be forwarded // to the backend before starting the bidirectional copy. // // A read deadline is set on the connection to prevent slowloris attacks. func Extract(conn net.Conn, deadline time.Time) (hostname string, peeked []byte, err error) { conn.SetReadDeadline(deadline) defer conn.SetReadDeadline(time.Time{}) // Read TLS record header (5 bytes). header := make([]byte, 5) if _, err := io.ReadFull(conn, header); err != nil { return "", nil, fmt.Errorf("reading TLS record header: %w", err) } // Verify this is a TLS handshake record (content type 0x16). if header[0] != 0x16 { return "", nil, fmt.Errorf("not a TLS handshake record (type 0x%02x)", header[0]) } // Record length. recordLen := int(binary.BigEndian.Uint16(header[3:5])) if recordLen == 0 || recordLen > maxBufferSize-5 { return "", nil, fmt.Errorf("TLS record length %d out of range", recordLen) } // Read the full record body. buf := make([]byte, 5+recordLen) copy(buf, header) if _, err := io.ReadFull(conn, buf[5:]); err != nil { return "", nil, fmt.Errorf("reading TLS record body: %w", err) } // Parse the handshake message from the record body. hostname, err = parseClientHello(buf[5:]) if err != nil { return "", nil, err } return hostname, buf, nil } func parseClientHello(data []byte) (string, error) { if len(data) < 4 { return "", fmt.Errorf("handshake message too short") } // Handshake type: 0x01 = ClientHello. if data[0] != 0x01 { return "", fmt.Errorf("not a ClientHello (type 0x%02x)", data[0]) } // Handshake length (3 bytes). hsLen := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) data = data[4:] if len(data) < hsLen { return "", fmt.Errorf("ClientHello truncated") } data = data[:hsLen] // Skip client version (2 bytes) + random (32 bytes). if len(data) < 34 { return "", fmt.Errorf("ClientHello too short for version+random") } data = data[34:] // Skip session ID (1-byte length prefix). if len(data) < 1 { return "", fmt.Errorf("ClientHello too short for session ID length") } sidLen := int(data[0]) data = data[1:] if len(data) < sidLen { return "", fmt.Errorf("ClientHello truncated at session ID") } data = data[sidLen:] // Skip cipher suites (2-byte length prefix). if len(data) < 2 { return "", fmt.Errorf("ClientHello too short for cipher suites length") } csLen := int(binary.BigEndian.Uint16(data[:2])) data = data[2:] if len(data) < csLen { return "", fmt.Errorf("ClientHello truncated at cipher suites") } data = data[csLen:] // Skip compression methods (1-byte length prefix). if len(data) < 1 { return "", fmt.Errorf("ClientHello too short for compression methods length") } cmLen := int(data[0]) data = data[1:] if len(data) < cmLen { return "", fmt.Errorf("ClientHello truncated at compression methods") } data = data[cmLen:] // Extensions (2-byte total length). if len(data) < 2 { return "", fmt.Errorf("no extensions in ClientHello") } extLen := int(binary.BigEndian.Uint16(data[:2])) data = data[2:] if len(data) < extLen { return "", fmt.Errorf("ClientHello truncated at extensions") } data = data[:extLen] return findSNI(data) } func findSNI(data []byte) (string, error) { for len(data) >= 4 { extType := binary.BigEndian.Uint16(data[:2]) extDataLen := int(binary.BigEndian.Uint16(data[2:4])) data = data[4:] if len(data) < extDataLen { return "", fmt.Errorf("extension truncated") } if extType == 0x0000 { // server_name return parseServerNameExtension(data[:extDataLen]) } data = data[extDataLen:] } return "", fmt.Errorf("no SNI extension found") } func parseServerNameExtension(data []byte) (string, error) { if len(data) < 2 { return "", fmt.Errorf("server_name extension too short") } // Server name list length. listLen := int(binary.BigEndian.Uint16(data[:2])) data = data[2:] if len(data) < listLen { return "", fmt.Errorf("server_name list truncated") } data = data[:listLen] for len(data) >= 3 { nameType := data[0] nameLen := int(binary.BigEndian.Uint16(data[1:3])) data = data[3:] if len(data) < nameLen { return "", fmt.Errorf("server_name entry truncated") } if nameType == 0x00 { // hostname return strings.ToLower(string(data[:nameLen])), nil } data = data[nameLen:] } return "", fmt.Errorf("no hostname in server_name extension") }