Address PR review: connection-wide idle watchdog, test hardening

- netrelay: replace per-direction read-deadline idle tracking with a
  single connection-wide watchdog that observes activity on both sides,
  so a long one-way transfer no longer trips the timeout on the quiet
  direction. IdleTimeout==0 remains a no-op (SSH and uspfilter forwarder
  call sites pass zero); only the reverse-proxy router sets one.
- netrelay tests: bound blocking peer reads/writes with deadlines so a
  broken relay fails fast; add a lower-bound assertion on the idle-timeout
  test.
- conntrack cap tests: assert that the newest flow is admitted and an
  early flow was evicted, not just that the table stayed under the cap.
- ssh client RemotePortForward: bound the localAddr dial with a 10s
  timeout so a black-holed address can't pin the accepted channel open.
This commit is contained in:
Viktor Liu
2026-04-21 13:01:50 +02:00
parent ffac18409e
commit 10da236dae
4 changed files with 97 additions and 44 deletions

View File

@@ -14,6 +14,7 @@ import (
"io"
"net"
"sync"
"sync/atomic"
"syscall"
"time"
)
@@ -31,10 +32,6 @@ type DebugLogger interface {
// that want an idle timeout but have no specific preference can use this.
const DefaultIdleTimeout = 5 * time.Minute
// ErrIdleTimeout is returned when a relay direction is torn down after no
// data flowed for longer than Options.IdleTimeout.
var ErrIdleTimeout = errors.New("idle timeout")
// halfCloser is implemented by connections that support half-close
// (e.g. *net.TCPConn, *gonet.TCPConn).
type halfCloser interface {
@@ -51,8 +48,10 @@ var copyBufPool = sync.Pool{
// Options configures Relay behavior. The zero value is valid: no idle timeout,
// no logging.
type Options struct {
// IdleTimeout tears down a direction if no bytes flow within this
// window. Zero disables idle tracking.
// IdleTimeout tears down the session if no bytes flow in either
// direction within this window. It is a connection-wide watchdog, so a
// long unidirectional transfer on one side keeps the other side alive.
// Zero disables idle tracking.
IdleTimeout time.Duration
// Logger receives debug-level copy/idle errors. Nil suppresses logging.
// Any logger with Debug/Debugf methods is accepted (logrus.Entry,
@@ -67,8 +66,8 @@ type Options struct {
//
// a and b only need to implement io.ReadWriteCloser; connections that also
// implement CloseWrite (e.g. *net.TCPConn, ssh.Channel) get proper half-close
// propagation. Idle-timeout enforcement requires a net.Conn; for other types
// Options.IdleTimeout is ignored.
// propagation. Options.IdleTimeout, when set, is enforced by a connection-wide
// watchdog that tracks reads in either direction.
//
// Return values are byte counts: aToB (a.Read → b.Write) and bToA (b.Read →
// a.Write). Errors are logged via Options.Logger when set; they are not
@@ -91,6 +90,16 @@ func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bT
_, bHC := b.(halfCloser)
halfCloseSupported := aHC && bHC
var (
lastActivity atomic.Int64
idleHit atomic.Bool
)
lastActivity.Store(time.Now().UnixNano())
if opts.IdleTimeout > 0 {
go watchdog(ctx, cancel, &lastActivity, &idleHit, opts.IdleTimeout)
}
var wg sync.WaitGroup
wg.Add(2)
@@ -98,7 +107,7 @@ func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bT
go func() {
defer wg.Done()
aToB, errAToB = copyWithIdleTimeout(b, a, opts.IdleTimeout)
aToB, errAToB = copyTracked(b, a, &lastActivity)
if halfCloseSupported {
halfClose(b)
} else {
@@ -108,7 +117,7 @@ func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bT
go func() {
defer wg.Done()
bToA, errBToA = copyWithIdleTimeout(a, b, opts.IdleTimeout)
bToA, errBToA = copyTracked(a, b, &lastActivity)
if halfCloseSupported {
halfClose(a)
} else {
@@ -119,7 +128,7 @@ func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bT
wg.Wait()
if opts.Logger != nil {
if errors.Is(errAToB, ErrIdleTimeout) || errors.Is(errBToA, ErrIdleTimeout) {
if idleHit.Load() {
opts.Logger.Debugf("relay closed due to idle timeout")
}
if errAToB != nil && !isExpectedCopyError(errAToB) {
@@ -133,30 +142,41 @@ func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bT
return aToB, bToA
}
// copyWithIdleTimeout copies from src to dst using a pooled buffer. When
// idleTimeout > 0 and src is a net.Conn it sets a read deadline before each
// read; a timeout is reported as ErrIdleTimeout.
func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) {
// watchdog enforces a connection-wide idle timeout. It cancels ctx when no
// activity has been seen on either direction for idle. It exits as soon as
// ctx is canceled so it doesn't outlive the relay.
func watchdog(ctx context.Context, cancel context.CancelFunc, lastActivity *atomic.Int64, idleHit *atomic.Bool, idle time.Duration) {
tick := max(idle/2, 50*time.Millisecond)
t := time.NewTicker(tick)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
last := time.Unix(0, lastActivity.Load())
if time.Since(last) >= idle {
idleHit.Store(true)
cancel()
return
}
}
}
}
// copyTracked copies from src to dst using a pooled buffer, updating
// lastActivity on every successful read so a shared watchdog can enforce a
// connection-wide idle timeout.
func copyTracked(dst io.Writer, src io.Reader, lastActivity *atomic.Int64) (int64, error) {
bufp := copyBufPool.Get().(*[]byte)
defer copyBufPool.Put(bufp)
if idleTimeout <= 0 {
return io.CopyBuffer(dst, src, *bufp)
}
conn, ok := src.(net.Conn)
if !ok {
return io.CopyBuffer(dst, src, *bufp)
}
buf := *bufp
var total int64
for {
if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
return total, err
}
nr, readErr := src.Read(buf)
if nr > 0 {
lastActivity.Store(time.Now().UnixNano())
n, werr := checkedWrite(dst, buf[:nr])
total += n
if werr != nil {
@@ -164,9 +184,6 @@ func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration
}
}
if readErr != nil {
if isNetTimeout(readErr) {
return total, ErrIdleTimeout
}
return total, readErr
}
}
@@ -192,18 +209,7 @@ func halfClose(conn io.ReadWriteCloser) {
}
}
func isNetTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
return netErr.Timeout()
}
return false
}
func isExpectedCopyError(err error) bool {
if errors.Is(err, ErrIdleTimeout) {
return true
}
return errors.Is(err, net.ErrClosed) ||
errors.Is(err, context.Canceled) ||
errors.Is(err, io.EOF) ||

View File

@@ -51,6 +51,12 @@ func TestRelayHalfClose(t *testing.T) {
defer peerA.Close()
defer peerB.Close()
// Bound blocking reads/writes so a broken relay fails the test instead of
// hanging the test process.
deadline := time.Now().Add(5 * time.Second)
require.NoError(t, peerA.SetDeadline(deadline))
require.NoError(t, peerB.SetDeadline(deadline))
ctx := t.Context()
done := make(chan struct{})
@@ -98,6 +104,12 @@ func TestRelayFullDuplex(t *testing.T) {
defer peerA.Close()
defer peerB.Close()
// Bound blocking reads/writes so a broken relay fails the test instead of
// hanging the test process.
deadline := time.Now().Add(5 * time.Second)
require.NoError(t, peerA.SetDeadline(deadline))
require.NoError(t, peerB.SetDeadline(deadline))
ctx := t.Context()
done := make(chan struct{})
@@ -186,10 +198,12 @@ func TestRelayIdleTimeout(t *testing.T) {
ctx := t.Context()
const idle = 150 * time.Millisecond
start := time.Now()
done := make(chan struct{})
go func() {
Relay(ctx, relayA, relayB, Options{IdleTimeout: 150 * time.Millisecond})
Relay(ctx, relayA, relayB, Options{IdleTimeout: idle})
close(done)
}()
@@ -199,5 +213,9 @@ func TestRelayIdleTimeout(t *testing.T) {
t.Fatal("relay did not close on idle")
}
require.WithinDuration(t, start.Add(150*time.Millisecond), time.Now(), 500*time.Millisecond)
elapsed := time.Since(start)
require.GreaterOrEqual(t, elapsed, idle,
"relay must not close before the idle timeout elapses")
require.Less(t, elapsed, idle+500*time.Millisecond,
"relay should close shortly after the idle timeout")
}