mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-16 05:39:56 +00:00
[client, proxy] Harden uspfilter conntrack and share TCP relay (#5936)
This commit is contained in:
238
util/netrelay/relay.go
Normal file
238
util/netrelay/relay.go
Normal file
@@ -0,0 +1,238 @@
|
||||
// Package netrelay provides a bidirectional byte-copy helper for TCP-like
|
||||
// connections with correct half-close propagation.
|
||||
//
|
||||
// When one direction reads EOF, the write side of the opposite connection is
|
||||
// half-closed (CloseWrite) so the peer sees FIN, then the second direction is
|
||||
// allowed to drain to its own EOF before both connections are fully closed.
|
||||
// This preserves TCP half-close semantics (e.g. shutdown(SHUT_WR)) that the
|
||||
// naive "cancel-both-on-first-EOF" pattern breaks.
|
||||
package netrelay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DebugLogger is the minimal interface netrelay uses to surface teardown
|
||||
// errors. Both *logrus.Entry and *nblog.Logger (via its Debugf method)
|
||||
// satisfy it, so callers can pass whichever they already use without an
|
||||
// adapter. Debugf is the only required method; callers with richer
|
||||
// loggers just expose this one shape here.
|
||||
type DebugLogger interface {
|
||||
Debugf(format string, args ...any)
|
||||
}
|
||||
|
||||
// DefaultIdleTimeout is a reasonable default for Options.IdleTimeout. Callers
|
||||
// that want an idle timeout but have no specific preference can use this.
|
||||
const DefaultIdleTimeout = 5 * time.Minute
|
||||
|
||||
// halfCloser is implemented by connections that support half-close
|
||||
// (e.g. *net.TCPConn, *gonet.TCPConn).
|
||||
type halfCloser interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
var copyBufPool = sync.Pool{
|
||||
New: func() any {
|
||||
buf := make([]byte, 32*1024)
|
||||
return &buf
|
||||
},
|
||||
}
|
||||
|
||||
// Options configures Relay behavior. The zero value is valid: no idle timeout,
|
||||
// no logging.
|
||||
type Options struct {
|
||||
// IdleTimeout tears down the session if no bytes flow in either
|
||||
// direction within this window. It is a connection-wide watchdog, so a
|
||||
// long unidirectional transfer on one side keeps the other side alive.
|
||||
// Zero disables idle tracking.
|
||||
IdleTimeout time.Duration
|
||||
// Logger receives debug-level copy/idle errors. Nil suppresses logging.
|
||||
// Any logger with Debug/Debugf methods is accepted (logrus.Entry,
|
||||
// uspfilter's nblog.Logger, etc.).
|
||||
Logger DebugLogger
|
||||
}
|
||||
|
||||
// Relay copies bytes in both directions between a and b until both directions
|
||||
// EOF or ctx is canceled. On each direction's EOF it half-closes the
|
||||
// opposite conn's write side (best effort) so the peer sees FIN while the
|
||||
// other direction drains. Both conns are fully closed when Relay returns.
|
||||
//
|
||||
// a and b only need to implement io.ReadWriteCloser; connections that also
|
||||
// implement CloseWrite (e.g. *net.TCPConn, ssh.Channel) get proper half-close
|
||||
// propagation. Options.IdleTimeout, when set, is enforced by a connection-wide
|
||||
// watchdog that tracks reads in either direction.
|
||||
//
|
||||
// Return values are byte counts: aToB (a.Read → b.Write) and bToA (b.Read →
|
||||
// a.Write). Errors are logged via Options.Logger when set; they are not
|
||||
// returned because a relay always terminates on some kind of EOF/cancel.
|
||||
func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bToA int64) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
closeDone := make(chan struct{})
|
||||
defer func() {
|
||||
cancel()
|
||||
<-closeDone
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = a.Close()
|
||||
_ = b.Close()
|
||||
close(closeDone)
|
||||
}()
|
||||
|
||||
// Both sides must support CloseWrite to propagate half-close. If either
|
||||
// doesn't, a direction's EOF can't be signaled to the peer and the other
|
||||
// direction would block forever waiting for data; in that case we fall
|
||||
// back to the cancel-both-on-first-EOF behavior.
|
||||
_, aHC := a.(halfCloser)
|
||||
_, bHC := b.(halfCloser)
|
||||
halfCloseSupported := aHC && bHC
|
||||
|
||||
var (
|
||||
lastActivity atomic.Int64
|
||||
idleHit atomic.Bool
|
||||
)
|
||||
lastActivity.Store(time.Now().UnixNano())
|
||||
|
||||
if opts.IdleTimeout > 0 {
|
||||
go watchdog(ctx, cancel, &lastActivity, &idleHit, opts.IdleTimeout)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var errAToB, errBToA error
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
aToB, errAToB = copyTracked(b, a, &lastActivity)
|
||||
if halfCloseSupported && isCleanEOF(errAToB) {
|
||||
halfClose(b)
|
||||
} else {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
bToA, errBToA = copyTracked(a, b, &lastActivity)
|
||||
if halfCloseSupported && isCleanEOF(errBToA) {
|
||||
halfClose(a)
|
||||
} else {
|
||||
cancel()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if opts.Logger != nil {
|
||||
if idleHit.Load() {
|
||||
opts.Logger.Debugf("relay closed due to idle timeout")
|
||||
}
|
||||
if errAToB != nil && !isExpectedCopyError(errAToB) {
|
||||
opts.Logger.Debugf("relay copy error (a→b): %v", errAToB)
|
||||
}
|
||||
if errBToA != nil && !isExpectedCopyError(errBToA) {
|
||||
opts.Logger.Debugf("relay copy error (b→a): %v", errBToA)
|
||||
}
|
||||
}
|
||||
|
||||
return aToB, bToA
|
||||
}
|
||||
|
||||
// watchdog enforces a connection-wide idle timeout. It cancels ctx when no
|
||||
// activity has been seen on either direction for idle. It exits as soon as
|
||||
// ctx is canceled so it doesn't outlive the relay.
|
||||
func watchdog(ctx context.Context, cancel context.CancelFunc, lastActivity *atomic.Int64, idleHit *atomic.Bool, idle time.Duration) {
|
||||
// Cap the tick at 50ms so detection latency stays bounded regardless of
|
||||
// how large idle is, and fall back to idle/2 when that is smaller so
|
||||
// very short timeouts (mainly in tests) are still caught promptly.
|
||||
tick := min(idle/2, 50*time.Millisecond)
|
||||
if tick <= 0 {
|
||||
tick = time.Millisecond
|
||||
}
|
||||
t := time.NewTicker(tick)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
last := time.Unix(0, lastActivity.Load())
|
||||
if time.Since(last) >= idle {
|
||||
idleHit.Store(true)
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// copyTracked copies from src to dst using a pooled buffer, updating
|
||||
// lastActivity on every successful read so a shared watchdog can enforce a
|
||||
// connection-wide idle timeout.
|
||||
func copyTracked(dst io.Writer, src io.Reader, lastActivity *atomic.Int64) (int64, error) {
|
||||
bufp := copyBufPool.Get().(*[]byte)
|
||||
defer copyBufPool.Put(bufp)
|
||||
|
||||
buf := *bufp
|
||||
var total int64
|
||||
for {
|
||||
nr, readErr := src.Read(buf)
|
||||
if nr > 0 {
|
||||
lastActivity.Store(time.Now().UnixNano())
|
||||
n, werr := checkedWrite(dst, buf[:nr])
|
||||
total += n
|
||||
if werr != nil {
|
||||
return total, werr
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
return total, readErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkedWrite(dst io.Writer, buf []byte) (int64, error) {
|
||||
nw, err := dst.Write(buf)
|
||||
if nw < 0 || nw > len(buf) {
|
||||
nw = 0
|
||||
}
|
||||
if err != nil {
|
||||
return int64(nw), err
|
||||
}
|
||||
if nw != len(buf) {
|
||||
return int64(nw), io.ErrShortWrite
|
||||
}
|
||||
return int64(nw), nil
|
||||
}
|
||||
|
||||
func halfClose(conn io.ReadWriteCloser) {
|
||||
if hc, ok := conn.(halfCloser); ok {
|
||||
_ = hc.CloseWrite()
|
||||
}
|
||||
}
|
||||
|
||||
// isCleanEOF reports whether a copy terminated on a graceful end-of-stream.
|
||||
// Only in that case is it correct to propagate the EOF via CloseWrite on the
|
||||
// peer; any other error means the flow is broken and both directions should
|
||||
// tear down.
|
||||
func isCleanEOF(err error) bool {
|
||||
return err == nil || errors.Is(err, io.EOF)
|
||||
}
|
||||
|
||||
func isExpectedCopyError(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
errors.Is(err, io.EOF) ||
|
||||
errors.Is(err, syscall.ECONNRESET) ||
|
||||
errors.Is(err, syscall.EPIPE) ||
|
||||
errors.Is(err, syscall.ECONNABORTED)
|
||||
}
|
||||
221
util/netrelay/relay_test.go
Normal file
221
util/netrelay/relay_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package netrelay
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// tcpPair returns two connected loopback TCP conns.
|
||||
func tcpPair(t *testing.T) (*net.TCPConn, *net.TCPConn) {
|
||||
t.Helper()
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer ln.Close()
|
||||
|
||||
type result struct {
|
||||
c *net.TCPConn
|
||||
err error
|
||||
}
|
||||
ch := make(chan result, 1)
|
||||
go func() {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
ch <- result{nil, err}
|
||||
return
|
||||
}
|
||||
ch <- result{c.(*net.TCPConn), nil}
|
||||
}()
|
||||
|
||||
dial, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
|
||||
r := <-ch
|
||||
require.NoError(t, r.err)
|
||||
return dial.(*net.TCPConn), r.c
|
||||
}
|
||||
|
||||
// TestRelayHalfClose exercises the shutdown(SHUT_WR) scenario that the naive
|
||||
// cancel-both-on-first-EOF pattern breaks. Client A shuts down its write
|
||||
// side; B must still be able to write a full response and A must receive
|
||||
// all of it before its read returns EOF.
|
||||
func TestRelayHalfClose(t *testing.T) {
|
||||
// Real peer pairs for each side of the relay. We relay between relayA
|
||||
// and relayB. Peer A talks through relayA; peer B talks through relayB.
|
||||
peerA, relayA := tcpPair(t)
|
||||
relayB, peerB := tcpPair(t)
|
||||
|
||||
defer peerA.Close()
|
||||
defer peerB.Close()
|
||||
|
||||
// Bound blocking reads/writes so a broken relay fails the test instead of
|
||||
// hanging the test process.
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
require.NoError(t, peerA.SetDeadline(deadline))
|
||||
require.NoError(t, peerB.SetDeadline(deadline))
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, relayA, relayB, Options{})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Peer A sends a request, then half-closes its write side.
|
||||
req := []byte("request-payload")
|
||||
_, err := peerA.Write(req)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, peerA.CloseWrite())
|
||||
|
||||
// Peer B reads the request to EOF (FIN must have propagated).
|
||||
got, err := io.ReadAll(peerB)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, req, got)
|
||||
|
||||
// Peer B writes its response; peer A must receive all of it even though
|
||||
// peer A's write side is already closed.
|
||||
resp := make([]byte, 64*1024)
|
||||
for i := range resp {
|
||||
resp[i] = byte(i)
|
||||
}
|
||||
_, err = peerB.Write(resp)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, peerB.Close())
|
||||
|
||||
gotResp, err := io.ReadAll(peerA)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, resp, gotResp)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("relay did not return")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRelayFullDuplex verifies bidirectional copy in the simple case.
|
||||
func TestRelayFullDuplex(t *testing.T) {
|
||||
peerA, relayA := tcpPair(t)
|
||||
relayB, peerB := tcpPair(t)
|
||||
defer peerA.Close()
|
||||
defer peerB.Close()
|
||||
|
||||
// Bound blocking reads/writes so a broken relay fails the test instead of
|
||||
// hanging the test process.
|
||||
deadline := time.Now().Add(5 * time.Second)
|
||||
require.NoError(t, peerA.SetDeadline(deadline))
|
||||
require.NoError(t, peerB.SetDeadline(deadline))
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, relayA, relayB, Options{})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
type result struct {
|
||||
got []byte
|
||||
err error
|
||||
}
|
||||
resA := make(chan result, 1)
|
||||
resB := make(chan result, 1)
|
||||
|
||||
msgAB := []byte("hello-from-a")
|
||||
msgBA := []byte("hello-from-b")
|
||||
|
||||
go func() {
|
||||
if _, err := peerA.Write(msgAB); err != nil {
|
||||
resA <- result{err: err}
|
||||
return
|
||||
}
|
||||
buf := make([]byte, len(msgBA))
|
||||
_, err := io.ReadFull(peerA, buf)
|
||||
resA <- result{got: buf, err: err}
|
||||
_ = peerA.Close()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := peerB.Write(msgBA); err != nil {
|
||||
resB <- result{err: err}
|
||||
return
|
||||
}
|
||||
buf := make([]byte, len(msgAB))
|
||||
_, err := io.ReadFull(peerB, buf)
|
||||
resB <- result{got: buf, err: err}
|
||||
_ = peerB.Close()
|
||||
}()
|
||||
|
||||
a, b := <-resA, <-resB
|
||||
require.NoError(t, a.err)
|
||||
require.Equal(t, msgBA, a.got)
|
||||
require.NoError(t, b.err)
|
||||
require.Equal(t, msgAB, b.got)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("relay did not return")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRelayNoHalfCloseFallback ensures Relay terminates when the underlying
|
||||
// conns don't support CloseWrite (e.g. net.Pipe). Without the fallback to
|
||||
// cancel-both-on-first-EOF, the second direction would block forever.
|
||||
func TestRelayNoHalfCloseFallback(t *testing.T) {
|
||||
a1, a2 := net.Pipe()
|
||||
b1, b2 := net.Pipe()
|
||||
defer a1.Close()
|
||||
defer b1.Close()
|
||||
|
||||
ctx := t.Context()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, a2, b2, Options{})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Close peer A's side; a2's Read will return EOF.
|
||||
require.NoError(t, a1.Close())
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("relay did not terminate when half-close is unsupported")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRelayIdleTimeout ensures the idle watchdog tears down a silent flow.
|
||||
func TestRelayIdleTimeout(t *testing.T) {
|
||||
peerA, relayA := tcpPair(t)
|
||||
relayB, peerB := tcpPair(t)
|
||||
defer peerA.Close()
|
||||
defer peerB.Close()
|
||||
|
||||
ctx := t.Context()
|
||||
|
||||
const idle = 150 * time.Millisecond
|
||||
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, relayA, relayB, Options{IdleTimeout: idle})
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("relay did not close on idle")
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
require.GreaterOrEqual(t, elapsed, idle,
|
||||
"relay must not close before the idle timeout elapses")
|
||||
require.Less(t, elapsed, idle+500*time.Millisecond,
|
||||
"relay should close shortly after the idle timeout")
|
||||
}
|
||||
Reference in New Issue
Block a user