package sni import ( "encoding/binary" "net" "testing" "time" ) func TestExtract(t *testing.T) { tests := []struct { name string sni string wantSNI string wantErr bool }{ {"basic", "example.com", "example.com", false}, {"case insensitive", "FoO.BaR.CoM", "foo.bar.com", false}, {"subdomain", "metacrypt.metacircular.net", "metacrypt.metacircular.net", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { client, server := net.Pipe() defer client.Close() defer server.Close() hello := buildClientHello(tt.sni) go func() { client.Write(hello) }() hostname, peeked, err := Extract(server, time.Now().Add(5*time.Second)) if tt.wantErr { if err == nil { t.Fatal("expected error, got nil") } return } if err != nil { t.Fatalf("unexpected error: %v", err) } if hostname != tt.wantSNI { t.Fatalf("got hostname %q, want %q", hostname, tt.wantSNI) } if len(peeked) != len(hello) { t.Fatalf("peeked %d bytes, want %d", len(peeked), len(hello)) } }) } } func TestExtractNoSNI(t *testing.T) { client, server := net.Pipe() defer client.Close() defer server.Close() hello := buildClientHelloNoSNI() go func() { client.Write(hello) }() _, _, err := Extract(server, time.Now().Add(5*time.Second)) if err == nil { t.Fatal("expected error for ClientHello without SNI") } } func TestExtractNotTLS(t *testing.T) { client, server := net.Pipe() defer client.Close() defer server.Close() go func() { client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")) }() _, _, err := Extract(server, time.Now().Add(5*time.Second)) if err == nil { t.Fatal("expected error for non-TLS data") } } func TestExtractTruncated(t *testing.T) { client, server := net.Pipe() defer client.Close() defer server.Close() go func() { // Write just the TLS record header, then close. client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50}) client.Close() }() _, _, err := Extract(server, time.Now().Add(5*time.Second)) if err == nil { t.Fatal("expected error for truncated record") } } func TestExtractOversizedRecord(t *testing.T) { client, server := net.Pipe() defer client.Close() defer server.Close() go func() { // Record header claiming a length larger than 16 KiB. header := []byte{0x16, 0x03, 0x01} header = binary.BigEndian.AppendUint16(header, 16384) // exceeds maxBufferSize - 5 client.Write(header) client.Close() }() _, _, err := Extract(server, time.Now().Add(5*time.Second)) if err == nil { t.Fatal("expected error for oversized record") } } func TestExtractMultipleExtensions(t *testing.T) { client, server := net.Pipe() defer client.Close() defer server.Close() hello := buildClientHelloWithExtraExtensions("target.example.com") go func() { client.Write(hello) }() hostname, _, err := Extract(server, time.Now().Add(5*time.Second)) if err != nil { t.Fatalf("unexpected error: %v", err) } if hostname != "target.example.com" { t.Fatalf("got hostname %q, want %q", hostname, "target.example.com") } } // buildClientHello constructs a minimal TLS 1.2 ClientHello with an SNI extension. func buildClientHello(serverName string) []byte { return buildClientHelloWithExtensions(sniExtension(serverName)) } // buildClientHelloNoSNI constructs a ClientHello with no extensions. func buildClientHelloNoSNI() []byte { return buildClientHelloWithExtensions(nil) } // buildClientHelloWithExtraExtensions puts a dummy extension before the SNI. func buildClientHelloWithExtraExtensions(serverName string) []byte { // Dummy extension (type 0xFF01, empty data). dummy := []byte{0xFF, 0x01, 0x00, 0x00} ext := append(dummy, sniExtension(serverName)...) return buildClientHelloWithExtensions(ext) } func buildClientHelloWithExtensions(extensions []byte) []byte { var hello []byte // Client version: TLS 1.2. hello = append(hello, 0x03, 0x03) // Random: 32 bytes of zeros. hello = append(hello, make([]byte, 32)...) // Session ID: empty. hello = append(hello, 0x00) // Cipher suites: one suite (TLS_RSA_WITH_AES_128_GCM_SHA256). hello = append(hello, 0x00, 0x02, 0x00, 0x9C) // Compression methods: null. hello = append(hello, 0x01, 0x00) // Extensions. if len(extensions) > 0 { hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions))) hello = append(hello, extensions...) } // Wrap in handshake header (type 0x01 = ClientHello). handshake := []byte{0x01, 0x00, 0x00, 0x00} handshake[1] = byte(len(hello) >> 16) handshake[2] = byte(len(hello) >> 8) handshake[3] = byte(len(hello)) handshake = append(handshake, hello...) // Wrap in TLS record header (type 0x16 = handshake, version TLS 1.0). record := []byte{0x16, 0x03, 0x01} record = binary.BigEndian.AppendUint16(record, uint16(len(handshake))) record = append(record, handshake...) return record } func sniExtension(serverName string) []byte { name := []byte(serverName) // Server name entry: type 0x00 (hostname), length, name. var entry []byte entry = append(entry, 0x00) entry = binary.BigEndian.AppendUint16(entry, uint16(len(name))) entry = append(entry, name...) // Server name list: length prefix. var list []byte list = binary.BigEndian.AppendUint16(list, uint16(len(entry))) list = append(list, entry...) // Extension: type 0x0000 (server_name), length, data. var ext []byte ext = binary.BigEndian.AppendUint16(ext, 0x0000) ext = binary.BigEndian.AppendUint16(ext, uint16(len(list))) ext = append(ext, list...) return ext }