From 10da236dae7420a75bbdee5f79ac63c1d2d9c7c8 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Tue, 21 Apr 2026 13:01:50 +0200 Subject: [PATCH] Address PR review: connection-wide idle watchdog, test hardening - netrelay: replace per-direction read-deadline idle tracking with a single connection-wide watchdog that observes activity on both sides, so a long one-way transfer no longer trips the timeout on the quiet direction. IdleTimeout==0 remains a no-op (SSH and uspfilter forwarder call sites pass zero); only the reverse-proxy router sets one. - netrelay tests: bound blocking peer reads/writes with deadlines so a broken relay fails fast; add a lower-bound assertion on the idle-timeout test. - conntrack cap tests: assert that the newest flow is admitted and an early flow was evicted, not just that the table stayed under the cap. - ssh client RemotePortForward: bound the localAddr dial with a 10s timeout so a black-holed address can't pin the accepted channel open. --- .../firewall/uspfilter/conntrack/cap_test.go | 24 +++++ client/ssh/client/client.go | 7 +- util/netrelay/relay.go | 88 ++++++++++--------- util/netrelay/relay_test.go | 22 ++++- 4 files changed, 97 insertions(+), 44 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/cap_test.go b/client/firewall/uspfilter/conntrack/cap_test.go index 1f633f134..7b7f814f1 100644 --- a/client/firewall/uspfilter/conntrack/cap_test.go +++ b/client/firewall/uspfilter/conntrack/cap_test.go @@ -25,6 +25,16 @@ func TestTCPCapEvicts(t *testing.T) { "TCP table must not exceed the configured cap") require.Greater(t, len(tracker.connections), 0, "some entries must remain after eviction") + + // The most recently admitted flow must be present: eviction must make + // room for new entries, not silently drop them. + require.Contains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10009), DstPort: 80}, + "newest TCP flow must be admitted after eviction") + // A pre-cap flow must have been evicted to fit the last one. + require.NotContains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10000), DstPort: 80}, + "oldest TCP flow should have been evicted") } func TestTCPCapPrefersTombstonedForEviction(t *testing.T) { @@ -71,6 +81,13 @@ func TestUDPCapEvicts(t *testing.T) { } require.LessOrEqual(t, len(tracker.connections), 5) require.Greater(t, len(tracker.connections), 0) + + require.Contains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30011), DstPort: 53}, + "newest UDP flow must be admitted after eviction") + require.NotContains(t, tracker.connections, + ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30000), DstPort: 53}, + "oldest UDP flow should have been evicted") } func TestICMPCapEvicts(t *testing.T) { @@ -89,4 +106,11 @@ func TestICMPCapEvicts(t *testing.T) { } require.LessOrEqual(t, len(tracker.connections), 3) require.Greater(t, len(tracker.connections), 0) + + require.Contains(t, tracker.connections, + ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(7)}, + "newest ICMP flow must be admitted after eviction") + require.NotContains(t, tracker.connections, + ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(0)}, + "oldest ICMP flow should have been evicted") } diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index 61904366d..ebf8eb794 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -679,9 +679,14 @@ func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.New go ssh.DiscardRequests(reqs) + // Bound the dial so a black-holed localAddr can't pin the accepted SSH + // channel open indefinitely; the relay itself runs under the outer ctx. + dialCtx, cancelDial := context.WithTimeout(ctx, 10*time.Second) var dialer net.Dialer - localConn, err := dialer.DialContext(ctx, "tcp", localAddr) + localConn, err := dialer.DialContext(dialCtx, "tcp", localAddr) + cancelDial() if err != nil { + log.Debugf("remote port forwarding: dial %s: %v", localAddr, err) return } defer func() { diff --git a/util/netrelay/relay.go b/util/netrelay/relay.go index 3afd35b1b..662ce3e1d 100644 --- a/util/netrelay/relay.go +++ b/util/netrelay/relay.go @@ -14,6 +14,7 @@ import ( "io" "net" "sync" + "sync/atomic" "syscall" "time" ) @@ -31,10 +32,6 @@ type DebugLogger interface { // that want an idle timeout but have no specific preference can use this. const DefaultIdleTimeout = 5 * time.Minute -// ErrIdleTimeout is returned when a relay direction is torn down after no -// data flowed for longer than Options.IdleTimeout. -var ErrIdleTimeout = errors.New("idle timeout") - // halfCloser is implemented by connections that support half-close // (e.g. *net.TCPConn, *gonet.TCPConn). type halfCloser interface { @@ -51,8 +48,10 @@ var copyBufPool = sync.Pool{ // Options configures Relay behavior. The zero value is valid: no idle timeout, // no logging. type Options struct { - // IdleTimeout tears down a direction if no bytes flow within this - // window. Zero disables idle tracking. + // IdleTimeout tears down the session if no bytes flow in either + // direction within this window. It is a connection-wide watchdog, so a + // long unidirectional transfer on one side keeps the other side alive. + // Zero disables idle tracking. IdleTimeout time.Duration // Logger receives debug-level copy/idle errors. Nil suppresses logging. // Any logger with Debug/Debugf methods is accepted (logrus.Entry, @@ -67,8 +66,8 @@ type Options struct { // // a and b only need to implement io.ReadWriteCloser; connections that also // implement CloseWrite (e.g. *net.TCPConn, ssh.Channel) get proper half-close -// propagation. Idle-timeout enforcement requires a net.Conn; for other types -// Options.IdleTimeout is ignored. +// propagation. Options.IdleTimeout, when set, is enforced by a connection-wide +// watchdog that tracks reads in either direction. // // Return values are byte counts: aToB (a.Read → b.Write) and bToA (b.Read → // a.Write). Errors are logged via Options.Logger when set; they are not @@ -91,6 +90,16 @@ func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bT _, bHC := b.(halfCloser) halfCloseSupported := aHC && bHC + var ( + lastActivity atomic.Int64 + idleHit atomic.Bool + ) + lastActivity.Store(time.Now().UnixNano()) + + if opts.IdleTimeout > 0 { + go watchdog(ctx, cancel, &lastActivity, &idleHit, opts.IdleTimeout) + } + var wg sync.WaitGroup wg.Add(2) @@ -98,7 +107,7 @@ func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bT go func() { defer wg.Done() - aToB, errAToB = copyWithIdleTimeout(b, a, opts.IdleTimeout) + aToB, errAToB = copyTracked(b, a, &lastActivity) if halfCloseSupported { halfClose(b) } else { @@ -108,7 +117,7 @@ func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bT go func() { defer wg.Done() - bToA, errBToA = copyWithIdleTimeout(a, b, opts.IdleTimeout) + bToA, errBToA = copyTracked(a, b, &lastActivity) if halfCloseSupported { halfClose(a) } else { @@ -119,7 +128,7 @@ func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bT wg.Wait() if opts.Logger != nil { - if errors.Is(errAToB, ErrIdleTimeout) || errors.Is(errBToA, ErrIdleTimeout) { + if idleHit.Load() { opts.Logger.Debugf("relay closed due to idle timeout") } if errAToB != nil && !isExpectedCopyError(errAToB) { @@ -133,30 +142,41 @@ func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bT return aToB, bToA } -// copyWithIdleTimeout copies from src to dst using a pooled buffer. When -// idleTimeout > 0 and src is a net.Conn it sets a read deadline before each -// read; a timeout is reported as ErrIdleTimeout. -func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) { +// watchdog enforces a connection-wide idle timeout. It cancels ctx when no +// activity has been seen on either direction for idle. It exits as soon as +// ctx is canceled so it doesn't outlive the relay. +func watchdog(ctx context.Context, cancel context.CancelFunc, lastActivity *atomic.Int64, idleHit *atomic.Bool, idle time.Duration) { + tick := max(idle/2, 50*time.Millisecond) + t := time.NewTicker(tick) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + last := time.Unix(0, lastActivity.Load()) + if time.Since(last) >= idle { + idleHit.Store(true) + cancel() + return + } + } + } +} + +// copyTracked copies from src to dst using a pooled buffer, updating +// lastActivity on every successful read so a shared watchdog can enforce a +// connection-wide idle timeout. +func copyTracked(dst io.Writer, src io.Reader, lastActivity *atomic.Int64) (int64, error) { bufp := copyBufPool.Get().(*[]byte) defer copyBufPool.Put(bufp) - if idleTimeout <= 0 { - return io.CopyBuffer(dst, src, *bufp) - } - - conn, ok := src.(net.Conn) - if !ok { - return io.CopyBuffer(dst, src, *bufp) - } - buf := *bufp var total int64 for { - if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil { - return total, err - } nr, readErr := src.Read(buf) if nr > 0 { + lastActivity.Store(time.Now().UnixNano()) n, werr := checkedWrite(dst, buf[:nr]) total += n if werr != nil { @@ -164,9 +184,6 @@ func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration } } if readErr != nil { - if isNetTimeout(readErr) { - return total, ErrIdleTimeout - } return total, readErr } } @@ -192,18 +209,7 @@ func halfClose(conn io.ReadWriteCloser) { } } -func isNetTimeout(err error) bool { - var netErr net.Error - if errors.As(err, &netErr) { - return netErr.Timeout() - } - return false -} - func isExpectedCopyError(err error) bool { - if errors.Is(err, ErrIdleTimeout) { - return true - } return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF) || diff --git a/util/netrelay/relay_test.go b/util/netrelay/relay_test.go index 26baebfbb..0cb86eb0d 100644 --- a/util/netrelay/relay_test.go +++ b/util/netrelay/relay_test.go @@ -51,6 +51,12 @@ func TestRelayHalfClose(t *testing.T) { defer peerA.Close() defer peerB.Close() + // Bound blocking reads/writes so a broken relay fails the test instead of + // hanging the test process. + deadline := time.Now().Add(5 * time.Second) + require.NoError(t, peerA.SetDeadline(deadline)) + require.NoError(t, peerB.SetDeadline(deadline)) + ctx := t.Context() done := make(chan struct{}) @@ -98,6 +104,12 @@ func TestRelayFullDuplex(t *testing.T) { defer peerA.Close() defer peerB.Close() + // Bound blocking reads/writes so a broken relay fails the test instead of + // hanging the test process. + deadline := time.Now().Add(5 * time.Second) + require.NoError(t, peerA.SetDeadline(deadline)) + require.NoError(t, peerB.SetDeadline(deadline)) + ctx := t.Context() done := make(chan struct{}) @@ -186,10 +198,12 @@ func TestRelayIdleTimeout(t *testing.T) { ctx := t.Context() + const idle = 150 * time.Millisecond + start := time.Now() done := make(chan struct{}) go func() { - Relay(ctx, relayA, relayB, Options{IdleTimeout: 150 * time.Millisecond}) + Relay(ctx, relayA, relayB, Options{IdleTimeout: idle}) close(done) }() @@ -199,5 +213,9 @@ func TestRelayIdleTimeout(t *testing.T) { t.Fatal("relay did not close on idle") } - require.WithinDuration(t, start.Add(150*time.Millisecond), time.Now(), 500*time.Millisecond) + elapsed := time.Since(start) + require.GreaterOrEqual(t, elapsed, idle, + "relay must not close before the idle timeout elapses") + require.Less(t, elapsed, idle+500*time.Millisecond, + "relay should close shortly after the idle timeout") }