Files
netbird/util/netrelay/relay.go
Viktor Liu 10da236dae Address PR review: connection-wide idle watchdog, test hardening
- netrelay: replace per-direction read-deadline idle tracking with a
  single connection-wide watchdog that observes activity on both sides,
  so a long one-way transfer no longer trips the timeout on the quiet
  direction. IdleTimeout==0 remains a no-op (SSH and uspfilter forwarder
  call sites pass zero); only the reverse-proxy router sets one.
- netrelay tests: bound blocking peer reads/writes with deadlines so a
  broken relay fails fast; add a lower-bound assertion on the idle-timeout
  test.
- conntrack cap tests: assert that the newest flow is admitted and an
  early flow was evicted, not just that the table stayed under the cap.
- ssh client RemotePortForward: bound the localAddr dial with a 10s
  timeout so a black-holed address can't pin the accepted channel open.
2026-04-21 13:01:50 +02:00

220 lines
6.2 KiB
Go

// Package netrelay provides a bidirectional byte-copy helper for TCP-like
// connections with correct half-close propagation.
//
// When one direction reads EOF, the write side of the opposite connection is
// half-closed (CloseWrite) so the peer sees FIN, then the second direction is
// allowed to drain to its own EOF before both connections are fully closed.
// This preserves TCP half-close semantics (e.g. shutdown(SHUT_WR)) that the
// naive "cancel-both-on-first-EOF" pattern breaks.
package netrelay
import (
"context"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"syscall"
"time"
)
// DebugLogger is the minimal interface netrelay uses to surface teardown
// errors. Both *logrus.Entry and *nblog.Logger (via its Debugf method)
// satisfy it, so callers can pass whichever they already use without an
// adapter. Debugf is the only required method; callers with richer
// loggers just expose this one shape here.
type DebugLogger interface {
Debugf(format string, args ...any)
}
// DefaultIdleTimeout is a reasonable default for Options.IdleTimeout. Callers
// that want an idle timeout but have no specific preference can use this.
const DefaultIdleTimeout = 5 * time.Minute
// halfCloser is implemented by connections that support half-close
// (e.g. *net.TCPConn, *gonet.TCPConn).
type halfCloser interface {
CloseWrite() error
}
var copyBufPool = sync.Pool{
New: func() any {
buf := make([]byte, 32*1024)
return &buf
},
}
// Options configures Relay behavior. The zero value is valid: no idle timeout,
// no logging.
type Options struct {
// IdleTimeout tears down the session if no bytes flow in either
// direction within this window. It is a connection-wide watchdog, so a
// long unidirectional transfer on one side keeps the other side alive.
// Zero disables idle tracking.
IdleTimeout time.Duration
// Logger receives debug-level copy/idle errors. Nil suppresses logging.
// Any logger with Debug/Debugf methods is accepted (logrus.Entry,
// uspfilter's nblog.Logger, etc.).
Logger DebugLogger
}
// Relay copies bytes in both directions between a and b until both directions
// EOF or ctx is canceled. On each direction's EOF it half-closes the
// opposite conn's write side (best effort) so the peer sees FIN while the
// other direction drains. Both conns are fully closed when Relay returns.
//
// a and b only need to implement io.ReadWriteCloser; connections that also
// implement CloseWrite (e.g. *net.TCPConn, ssh.Channel) get proper half-close
// propagation. Options.IdleTimeout, when set, is enforced by a connection-wide
// watchdog that tracks reads in either direction.
//
// Return values are byte counts: aToB (a.Read → b.Write) and bToA (b.Read →
// a.Write). Errors are logged via Options.Logger when set; they are not
// returned because a relay always terminates on some kind of EOF/cancel.
func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bToA int64) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
_ = a.Close()
_ = b.Close()
}()
// Both sides must support CloseWrite to propagate half-close. If either
// doesn't, a direction's EOF can't be signaled to the peer and the other
// direction would block forever waiting for data; in that case we fall
// back to the cancel-both-on-first-EOF behavior.
_, aHC := a.(halfCloser)
_, bHC := b.(halfCloser)
halfCloseSupported := aHC && bHC
var (
lastActivity atomic.Int64
idleHit atomic.Bool
)
lastActivity.Store(time.Now().UnixNano())
if opts.IdleTimeout > 0 {
go watchdog(ctx, cancel, &lastActivity, &idleHit, opts.IdleTimeout)
}
var wg sync.WaitGroup
wg.Add(2)
var errAToB, errBToA error
go func() {
defer wg.Done()
aToB, errAToB = copyTracked(b, a, &lastActivity)
if halfCloseSupported {
halfClose(b)
} else {
cancel()
}
}()
go func() {
defer wg.Done()
bToA, errBToA = copyTracked(a, b, &lastActivity)
if halfCloseSupported {
halfClose(a)
} else {
cancel()
}
}()
wg.Wait()
if opts.Logger != nil {
if idleHit.Load() {
opts.Logger.Debugf("relay closed due to idle timeout")
}
if errAToB != nil && !isExpectedCopyError(errAToB) {
opts.Logger.Debugf("relay copy error (a→b): %v", errAToB)
}
if errBToA != nil && !isExpectedCopyError(errBToA) {
opts.Logger.Debugf("relay copy error (b→a): %v", errBToA)
}
}
return aToB, bToA
}
// watchdog enforces a connection-wide idle timeout. It cancels ctx when no
// activity has been seen on either direction for idle. It exits as soon as
// ctx is canceled so it doesn't outlive the relay.
func watchdog(ctx context.Context, cancel context.CancelFunc, lastActivity *atomic.Int64, idleHit *atomic.Bool, idle time.Duration) {
tick := max(idle/2, 50*time.Millisecond)
t := time.NewTicker(tick)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
last := time.Unix(0, lastActivity.Load())
if time.Since(last) >= idle {
idleHit.Store(true)
cancel()
return
}
}
}
}
// copyTracked copies from src to dst using a pooled buffer, updating
// lastActivity on every successful read so a shared watchdog can enforce a
// connection-wide idle timeout.
func copyTracked(dst io.Writer, src io.Reader, lastActivity *atomic.Int64) (int64, error) {
bufp := copyBufPool.Get().(*[]byte)
defer copyBufPool.Put(bufp)
buf := *bufp
var total int64
for {
nr, readErr := src.Read(buf)
if nr > 0 {
lastActivity.Store(time.Now().UnixNano())
n, werr := checkedWrite(dst, buf[:nr])
total += n
if werr != nil {
return total, werr
}
}
if readErr != nil {
return total, readErr
}
}
}
func checkedWrite(dst io.Writer, buf []byte) (int64, error) {
nw, err := dst.Write(buf)
if nw < 0 || nw > len(buf) {
nw = 0
}
if err != nil {
return int64(nw), err
}
if nw != len(buf) {
return int64(nw), io.ErrShortWrite
}
return int64(nw), nil
}
func halfClose(conn io.ReadWriteCloser) {
if hc, ok := conn.(halfCloser); ok {
_ = hc.CloseWrite()
}
}
func isExpectedCopyError(err error) bool {
return errors.Is(err, net.ErrClosed) ||
errors.Is(err, context.Canceled) ||
errors.Is(err, io.EOF) ||
errors.Is(err, syscall.ECONNRESET) ||
errors.Is(err, syscall.EPIPE) ||
errors.Is(err, syscall.ECONNABORTED)
}