package proxy import ( "context" "errors" "io" "net" "sync" "time" ) // Result holds the outcome of a relay operation. type Result struct { ClientBytes int64 // bytes sent from client to backend BackendBytes int64 // bytes sent from backend to client } // Relay performs bidirectional byte copying between client and backend. // The peeked bytes (the TLS ClientHello) are written to the backend first. // Relay blocks until both directions are done or ctx is cancelled. func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTimeout time.Duration) (Result, error) { // Forward the buffered ClientHello to the backend. if len(peeked) > 0 { if _, err := backend.Write(peeked); err != nil { return Result{}, err } } // Cancel context closes both connections to unblock copy goroutines. ctx, cancel := context.WithCancel(ctx) defer cancel() go func() { <-ctx.Done() _ = client.Close() _ = backend.Close() }() var ( result Result wg sync.WaitGroup errC2B error errB2C error ) wg.Add(2) // client → backend go func() { defer wg.Done() result.ClientBytes, errC2B = copyWithIdleTimeout(backend, client, idleTimeout) // Half-close backend's write side. if hc, ok := backend.(interface{ CloseWrite() error }); ok { _ = hc.CloseWrite() } }() // backend → client go func() { defer wg.Done() result.BackendBytes, errB2C = copyWithIdleTimeout(client, backend, idleTimeout) // Half-close client's write side. if hc, ok := client.(interface{ CloseWrite() error }); ok { _ = hc.CloseWrite() } }() wg.Wait() // If context was cancelled, that's the primary error. if ctx.Err() != nil { return result, ctx.Err() } // Return the first meaningful error, if any. if errC2B != nil { return result, errC2B } return result, errB2C } // copyWithIdleTimeout copies from src to dst, resetting the idle deadline // on each successful read. func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, error) { buf := make([]byte, 32*1024) var total int64 for { _ = src.SetReadDeadline(time.Now().Add(idleTimeout)) nr, readErr := src.Read(buf) if nr > 0 { _ = dst.SetWriteDeadline(time.Now().Add(idleTimeout)) nw, writeErr := dst.Write(buf[:nr]) total += int64(nw) if writeErr != nil { return total, writeErr } } if readErr != nil { if errors.Is(readErr, io.EOF) { return total, nil } return total, readErr } } }