mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)
This commit is contained in:
133
proxy/internal/tcp/bench_test.go
Normal file
133
proxy/internal/tcp/bench_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// BenchmarkPeekClientHello_TLS measures the overhead of peeking at a real
|
||||
// TLS ClientHello and extracting the SNI. This is the per-connection cost
|
||||
// added to every TLS connection on the main listener.
|
||||
func BenchmarkPeekClientHello_TLS(b *testing.B) {
|
||||
// Pre-generate a ClientHello by capturing what crypto/tls sends.
|
||||
clientConn, serverConn := net.Pipe()
|
||||
go func() {
|
||||
tlsConn := tls.Client(clientConn, &tls.Config{
|
||||
ServerName: "app.example.com",
|
||||
InsecureSkipVerify: true, //nolint:gosec
|
||||
})
|
||||
_ = tlsConn.Handshake()
|
||||
}()
|
||||
|
||||
var hello []byte
|
||||
buf := make([]byte, 16384)
|
||||
n, _ := serverConn.Read(buf)
|
||||
hello = make([]byte, n)
|
||||
copy(hello, buf[:n])
|
||||
clientConn.Close()
|
||||
serverConn.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for b.Loop() {
|
||||
r := bytes.NewReader(hello)
|
||||
conn := &readerConn{Reader: r}
|
||||
sni, wrapped, err := PeekClientHello(conn)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if sni != "app.example.com" {
|
||||
b.Fatalf("unexpected SNI: %q", sni)
|
||||
}
|
||||
// Simulate draining the peeked bytes (what the HTTP server would do).
|
||||
_, _ = io.Copy(io.Discard, wrapped)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPeekClientHello_NonTLS measures peek overhead for non-TLS
|
||||
// connections that hit the fast non-handshake exit path.
|
||||
func BenchmarkPeekClientHello_NonTLS(b *testing.B) {
|
||||
httpReq := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for b.Loop() {
|
||||
r := bytes.NewReader(httpReq)
|
||||
conn := &readerConn{Reader: r}
|
||||
_, wrapped, err := PeekClientHello(conn)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
_, _ = io.Copy(io.Discard, wrapped)
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkPeekedConn_Read measures the read overhead of the peekedConn
|
||||
// wrapper compared to a plain connection read. The peeked bytes use
|
||||
// io.MultiReader which adds one indirection per Read call.
|
||||
func BenchmarkPeekedConn_Read(b *testing.B) {
|
||||
data := make([]byte, 4096)
|
||||
peeked := make([]byte, 512)
|
||||
buf := make([]byte, 1024)
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for b.Loop() {
|
||||
r := bytes.NewReader(data)
|
||||
conn := &readerConn{Reader: r}
|
||||
pc := newPeekedConn(conn, peeked)
|
||||
for {
|
||||
_, err := pc.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkExtractSNI measures just the in-memory SNI parsing cost,
|
||||
// excluding I/O.
|
||||
func BenchmarkExtractSNI(b *testing.B) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
go func() {
|
||||
tlsConn := tls.Client(clientConn, &tls.Config{
|
||||
ServerName: "app.example.com",
|
||||
InsecureSkipVerify: true, //nolint:gosec
|
||||
})
|
||||
_ = tlsConn.Handshake()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 16384)
|
||||
n, _ := serverConn.Read(buf)
|
||||
payload := make([]byte, n-tlsRecordHeaderLen)
|
||||
copy(payload, buf[tlsRecordHeaderLen:n])
|
||||
clientConn.Close()
|
||||
serverConn.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
for b.Loop() {
|
||||
sni := extractSNI(payload)
|
||||
if sni != "app.example.com" {
|
||||
b.Fatalf("unexpected SNI: %q", sni)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// readerConn wraps an io.Reader as a net.Conn for benchmarking.
|
||||
// Only Read is functional; all other methods are no-ops.
|
||||
type readerConn struct {
|
||||
io.Reader
|
||||
net.Conn
|
||||
}
|
||||
|
||||
func (c *readerConn) Read(b []byte) (int, error) {
|
||||
return c.Reader.Read(b)
|
||||
}
|
||||
76
proxy/internal/tcp/chanlistener.go
Normal file
76
proxy/internal/tcp/chanlistener.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// chanListener implements net.Listener by reading connections from a channel.
|
||||
// It allows the SNI router to feed HTTP connections to http.Server.ServeTLS.
|
||||
type chanListener struct {
|
||||
ch chan net.Conn
|
||||
addr net.Addr
|
||||
once sync.Once
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newChanListener(ch chan net.Conn, addr net.Addr) *chanListener {
|
||||
return &chanListener{
|
||||
ch: ch,
|
||||
addr: addr,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection from the channel.
|
||||
func (l *chanListener) Accept() (net.Conn, error) {
|
||||
for {
|
||||
select {
|
||||
case conn, ok := <-l.ch:
|
||||
if !ok {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return conn, nil
|
||||
case <-l.closed:
|
||||
// Drain buffered connections before returning.
|
||||
for {
|
||||
select {
|
||||
case conn, ok := <-l.ch:
|
||||
if !ok {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
_ = conn.Close()
|
||||
default:
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close signals the listener to stop accepting connections and drains
|
||||
// any buffered connections that have not yet been accepted.
|
||||
func (l *chanListener) Close() error {
|
||||
l.once.Do(func() {
|
||||
close(l.closed)
|
||||
for {
|
||||
select {
|
||||
case conn, ok := <-l.ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
_ = conn.Close()
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Addr returns the listener's network address.
|
||||
func (l *chanListener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
var _ net.Listener = (*chanListener)(nil)
|
||||
39
proxy/internal/tcp/peekedconn.go
Normal file
39
proxy/internal/tcp/peekedconn.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// peekedConn wraps a net.Conn and prepends previously peeked bytes
|
||||
// so that readers see the full original stream transparently.
|
||||
type peekedConn struct {
|
||||
net.Conn
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func newPeekedConn(conn net.Conn, peeked []byte) *peekedConn {
|
||||
return &peekedConn{
|
||||
Conn: conn,
|
||||
reader: io.MultiReader(bytes.NewReader(peeked), conn),
|
||||
}
|
||||
}
|
||||
|
||||
// Read replays the peeked bytes first, then reads from the underlying conn.
|
||||
func (c *peekedConn) Read(b []byte) (int, error) {
|
||||
return c.reader.Read(b)
|
||||
}
|
||||
|
||||
// CloseWrite delegates to the underlying connection if it supports
|
||||
// half-close (e.g. *net.TCPConn). Without this, embedding net.Conn
|
||||
// as an interface hides the concrete type's CloseWrite method, making
|
||||
// half-close a silent no-op for all SNI-routed connections.
|
||||
func (c *peekedConn) CloseWrite() error {
|
||||
if hc, ok := c.Conn.(halfCloser); ok {
|
||||
return hc.CloseWrite()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ halfCloser = (*peekedConn)(nil)
|
||||
29
proxy/internal/tcp/proxyprotocol.go
Normal file
29
proxy/internal/tcp/proxyprotocol.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"github.com/pires/go-proxyproto"
|
||||
)
|
||||
|
||||
// writeProxyProtoV2 sends a PROXY protocol v2 header to the backend connection,
|
||||
// conveying the real client address.
|
||||
func writeProxyProtoV2(client, backend net.Conn) error {
|
||||
tp := proxyproto.TCPv4
|
||||
if addr, ok := client.RemoteAddr().(*net.TCPAddr); ok && addr.IP.To4() == nil {
|
||||
tp = proxyproto.TCPv6
|
||||
}
|
||||
|
||||
header := &proxyproto.Header{
|
||||
Version: 2,
|
||||
Command: proxyproto.PROXY,
|
||||
TransportProtocol: tp,
|
||||
SourceAddr: client.RemoteAddr(),
|
||||
DestinationAddr: client.LocalAddr(),
|
||||
}
|
||||
if _, err := header.WriteTo(backend); err != nil {
|
||||
return fmt.Errorf("write PROXY protocol v2 header: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
128
proxy/internal/tcp/proxyprotocol_test.go
Normal file
128
proxy/internal/tcp/proxyprotocol_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/pires/go-proxyproto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWriteProxyProtoV2_IPv4(t *testing.T) {
|
||||
// Set up a real TCP listener and dial to get connections with real addresses.
|
||||
ln, err := net.Listen("tcp4", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
var serverConn net.Conn
|
||||
accepted := make(chan struct{})
|
||||
go func() {
|
||||
var err error
|
||||
serverConn, err = ln.Accept()
|
||||
if err != nil {
|
||||
t.Error("accept failed:", err)
|
||||
}
|
||||
close(accepted)
|
||||
}()
|
||||
|
||||
clientConn, err := net.Dial("tcp4", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
<-accepted
|
||||
defer serverConn.Close()
|
||||
|
||||
// Use a pipe as the backend: write the header to one end, read from the other.
|
||||
backendRead, backendWrite := net.Pipe()
|
||||
defer backendRead.Close()
|
||||
defer backendWrite.Close()
|
||||
|
||||
// serverConn is the "client" arg: RemoteAddr is the source, LocalAddr is the destination.
|
||||
writeDone := make(chan error, 1)
|
||||
go func() {
|
||||
writeDone <- writeProxyProtoV2(serverConn, backendWrite)
|
||||
}()
|
||||
|
||||
// Read the PROXY protocol header from the backend read side.
|
||||
header, err := proxyproto.Read(bufio.NewReader(backendRead))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, header, "should have received a proxy protocol header")
|
||||
|
||||
writeErr := <-writeDone
|
||||
require.NoError(t, writeErr)
|
||||
|
||||
assert.Equal(t, byte(2), header.Version, "version should be 2")
|
||||
assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY")
|
||||
assert.Equal(t, proxyproto.TCPv4, header.TransportProtocol, "transport should be TCPv4")
|
||||
|
||||
// serverConn.RemoteAddr() is the client's address (source in the header).
|
||||
expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr)
|
||||
actualSrc := header.SourceAddr.(*net.TCPAddr)
|
||||
assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr")
|
||||
assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr")
|
||||
|
||||
// serverConn.LocalAddr() is the server's address (destination in the header).
|
||||
expectedDst := serverConn.LocalAddr().(*net.TCPAddr)
|
||||
actualDst := header.DestinationAddr.(*net.TCPAddr)
|
||||
assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr")
|
||||
assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr")
|
||||
}
|
||||
|
||||
func TestWriteProxyProtoV2_IPv6(t *testing.T) {
|
||||
// Set up a real TCP6 listener on loopback.
|
||||
ln, err := net.Listen("tcp6", "[::1]:0")
|
||||
if err != nil {
|
||||
t.Skip("IPv6 not available:", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
var serverConn net.Conn
|
||||
accepted := make(chan struct{})
|
||||
go func() {
|
||||
var err error
|
||||
serverConn, err = ln.Accept()
|
||||
if err != nil {
|
||||
t.Error("accept failed:", err)
|
||||
}
|
||||
close(accepted)
|
||||
}()
|
||||
|
||||
clientConn, err := net.Dial("tcp6", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer clientConn.Close()
|
||||
|
||||
<-accepted
|
||||
defer serverConn.Close()
|
||||
|
||||
backendRead, backendWrite := net.Pipe()
|
||||
defer backendRead.Close()
|
||||
defer backendWrite.Close()
|
||||
|
||||
writeDone := make(chan error, 1)
|
||||
go func() {
|
||||
writeDone <- writeProxyProtoV2(serverConn, backendWrite)
|
||||
}()
|
||||
|
||||
header, err := proxyproto.Read(bufio.NewReader(backendRead))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, header, "should have received a proxy protocol header")
|
||||
|
||||
writeErr := <-writeDone
|
||||
require.NoError(t, writeErr)
|
||||
|
||||
assert.Equal(t, byte(2), header.Version, "version should be 2")
|
||||
assert.Equal(t, proxyproto.PROXY, header.Command, "command should be PROXY")
|
||||
assert.Equal(t, proxyproto.TCPv6, header.TransportProtocol, "transport should be TCPv6")
|
||||
|
||||
expectedSrc := serverConn.RemoteAddr().(*net.TCPAddr)
|
||||
actualSrc := header.SourceAddr.(*net.TCPAddr)
|
||||
assert.Equal(t, expectedSrc.IP.String(), actualSrc.IP.String(), "source IP should match client remote addr")
|
||||
assert.Equal(t, expectedSrc.Port, actualSrc.Port, "source port should match client remote addr")
|
||||
|
||||
expectedDst := serverConn.LocalAddr().(*net.TCPAddr)
|
||||
actualDst := header.DestinationAddr.(*net.TCPAddr)
|
||||
assert.Equal(t, expectedDst.IP.String(), actualDst.IP.String(), "destination IP should match server local addr")
|
||||
assert.Equal(t, expectedDst.Port, actualDst.Port, "destination port should match server local addr")
|
||||
}
|
||||
156
proxy/internal/tcp/relay.go
Normal file
156
proxy/internal/tcp/relay.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
)
|
||||
|
||||
// errIdleTimeout is returned when a relay connection is closed due to inactivity.
|
||||
var errIdleTimeout = errors.New("idle timeout")
|
||||
|
||||
// DefaultIdleTimeout is the default idle timeout for TCP relay connections.
|
||||
// A zero value disables idle timeout checking.
|
||||
const DefaultIdleTimeout = 5 * time.Minute
|
||||
|
||||
// halfCloser is implemented by connections that support half-close
|
||||
// (e.g. *net.TCPConn). When one copy direction finishes, we signal
|
||||
// EOF to the remote by closing the write side while keeping the read
|
||||
// side open so the other direction can drain.
|
||||
type halfCloser interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
// copyBufPool avoids allocating a new 32KB buffer per io.Copy call.
|
||||
var copyBufPool = sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, 32*1024)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
// Relay copies data bidirectionally between src and dst until both
|
||||
// sides are done or the context is canceled. When idleTimeout is
|
||||
// non-zero, each direction's read is deadline-guarded; if no data
|
||||
// flows within the timeout the connection is torn down. When one
|
||||
// direction finishes, it half-closes the write side of the
|
||||
// destination (if supported) to signal EOF, allowing the other
|
||||
// direction to drain gracefully before the full connection teardown.
|
||||
func Relay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (srcToDst, dstToSrc int64) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = src.Close()
|
||||
_ = dst.Close()
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var errSrcToDst, errDstToSrc error
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
srcToDst, errSrcToDst = copyWithIdleTimeout(dst, src, idleTimeout)
|
||||
halfClose(dst)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
dstToSrc, errDstToSrc = copyWithIdleTimeout(src, dst, idleTimeout)
|
||||
halfClose(src)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errors.Is(errSrcToDst, errIdleTimeout) || errors.Is(errDstToSrc, errIdleTimeout) {
|
||||
logger.Debug("relay closed due to idle timeout")
|
||||
}
|
||||
if errSrcToDst != nil && !isExpectedCopyError(errSrcToDst) {
|
||||
logger.Debugf("relay copy error (src→dst): %v", errSrcToDst)
|
||||
}
|
||||
if errDstToSrc != nil && !isExpectedCopyError(errDstToSrc) {
|
||||
logger.Debugf("relay copy error (dst→src): %v", errDstToSrc)
|
||||
}
|
||||
|
||||
return srcToDst, dstToSrc
|
||||
}
|
||||
|
||||
// copyWithIdleTimeout copies from src to dst using a pooled buffer.
|
||||
// When idleTimeout > 0 it sets a read deadline on src before each
|
||||
// read and treats a timeout as an idle-triggered close.
|
||||
func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) {
|
||||
bufp := copyBufPool.Get().(*[]byte)
|
||||
defer copyBufPool.Put(bufp)
|
||||
|
||||
if idleTimeout <= 0 {
|
||||
return io.CopyBuffer(dst, src, *bufp)
|
||||
}
|
||||
|
||||
conn, ok := src.(net.Conn)
|
||||
if !ok {
|
||||
return io.CopyBuffer(dst, src, *bufp)
|
||||
}
|
||||
|
||||
buf := *bufp
|
||||
var total int64
|
||||
for {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
|
||||
return total, err
|
||||
}
|
||||
nr, readErr := src.Read(buf)
|
||||
if nr > 0 {
|
||||
n, err := checkedWrite(dst, buf[:nr])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
if netutil.IsTimeout(readErr) {
|
||||
return total, errIdleTimeout
|
||||
}
|
||||
return total, readErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkedWrite writes buf to dst and returns the number of bytes written.
|
||||
// It guards against short writes and negative counts per io.Copy convention.
|
||||
func checkedWrite(dst io.Writer, buf []byte) (int64, error) {
|
||||
nw, err := dst.Write(buf)
|
||||
if nw < 0 || nw > len(buf) {
|
||||
nw = 0
|
||||
}
|
||||
if err != nil {
|
||||
return int64(nw), err
|
||||
}
|
||||
if nw != len(buf) {
|
||||
return int64(nw), io.ErrShortWrite
|
||||
}
|
||||
return int64(nw), nil
|
||||
}
|
||||
|
||||
func isExpectedCopyError(err error) bool {
|
||||
return errors.Is(err, errIdleTimeout) || netutil.IsExpectedError(err)
|
||||
}
|
||||
|
||||
// halfClose attempts to half-close the write side of the connection.
|
||||
// If the connection does not support half-close, this is a no-op.
|
||||
func halfClose(conn net.Conn) {
|
||||
if hc, ok := conn.(halfCloser); ok {
|
||||
// Best-effort; the full close will follow shortly.
|
||||
_ = hc.CloseWrite()
|
||||
}
|
||||
}
|
||||
210
proxy/internal/tcp/relay_test.go
Normal file
210
proxy/internal/tcp/relay_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/netutil"
|
||||
)
|
||||
|
||||
func TestRelay_BidirectionalCopy(t *testing.T) {
|
||||
srcClient, srcServer := net.Pipe()
|
||||
dstClient, dstServer := net.Pipe()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
ctx := context.Background()
|
||||
|
||||
srcData := []byte("hello from src")
|
||||
dstData := []byte("hello from dst")
|
||||
|
||||
// dst side: write response first, then read + close.
|
||||
go func() {
|
||||
_, _ = dstClient.Write(dstData)
|
||||
buf := make([]byte, 256)
|
||||
_, _ = dstClient.Read(buf)
|
||||
dstClient.Close()
|
||||
}()
|
||||
|
||||
// src side: read the response, then send data + close.
|
||||
go func() {
|
||||
buf := make([]byte, 256)
|
||||
_, _ = srcClient.Read(buf)
|
||||
_, _ = srcClient.Write(srcData)
|
||||
srcClient.Close()
|
||||
}()
|
||||
|
||||
s2d, d2s := Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
|
||||
assert.Equal(t, int64(len(srcData)), s2d, "bytes src→dst")
|
||||
assert.Equal(t, int64(len(dstData)), d2s, "bytes dst→src")
|
||||
}
|
||||
|
||||
func TestRelay_ContextCancellation(t *testing.T) {
|
||||
srcClient, srcServer := net.Pipe()
|
||||
dstClient, dstServer := net.Pipe()
|
||||
defer srcClient.Close()
|
||||
defer dstClient.Close()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Cancel should cause Relay to return.
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Relay did not return after context cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelay_OneSideClosed(t *testing.T) {
|
||||
srcClient, srcServer := net.Pipe()
|
||||
dstClient, dstServer := net.Pipe()
|
||||
defer dstClient.Close()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
ctx := context.Background()
|
||||
|
||||
// Close src immediately. Relay should complete without hanging.
|
||||
srcClient.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Relay did not return after one side closed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelay_LargeTransfer(t *testing.T) {
|
||||
srcClient, srcServer := net.Pipe()
|
||||
dstClient, dstServer := net.Pipe()
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
ctx := context.Background()
|
||||
|
||||
// 1MB of data.
|
||||
data := make([]byte, 1<<20)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
go func() {
|
||||
_, _ = srcClient.Write(data)
|
||||
srcClient.Close()
|
||||
}()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
received, err := io.ReadAll(dstClient)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
if len(received) != len(data) {
|
||||
errCh <- fmt.Errorf("expected %d bytes, got %d", len(data), len(received))
|
||||
return
|
||||
}
|
||||
errCh <- nil
|
||||
dstClient.Close()
|
||||
}()
|
||||
|
||||
s2d, _ := Relay(ctx, logger, srcServer, dstServer, 0)
|
||||
assert.Equal(t, int64(len(data)), s2d, "should transfer all bytes")
|
||||
require.NoError(t, <-errCh)
|
||||
}
|
||||
|
||||
func TestRelay_IdleTimeout(t *testing.T) {
|
||||
// Use real TCP connections so SetReadDeadline works (net.Pipe
|
||||
// does not support deadlines).
|
||||
srcLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer srcLn.Close()
|
||||
|
||||
dstLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dstLn.Close()
|
||||
|
||||
srcClient, err := net.Dial("tcp", srcLn.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer srcClient.Close()
|
||||
|
||||
srcServer, err := srcLn.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dstClient, err := net.Dial("tcp", dstLn.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer dstClient.Close()
|
||||
|
||||
dstServer, err := dstLn.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
logger := log.NewEntry(log.StandardLogger())
|
||||
ctx := context.Background()
|
||||
|
||||
// Send initial data to prove the relay works.
|
||||
go func() {
|
||||
_, _ = srcClient.Write([]byte("ping"))
|
||||
}()
|
||||
|
||||
done := make(chan struct{})
|
||||
var s2d, d2s int64
|
||||
go func() {
|
||||
s2d, d2s = Relay(ctx, logger, srcServer, dstServer, 200*time.Millisecond)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Read the forwarded data on the dst side.
|
||||
buf := make([]byte, 64)
|
||||
n, err := dstClient.Read(buf)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ping", string(buf[:n]))
|
||||
|
||||
// Now stop sending. The relay should close after the idle timeout.
|
||||
select {
|
||||
case <-done:
|
||||
assert.Greater(t, s2d, int64(0), "should have transferred initial data")
|
||||
_ = d2s
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Relay did not exit after idle timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsExpectedError(t *testing.T) {
|
||||
assert.True(t, netutil.IsExpectedError(net.ErrClosed))
|
||||
assert.True(t, netutil.IsExpectedError(context.Canceled))
|
||||
assert.True(t, netutil.IsExpectedError(io.EOF))
|
||||
assert.False(t, netutil.IsExpectedError(io.ErrUnexpectedEOF))
|
||||
}
|
||||
570
proxy/internal/tcp/router.go
Normal file
570
proxy/internal/tcp/router.go
Normal file
@@ -0,0 +1,570 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/accesslog"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
// defaultDialTimeout is the fallback dial timeout when no per-route
|
||||
// timeout is configured.
|
||||
const defaultDialTimeout = 30 * time.Second
|
||||
|
||||
// SNIHost is a typed key for SNI hostname lookups.
|
||||
type SNIHost string
|
||||
|
||||
// RouteType specifies how a connection should be handled.
|
||||
type RouteType int
|
||||
|
||||
const (
|
||||
// RouteHTTP routes the connection through the HTTP reverse proxy.
|
||||
RouteHTTP RouteType = iota
|
||||
// RouteTCP relays the connection directly to the backend (TLS passthrough).
|
||||
RouteTCP
|
||||
)
|
||||
|
||||
const (
|
||||
// sniPeekTimeout is the deadline for reading the TLS ClientHello.
|
||||
sniPeekTimeout = 5 * time.Second
|
||||
// DefaultDrainTimeout is the default grace period for in-flight relay
|
||||
// connections to finish during shutdown.
|
||||
DefaultDrainTimeout = 30 * time.Second
|
||||
// DefaultMaxRelayConns is the default cap on concurrent TCP relay connections per router.
|
||||
DefaultMaxRelayConns = 4096
|
||||
// httpChannelBuffer is the capacity of the channel feeding HTTP connections.
|
||||
httpChannelBuffer = 4096
|
||||
)
|
||||
|
||||
// DialResolver returns a DialContextFunc for the given account.
|
||||
type DialResolver func(accountID types.AccountID) (types.DialContextFunc, error)
|
||||
|
||||
// Route describes where a connection for a given SNI should be sent.
|
||||
type Route struct {
|
||||
Type RouteType
|
||||
AccountID types.AccountID
|
||||
ServiceID types.ServiceID
|
||||
// Domain is the service's configured domain, used for access log entries.
|
||||
Domain string
|
||||
// Protocol is the frontend protocol (tcp, tls), used for access log entries.
|
||||
Protocol accesslog.Protocol
|
||||
// Target is the backend address for TCP relay (e.g. "10.0.0.5:5432").
|
||||
Target string
|
||||
// ProxyProtocol enables sending a PROXY protocol v2 header to the backend.
|
||||
ProxyProtocol bool
|
||||
// DialTimeout overrides the default dial timeout for this route.
|
||||
// Zero uses defaultDialTimeout.
|
||||
DialTimeout time.Duration
|
||||
}
|
||||
|
||||
// l4Logger sends layer-4 access log entries to the management server.
|
||||
type l4Logger interface {
|
||||
LogL4(entry accesslog.L4Entry)
|
||||
}
|
||||
|
||||
// RelayObserver receives callbacks for TCP relay lifecycle events.
|
||||
// All methods must be safe for concurrent use.
|
||||
type RelayObserver interface {
|
||||
TCPRelayStarted(accountID types.AccountID)
|
||||
TCPRelayEnded(accountID types.AccountID, duration time.Duration, srcToDst, dstToSrc int64)
|
||||
TCPRelayDialError(accountID types.AccountID)
|
||||
TCPRelayRejected(accountID types.AccountID)
|
||||
}
|
||||
|
||||
// Router accepts raw TCP connections on a shared listener, peeks at
|
||||
// the TLS ClientHello to extract the SNI, and routes the connection
|
||||
// to either the HTTP reverse proxy or a direct TCP relay.
|
||||
type Router struct {
|
||||
logger *log.Logger
|
||||
// httpCh is immutable after construction: set only in NewRouter, nil in NewPortRouter.
|
||||
httpCh chan net.Conn
|
||||
httpListener *chanListener
|
||||
mu sync.RWMutex
|
||||
routes map[SNIHost][]Route
|
||||
fallback *Route
|
||||
draining bool
|
||||
dialResolve DialResolver
|
||||
activeConns sync.WaitGroup
|
||||
activeRelays sync.WaitGroup
|
||||
relaySem chan struct{}
|
||||
drainDone chan struct{}
|
||||
observer RelayObserver
|
||||
accessLog l4Logger
|
||||
// svcCtxs tracks a context per service ID. All relay goroutines for a
|
||||
// service derive from its context; canceling it kills them immediately.
|
||||
svcCtxs map[types.ServiceID]context.Context
|
||||
svcCancels map[types.ServiceID]context.CancelFunc
|
||||
}
|
||||
|
||||
// NewRouter creates a new SNI-based connection router.
|
||||
func NewRouter(logger *log.Logger, dialResolve DialResolver, addr net.Addr) *Router {
|
||||
httpCh := make(chan net.Conn, httpChannelBuffer)
|
||||
return &Router{
|
||||
logger: logger,
|
||||
httpCh: httpCh,
|
||||
httpListener: newChanListener(httpCh, addr),
|
||||
routes: make(map[SNIHost][]Route),
|
||||
dialResolve: dialResolve,
|
||||
relaySem: make(chan struct{}, DefaultMaxRelayConns),
|
||||
svcCtxs: make(map[types.ServiceID]context.Context),
|
||||
svcCancels: make(map[types.ServiceID]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
|
||||
// NewPortRouter creates a Router for a dedicated port without an HTTP
|
||||
// channel. Connections that don't match any SNI route fall through to
|
||||
// the fallback relay (if set) or are closed.
|
||||
func NewPortRouter(logger *log.Logger, dialResolve DialResolver) *Router {
|
||||
return &Router{
|
||||
logger: logger,
|
||||
routes: make(map[SNIHost][]Route),
|
||||
dialResolve: dialResolve,
|
||||
relaySem: make(chan struct{}, DefaultMaxRelayConns),
|
||||
svcCtxs: make(map[types.ServiceID]context.Context),
|
||||
svcCancels: make(map[types.ServiceID]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPListener returns a net.Listener that yields connections routed
|
||||
// to the HTTP handler. Use this with http.Server.ServeTLS.
|
||||
func (r *Router) HTTPListener() net.Listener {
|
||||
return r.httpListener
|
||||
}
|
||||
|
||||
// AddRoute registers an SNI route. Multiple routes for the same host are
|
||||
// stored and resolved by priority at lookup time (HTTP > TCP).
|
||||
// Empty host is ignored to prevent conflicts with ECH/ESNI fallback.
|
||||
func (r *Router) AddRoute(host SNIHost, route Route) {
|
||||
if host == "" {
|
||||
return
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
routes := r.routes[host]
|
||||
for i, existing := range routes {
|
||||
if existing.ServiceID == route.ServiceID {
|
||||
r.cancelServiceLocked(route.ServiceID)
|
||||
routes[i] = route
|
||||
return
|
||||
}
|
||||
}
|
||||
r.routes[host] = append(routes, route)
|
||||
}
|
||||
|
||||
// RemoveRoute removes the route for the given host and service ID.
|
||||
// Active relay connections for the service are closed immediately.
|
||||
// If other routes remain for the host, they are preserved.
|
||||
func (r *Router) RemoveRoute(host SNIHost, svcID types.ServiceID) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.routes[host] = slices.DeleteFunc(r.routes[host], func(route Route) bool {
|
||||
return route.ServiceID == svcID
|
||||
})
|
||||
if len(r.routes[host]) == 0 {
|
||||
delete(r.routes, host)
|
||||
}
|
||||
r.cancelServiceLocked(svcID)
|
||||
}
|
||||
|
||||
// SetFallback registers a catch-all route for connections that don't
|
||||
// match any SNI route. On a port router this handles plain TCP relay;
|
||||
// on the main router it takes priority over the HTTP channel.
|
||||
func (r *Router) SetFallback(route Route) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.fallback = &route
|
||||
}
|
||||
|
||||
// RemoveFallback clears the catch-all fallback route and closes any
|
||||
// active relay connections for the given service.
|
||||
func (r *Router) RemoveFallback(svcID types.ServiceID) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.fallback = nil
|
||||
r.cancelServiceLocked(svcID)
|
||||
}
|
||||
|
||||
// SetObserver sets the relay lifecycle observer. Must be called before Serve.
|
||||
func (r *Router) SetObserver(obs RelayObserver) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.observer = obs
|
||||
}
|
||||
|
||||
// SetAccessLogger sets the L4 access logger. Must be called before Serve.
|
||||
func (r *Router) SetAccessLogger(l l4Logger) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.accessLog = l
|
||||
}
|
||||
|
||||
// getObserver returns the current relay observer under the read lock.
|
||||
func (r *Router) getObserver() RelayObserver {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.observer
|
||||
}
|
||||
|
||||
// IsEmpty returns true when the router has no SNI routes and no fallback.
|
||||
func (r *Router) IsEmpty() bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return len(r.routes) == 0 && r.fallback == nil
|
||||
}
|
||||
|
||||
// Serve accepts connections from ln and routes them based on SNI.
|
||||
// It blocks until ctx is canceled or ln is closed, then drains
|
||||
// active relay connections up to DefaultDrainTimeout.
|
||||
func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
|
||||
done := make(chan struct{})
|
||||
defer close(done)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = ln.Close()
|
||||
if r.httpListener != nil {
|
||||
r.httpListener.Close()
|
||||
}
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
|
||||
if ok := r.Drain(DefaultDrainTimeout); !ok {
|
||||
r.logger.Warn("timed out waiting for connections to drain")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
r.logger.Debugf("SNI router accept: %v", err)
|
||||
continue
|
||||
}
|
||||
r.activeConns.Add(1)
|
||||
go func() {
|
||||
defer r.activeConns.Done()
|
||||
r.handleConn(ctx, conn)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// handleConn peeks at the TLS ClientHello and routes the connection.
|
||||
func (r *Router) handleConn(ctx context.Context, conn net.Conn) {
|
||||
// Fast path: when no SNI routes and no HTTP channel exist (pure TCP
|
||||
// fallback port), skip the TLS peek entirely to avoid read errors on
|
||||
// non-TLS connections and reduce latency.
|
||||
if r.isFallbackOnly() {
|
||||
r.handleUnmatched(ctx, conn)
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(sniPeekTimeout)); err != nil {
|
||||
r.logger.Debugf("set SNI peek deadline: %v", err)
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
sni, wrapped, err := PeekClientHello(conn)
|
||||
if err != nil {
|
||||
r.logger.Debugf("SNI peek: %v", err)
|
||||
if wrapped != nil {
|
||||
r.handleUnmatched(ctx, wrapped)
|
||||
} else {
|
||||
_ = conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := wrapped.SetReadDeadline(time.Time{}); err != nil {
|
||||
r.logger.Debugf("clear SNI peek deadline: %v", err)
|
||||
_ = wrapped.Close()
|
||||
return
|
||||
}
|
||||
|
||||
host := SNIHost(sni)
|
||||
route, ok := r.lookupRoute(host)
|
||||
if !ok {
|
||||
r.handleUnmatched(ctx, wrapped)
|
||||
return
|
||||
}
|
||||
|
||||
if route.Type == RouteHTTP {
|
||||
r.sendToHTTP(wrapped)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.relayTCP(ctx, wrapped, host, route); err != nil {
|
||||
r.logger.WithFields(log.Fields{
|
||||
"sni": host,
|
||||
"service_id": route.ServiceID,
|
||||
"target": route.Target,
|
||||
}).Warnf("TCP relay: %v", err)
|
||||
_ = wrapped.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// isFallbackOnly returns true when the router has no SNI routes and no HTTP
|
||||
// channel, meaning all connections should go directly to the fallback relay.
|
||||
func (r *Router) isFallbackOnly() bool {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return len(r.routes) == 0 && r.httpCh == nil
|
||||
}
|
||||
|
||||
// handleUnmatched routes a connection that didn't match any SNI route.
|
||||
// This includes ECH/ESNI connections where the cleartext SNI is empty.
|
||||
// It tries the fallback relay first, then the HTTP channel, and closes
|
||||
// the connection if neither is available.
|
||||
func (r *Router) handleUnmatched(ctx context.Context, conn net.Conn) {
|
||||
r.mu.RLock()
|
||||
fb := r.fallback
|
||||
r.mu.RUnlock()
|
||||
|
||||
if fb != nil {
|
||||
if err := r.relayTCP(ctx, conn, SNIHost("fallback"), *fb); err != nil {
|
||||
r.logger.WithFields(log.Fields{
|
||||
"service_id": fb.ServiceID,
|
||||
"target": fb.Target,
|
||||
}).Warnf("TCP relay (fallback): %v", err)
|
||||
_ = conn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
r.sendToHTTP(conn)
|
||||
}
|
||||
|
||||
// lookupRoute returns the highest-priority route for the given SNI host.
|
||||
// HTTP routes take precedence over TCP routes.
|
||||
func (r *Router) lookupRoute(host SNIHost) (Route, bool) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
routes, ok := r.routes[host]
|
||||
if !ok || len(routes) == 0 {
|
||||
return Route{}, false
|
||||
}
|
||||
best := routes[0]
|
||||
for _, route := range routes[1:] {
|
||||
if route.Type < best.Type {
|
||||
best = route
|
||||
}
|
||||
}
|
||||
return best, true
|
||||
}
|
||||
|
||||
// sendToHTTP feeds the connection to the HTTP handler via the channel.
|
||||
// If no HTTP channel is configured (port router), the router is
|
||||
// draining, or the channel is full, the connection is closed.
|
||||
func (r *Router) sendToHTTP(conn net.Conn) {
|
||||
if r.httpCh == nil {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
r.mu.RLock()
|
||||
draining := r.draining
|
||||
r.mu.RUnlock()
|
||||
|
||||
if draining {
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case r.httpCh <- conn:
|
||||
default:
|
||||
r.logger.Warnf("HTTP channel full, dropping connection from %s", conn.RemoteAddr())
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Drain prevents new relay connections from starting and waits for all
|
||||
// in-flight connection handlers and active relays to finish, up to the
|
||||
// given timeout. Returns true if all completed, false on timeout.
|
||||
func (r *Router) Drain(timeout time.Duration) bool {
|
||||
r.mu.Lock()
|
||||
r.draining = true
|
||||
if r.drainDone == nil {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
r.activeConns.Wait()
|
||||
r.activeRelays.Wait()
|
||||
close(done)
|
||||
}()
|
||||
r.drainDone = done
|
||||
}
|
||||
done := r.drainDone
|
||||
r.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// cancelServiceLocked cancels and removes the context for the given service,
|
||||
// closing all its active relay connections. Must be called with mu held.
|
||||
func (r *Router) cancelServiceLocked(svcID types.ServiceID) {
|
||||
if cancel, ok := r.svcCancels[svcID]; ok {
|
||||
cancel()
|
||||
delete(r.svcCtxs, svcID)
|
||||
delete(r.svcCancels, svcID)
|
||||
}
|
||||
}
|
||||
|
||||
// relayTCP sets up and runs a bidirectional TCP relay.
|
||||
// The caller owns conn and must close it if this method returns an error.
|
||||
// On success (nil error), both conn and backend are closed by the relay.
|
||||
func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route Route) error {
|
||||
svcCtx, err := r.acquireRelay(ctx, route)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
<-r.relaySem
|
||||
r.activeRelays.Done()
|
||||
}()
|
||||
|
||||
backend, err := r.dialBackend(svcCtx, route)
|
||||
if err != nil {
|
||||
obs := r.getObserver()
|
||||
if obs != nil {
|
||||
obs.TCPRelayDialError(route.AccountID)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if route.ProxyProtocol {
|
||||
if err := writeProxyProtoV2(conn, backend); err != nil {
|
||||
_ = backend.Close()
|
||||
return fmt.Errorf("write PROXY protocol header: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
obs := r.getObserver()
|
||||
if obs != nil {
|
||||
obs.TCPRelayStarted(route.AccountID)
|
||||
}
|
||||
|
||||
entry := r.logger.WithFields(log.Fields{
|
||||
"sni": sni,
|
||||
"service_id": route.ServiceID,
|
||||
"target": route.Target,
|
||||
})
|
||||
entry.Debug("TCP relay started")
|
||||
|
||||
start := time.Now()
|
||||
s2d, d2s := Relay(svcCtx, entry, conn, backend, DefaultIdleTimeout)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
if obs != nil {
|
||||
obs.TCPRelayEnded(route.AccountID, elapsed, s2d, d2s)
|
||||
}
|
||||
entry.Debugf("TCP relay ended (client→backend: %d bytes, backend→client: %d bytes)", s2d, d2s)
|
||||
|
||||
r.logL4Entry(route, conn, elapsed, s2d, d2s)
|
||||
return nil
|
||||
}
|
||||
|
||||
// acquireRelay checks draining state, increments activeRelays, and acquires
|
||||
// a semaphore slot. Returns the per-service context on success.
|
||||
// The caller must release the semaphore and call activeRelays.Done() when done.
|
||||
func (r *Router) acquireRelay(ctx context.Context, route Route) (context.Context, error) {
|
||||
r.mu.Lock()
|
||||
if r.draining {
|
||||
r.mu.Unlock()
|
||||
return nil, errors.New("router is draining")
|
||||
}
|
||||
r.activeRelays.Add(1)
|
||||
svcCtx := r.getOrCreateServiceCtxLocked(ctx, route.ServiceID)
|
||||
r.mu.Unlock()
|
||||
|
||||
select {
|
||||
case r.relaySem <- struct{}{}:
|
||||
return svcCtx, nil
|
||||
default:
|
||||
r.activeRelays.Done()
|
||||
obs := r.getObserver()
|
||||
if obs != nil {
|
||||
obs.TCPRelayRejected(route.AccountID)
|
||||
}
|
||||
return nil, errors.New("TCP relay connection limit reached")
|
||||
}
|
||||
}
|
||||
|
||||
// dialBackend resolves the dialer for the route's account and dials the backend.
|
||||
func (r *Router) dialBackend(svcCtx context.Context, route Route) (net.Conn, error) {
|
||||
dialFn, err := r.dialResolve(route.AccountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve dialer: %w", err)
|
||||
}
|
||||
|
||||
dialTimeout := route.DialTimeout
|
||||
if dialTimeout <= 0 {
|
||||
dialTimeout = defaultDialTimeout
|
||||
}
|
||||
dialCtx, dialCancel := context.WithTimeout(svcCtx, dialTimeout)
|
||||
backend, err := dialFn(dialCtx, "tcp", route.Target)
|
||||
dialCancel()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial backend %s: %w", route.Target, err)
|
||||
}
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
// logL4Entry sends a TCP relay access log entry if an access logger is configured.
|
||||
func (r *Router) logL4Entry(route Route, conn net.Conn, duration time.Duration, bytesUp, bytesDown int64) {
|
||||
r.mu.RLock()
|
||||
al := r.accessLog
|
||||
r.mu.RUnlock()
|
||||
|
||||
if al == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var sourceIP netip.Addr
|
||||
if remote := conn.RemoteAddr(); remote != nil {
|
||||
if ap, err := netip.ParseAddrPort(remote.String()); err == nil {
|
||||
sourceIP = ap.Addr().Unmap()
|
||||
}
|
||||
}
|
||||
|
||||
al.LogL4(accesslog.L4Entry{
|
||||
AccountID: route.AccountID,
|
||||
ServiceID: route.ServiceID,
|
||||
Protocol: route.Protocol,
|
||||
Host: route.Domain,
|
||||
SourceIP: sourceIP,
|
||||
DurationMs: duration.Milliseconds(),
|
||||
BytesUpload: bytesUp,
|
||||
BytesDownload: bytesDown,
|
||||
})
|
||||
}
|
||||
|
||||
// getOrCreateServiceCtxLocked returns the context for a service, creating one
|
||||
// if it doesn't exist yet. The context is a child of the server context.
|
||||
// Must be called with mu held.
|
||||
func (r *Router) getOrCreateServiceCtxLocked(parent context.Context, svcID types.ServiceID) context.Context {
|
||||
if ctx, ok := r.svcCtxs[svcID]; ok {
|
||||
return ctx
|
||||
}
|
||||
ctx, cancel := context.WithCancel(parent)
|
||||
r.svcCtxs[svcID] = ctx
|
||||
r.svcCancels[svcID] = cancel
|
||||
return ctx
|
||||
}
|
||||
1670
proxy/internal/tcp/router_test.go
Normal file
1670
proxy/internal/tcp/router_test.go
Normal file
File diff suppressed because it is too large
Load Diff
191
proxy/internal/tcp/snipeek.go
Normal file
191
proxy/internal/tcp/snipeek.go
Normal file
@@ -0,0 +1,191 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
const (
|
||||
// TLS record header is 5 bytes: ContentType(1) + Version(2) + Length(2).
|
||||
tlsRecordHeaderLen = 5
|
||||
// TLS handshake type for ClientHello.
|
||||
handshakeTypeClientHello = 1
|
||||
// TLS ContentType for handshake messages.
|
||||
contentTypeHandshake = 22
|
||||
// SNI extension type (RFC 6066).
|
||||
extensionServerName = 0
|
||||
// SNI host name type.
|
||||
sniHostNameType = 0
|
||||
// maxClientHelloLen caps the ClientHello size we're willing to buffer.
|
||||
maxClientHelloLen = 16384
|
||||
// maxSNILen is the maximum valid DNS hostname length per RFC 1035.
|
||||
maxSNILen = 253
|
||||
)
|
||||
|
||||
// PeekClientHello reads the TLS ClientHello from conn, extracts the SNI
|
||||
// server name, and returns a wrapped connection that replays the peeked
|
||||
// bytes transparently. If the data is not a valid TLS ClientHello or
|
||||
// contains no SNI extension, sni is empty and err is nil.
|
||||
//
|
||||
// ECH/ESNI: When the client uses Encrypted Client Hello (TLS 1.3), the
|
||||
// real server name is encrypted inside the encrypted_client_hello
|
||||
// extension. This parser only reads the cleartext server_name extension
|
||||
// (type 0x0000), so ECH connections return sni="" and are routed through
|
||||
// the fallback path (or HTTP channel), which is the correct behavior
|
||||
// for a transparent proxy that does not terminate TLS.
|
||||
func PeekClientHello(conn net.Conn) (sni string, wrapped net.Conn, err error) {
|
||||
// Read the 5-byte TLS record header into a small stack-friendly buffer.
|
||||
var header [tlsRecordHeaderLen]byte
|
||||
if _, err := io.ReadFull(conn, header[:]); err != nil {
|
||||
return "", nil, fmt.Errorf("read TLS record header: %w", err)
|
||||
}
|
||||
|
||||
if header[0] != contentTypeHandshake {
|
||||
return "", newPeekedConn(conn, header[:]), nil
|
||||
}
|
||||
|
||||
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
|
||||
if recordLen == 0 || recordLen > maxClientHelloLen {
|
||||
return "", newPeekedConn(conn, header[:]), nil
|
||||
}
|
||||
|
||||
// Single allocation for header + payload. The peekedConn takes
|
||||
// ownership of this buffer, so no further copies are needed.
|
||||
buf := make([]byte, tlsRecordHeaderLen+recordLen)
|
||||
copy(buf, header[:])
|
||||
|
||||
n, err := io.ReadFull(conn, buf[tlsRecordHeaderLen:])
|
||||
if err != nil {
|
||||
return "", newPeekedConn(conn, buf[:tlsRecordHeaderLen+n]), fmt.Errorf("read TLS handshake payload: %w", err)
|
||||
}
|
||||
|
||||
sni = extractSNI(buf[tlsRecordHeaderLen:])
|
||||
return sni, newPeekedConn(conn, buf), nil
|
||||
}
|
||||
|
||||
// extractSNI parses a TLS handshake payload to find the SNI extension.
|
||||
// Returns empty string if the payload is not a ClientHello or has no SNI.
|
||||
func extractSNI(payload []byte) string {
|
||||
if len(payload) < 4 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if payload[0] != handshakeTypeClientHello {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handshake length (3 bytes, big-endian).
|
||||
handshakeLen := int(payload[1])<<16 | int(payload[2])<<8 | int(payload[3])
|
||||
if handshakeLen > len(payload)-4 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return parseSNIFromClientHello(payload[4 : 4+handshakeLen])
|
||||
}
|
||||
|
||||
// parseSNIFromClientHello walks the ClientHello message fields to reach
|
||||
// the extensions block and extract the server_name extension value.
|
||||
func parseSNIFromClientHello(msg []byte) string {
|
||||
// ClientHello layout:
|
||||
// ProtocolVersion(2) + Random(32) = 34 bytes minimum before session_id
|
||||
if len(msg) < 34 {
|
||||
return ""
|
||||
}
|
||||
|
||||
pos := 34
|
||||
|
||||
// Session ID (variable, 1 byte length prefix).
|
||||
if pos >= len(msg) {
|
||||
return ""
|
||||
}
|
||||
sessionIDLen := int(msg[pos])
|
||||
pos++
|
||||
pos += sessionIDLen
|
||||
|
||||
// Cipher suites (variable, 2 byte length prefix).
|
||||
if pos+2 > len(msg) {
|
||||
return ""
|
||||
}
|
||||
cipherSuitesLen := int(binary.BigEndian.Uint16(msg[pos : pos+2]))
|
||||
pos += 2 + cipherSuitesLen
|
||||
|
||||
// Compression methods (variable, 1 byte length prefix).
|
||||
if pos >= len(msg) {
|
||||
return ""
|
||||
}
|
||||
compMethodsLen := int(msg[pos])
|
||||
pos++
|
||||
pos += compMethodsLen
|
||||
|
||||
// Extensions (variable, 2 byte length prefix).
|
||||
if pos+2 > len(msg) {
|
||||
return ""
|
||||
}
|
||||
extensionsLen := int(binary.BigEndian.Uint16(msg[pos : pos+2]))
|
||||
pos += 2
|
||||
|
||||
extensionsEnd := pos + extensionsLen
|
||||
if extensionsEnd > len(msg) {
|
||||
return ""
|
||||
}
|
||||
|
||||
return findSNIExtension(msg[pos:extensionsEnd])
|
||||
}
|
||||
|
||||
// findSNIExtension iterates over TLS extensions and returns the host
|
||||
// name from the server_name extension, if present.
|
||||
func findSNIExtension(extensions []byte) string {
|
||||
pos := 0
|
||||
for pos+4 <= len(extensions) {
|
||||
extType := binary.BigEndian.Uint16(extensions[pos : pos+2])
|
||||
extLen := int(binary.BigEndian.Uint16(extensions[pos+2 : pos+4]))
|
||||
pos += 4
|
||||
|
||||
if pos+extLen > len(extensions) {
|
||||
return ""
|
||||
}
|
||||
|
||||
if extType == extensionServerName {
|
||||
return parseSNIExtensionData(extensions[pos : pos+extLen])
|
||||
}
|
||||
pos += extLen
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseSNIExtensionData parses the ServerNameList structure inside an
|
||||
// SNI extension to extract the host name.
|
||||
func parseSNIExtensionData(data []byte) string {
|
||||
if len(data) < 2 {
|
||||
return ""
|
||||
}
|
||||
listLen := int(binary.BigEndian.Uint16(data[0:2]))
|
||||
if listLen > len(data)-2 {
|
||||
return ""
|
||||
}
|
||||
|
||||
list := data[2 : 2+listLen]
|
||||
pos := 0
|
||||
for pos+3 <= len(list) {
|
||||
nameType := list[pos]
|
||||
nameLen := int(binary.BigEndian.Uint16(list[pos+1 : pos+3]))
|
||||
pos += 3
|
||||
|
||||
if pos+nameLen > len(list) {
|
||||
return ""
|
||||
}
|
||||
|
||||
if nameType == sniHostNameType {
|
||||
name := list[pos : pos+nameLen]
|
||||
if nameLen > maxSNILen || bytes.ContainsRune(name, 0) {
|
||||
return ""
|
||||
}
|
||||
return string(name)
|
||||
}
|
||||
pos += nameLen
|
||||
}
|
||||
return ""
|
||||
}
|
||||
251
proxy/internal/tcp/snipeek_test.go
Normal file
251
proxy/internal/tcp/snipeek_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPeekClientHello_ValidSNI(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
const expectedSNI = "example.com"
|
||||
trailingData := []byte("trailing data after handshake")
|
||||
|
||||
go func() {
|
||||
tlsConn := tls.Client(clientConn, &tls.Config{
|
||||
ServerName: expectedSNI,
|
||||
InsecureSkipVerify: true, //nolint:gosec
|
||||
})
|
||||
// The Handshake will send the ClientHello. It will fail because
|
||||
// our server side isn't doing a real TLS handshake, but that's
|
||||
// fine: we only need the ClientHello to be sent.
|
||||
_ = tlsConn.Handshake()
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedSNI, sni, "should extract SNI from ClientHello")
|
||||
assert.NotNil(t, wrapped, "wrapped connection should not be nil")
|
||||
|
||||
// Verify the wrapped connection replays the peeked bytes.
|
||||
// Read the first 5 bytes (TLS record header) to confirm replay.
|
||||
buf := make([]byte, 5)
|
||||
n, err := wrapped.Read(buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 5, n)
|
||||
assert.Equal(t, byte(contentTypeHandshake), buf[0], "first byte should be TLS handshake content type")
|
||||
|
||||
// Write trailing data from the client side and verify it arrives
|
||||
// through the wrapped connection after the peeked bytes.
|
||||
go func() {
|
||||
_, _ = clientConn.Write(trailingData)
|
||||
}()
|
||||
|
||||
// Drain the rest of the peeked ClientHello first.
|
||||
peekedRest := make([]byte, 16384)
|
||||
_, _ = wrapped.Read(peekedRest)
|
||||
|
||||
got := make([]byte, len(trailingData))
|
||||
n, err = io.ReadFull(wrapped, got)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, trailingData, got[:n])
|
||||
}
|
||||
|
||||
func TestPeekClientHello_MultipleSNIs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
serverName string
|
||||
expectedSNI string
|
||||
}{
|
||||
{"simple domain", "example.com", "example.com"},
|
||||
{"subdomain", "sub.example.com", "sub.example.com"},
|
||||
{"deep subdomain", "a.b.c.example.com", "a.b.c.example.com"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
go func() {
|
||||
tlsConn := tls.Client(clientConn, &tls.Config{
|
||||
ServerName: tt.serverName,
|
||||
InsecureSkipVerify: true, //nolint:gosec
|
||||
})
|
||||
_ = tlsConn.Handshake()
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, tt.expectedSNI, sni)
|
||||
assert.NotNil(t, wrapped)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeekClientHello_NonTLSData(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
// Send plain HTTP data (not TLS).
|
||||
httpData := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
go func() {
|
||||
_, _ = clientConn.Write(httpData)
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, sni, "should return empty SNI for non-TLS data")
|
||||
assert.NotNil(t, wrapped)
|
||||
|
||||
// Verify the wrapped connection still provides the original data.
|
||||
buf := make([]byte, len(httpData))
|
||||
n, err := io.ReadFull(wrapped, buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, httpData, buf[:n], "wrapped connection should replay original data")
|
||||
}
|
||||
|
||||
func TestPeekClientHello_TruncatedHeader(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
|
||||
// Write only 3 bytes then close, fewer than the 5-byte TLS header.
|
||||
go func() {
|
||||
_, _ = clientConn.Write([]byte{0x16, 0x03, 0x01})
|
||||
clientConn.Close()
|
||||
}()
|
||||
|
||||
_, _, err := PeekClientHello(serverConn)
|
||||
assert.Error(t, err, "should error on truncated header")
|
||||
}
|
||||
|
||||
func TestPeekClientHello_TruncatedPayload(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer serverConn.Close()
|
||||
|
||||
// Write a valid TLS header claiming 100 bytes, but only send 10.
|
||||
go func() {
|
||||
header := []byte{0x16, 0x03, 0x01, 0x00, 0x64} // 100 bytes claimed
|
||||
_, _ = clientConn.Write(header)
|
||||
_, _ = clientConn.Write(make([]byte, 10))
|
||||
clientConn.Close()
|
||||
}()
|
||||
|
||||
_, _, err := PeekClientHello(serverConn)
|
||||
assert.Error(t, err, "should error on truncated payload")
|
||||
}
|
||||
|
||||
func TestPeekClientHello_ZeroLengthRecord(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
// TLS handshake header with zero-length payload.
|
||||
go func() {
|
||||
_, _ = clientConn.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x00})
|
||||
}()
|
||||
|
||||
sni, wrapped, err := PeekClientHello(serverConn)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, sni)
|
||||
assert.NotNil(t, wrapped)
|
||||
}
|
||||
|
||||
func TestExtractSNI_InvalidPayload(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
}{
|
||||
{"nil", nil},
|
||||
{"empty", []byte{}},
|
||||
{"too short", []byte{0x01, 0x00}},
|
||||
{"wrong handshake type", []byte{0x02, 0x00, 0x00, 0x05, 0x03, 0x03, 0x00, 0x00, 0x00}},
|
||||
{"truncated client hello", []byte{0x01, 0x00, 0x00, 0x20}}, // claims 32 bytes but has none
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Empty(t, extractSNI(tt.payload))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPeekedConn_CloseWrite(t *testing.T) {
|
||||
t.Run("delegates to underlying TCPConn", func(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
accepted := make(chan net.Conn, 1)
|
||||
go func() {
|
||||
c, err := ln.Accept()
|
||||
if err == nil {
|
||||
accepted <- c
|
||||
}
|
||||
}()
|
||||
|
||||
client, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
|
||||
server := <-accepted
|
||||
defer server.Close()
|
||||
|
||||
wrapped := newPeekedConn(server, []byte("peeked"))
|
||||
|
||||
// CloseWrite should succeed on a real TCP connection.
|
||||
err = wrapped.CloseWrite()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// The client should see EOF on reads after CloseWrite.
|
||||
buf := make([]byte, 1)
|
||||
_, err = client.Read(buf)
|
||||
assert.Equal(t, io.EOF, err, "client should see EOF after half-close")
|
||||
})
|
||||
|
||||
t.Run("no-op on non-halfcloser", func(t *testing.T) {
|
||||
// net.Pipe does not implement CloseWrite.
|
||||
_, server := net.Pipe()
|
||||
defer server.Close()
|
||||
|
||||
wrapped := newPeekedConn(server, []byte("peeked"))
|
||||
err := wrapped.CloseWrite()
|
||||
assert.NoError(t, err, "should be no-op on non-halfcloser")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPeekedConn_ReplayAndPassthrough(t *testing.T) {
|
||||
clientConn, serverConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
peeked := []byte("peeked-data")
|
||||
subsequent := []byte("subsequent-data")
|
||||
|
||||
wrapped := newPeekedConn(serverConn, peeked)
|
||||
|
||||
go func() {
|
||||
_, _ = clientConn.Write(subsequent)
|
||||
}()
|
||||
|
||||
// Read should return peeked data first.
|
||||
buf := make([]byte, len(peeked))
|
||||
n, err := io.ReadFull(wrapped, buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, peeked, buf[:n])
|
||||
|
||||
// Then subsequent data from the real connection.
|
||||
buf = make([]byte, len(subsequent))
|
||||
n, err = io.ReadFull(wrapped, buf)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, subsequent, buf[:n])
|
||||
}
|
||||
Reference in New Issue
Block a user