Compare commits

...

5 Commits

Author SHA1 Message Date
Viktor Liu
5506507313 netrelay: wait for endpoint close before Relay returns
The closer goroutine ran asynchronously on ctx cancellation, so the
"fully closed when Relay returns" guarantee was racy: callers could see
the function return before a and b were actually Close()d. Wait on a
done channel in the defer so the guarantee holds.
2026-04-21 15:50:24 +02:00
Viktor Liu
1311fa2aad netrelay: tighten watchdog tick for short idle timeouts
Use min(idle/2, 50ms) so very short idle timeouts (mainly in tests) are
caught within one tick; the 50ms cap still keeps detection latency bounded
for long idle values without needlessly frequent wakeups.
2026-04-21 14:54:07 +02:00
Viktor Liu
be434e1eb2 Address PR review: cancel on non-EOF copy errors, stricter cap test
- netrelay: only propagate CloseWrite on clean io.EOF; cancel both sides
  on any other copy error so a short write, reset, or broken pipe can't
  leave the opposite direction blocked.
- TestTCPCapPrefersTombstonedForEviction: assert both live pre-cap
  entries survive, not just that the tombstone is gone, so a regression
  that evicts a live entry instead of the tombstone is caught.
2026-04-21 14:15:04 +02:00
Viktor Liu
10da236dae Address PR review: connection-wide idle watchdog, test hardening
- netrelay: replace per-direction read-deadline idle tracking with a
  single connection-wide watchdog that observes activity on both sides,
  so a long one-way transfer no longer trips the timeout on the quiet
  direction. IdleTimeout==0 remains a no-op (SSH and uspfilter forwarder
  call sites pass zero); only the reverse-proxy router sets one.
- netrelay tests: bound blocking peer reads/writes with deadlines so a
  broken relay fails fast; add a lower-bound assertion on the idle-timeout
  test.
- conntrack cap tests: assert that the newest flow is admitted and an
  early flow was evicted, not just that the table stayed under the cap.
- ssh client RemotePortForward: bound the localAddr dial with a 10s
  timeout so a black-holed address can't pin the accepted channel open.
2026-04-21 13:01:50 +02:00
Viktor Liu
ffac18409e Harden uspfilter conntrack and share half-close-correct TCP relay 2026-04-21 10:47:23 +02:00
26 changed files with 1624 additions and 488 deletions

View File

@@ -0,0 +1,125 @@
package conntrack
import (
"net/netip"
"testing"
"github.com/google/gopacket/layers"
"github.com/stretchr/testify/require"
)
func TestTCPCapEvicts(t *testing.T) {
t.Setenv(EnvTCPMaxEntries, "4")
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 4, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
for i := 0; i < 10; i++ {
tracker.TrackOutbound(src, dst, uint16(10000+i), 80, TCPSyn, 0)
}
require.LessOrEqual(t, len(tracker.connections), 4,
"TCP table must not exceed the configured cap")
require.Greater(t, len(tracker.connections), 0,
"some entries must remain after eviction")
// The most recently admitted flow must be present: eviction must make
// room for new entries, not silently drop them.
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10009), DstPort: 80},
"newest TCP flow must be admitted after eviction")
// A pre-cap flow must have been evicted to fit the last one.
require.NotContains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(10000), DstPort: 80},
"oldest TCP flow should have been evicted")
}
func TestTCPCapPrefersTombstonedForEviction(t *testing.T) {
t.Setenv(EnvTCPMaxEntries, "3")
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
// Fill to cap with 3 live connections.
for i := 0; i < 3; i++ {
tracker.TrackOutbound(src, dst, uint16(20000+i), 80, TCPSyn, 0)
}
require.Len(t, tracker.connections, 3)
// Tombstone one by sending RST through IsValidInbound.
tombstonedKey := ConnKey{SrcIP: src, DstIP: dst, SrcPort: 20001, DstPort: 80}
require.True(t, tracker.IsValidInbound(dst, src, 80, 20001, TCPRst|TCPAck, 0))
require.True(t, tracker.connections[tombstonedKey].IsTombstone())
// Another live connection forces eviction. The tombstone must go first.
tracker.TrackOutbound(src, dst, uint16(29999), 80, TCPSyn, 0)
_, tombstonedStillPresent := tracker.connections[tombstonedKey]
require.False(t, tombstonedStillPresent,
"tombstoned entry should be evicted before live entries")
require.LessOrEqual(t, len(tracker.connections), 3)
// Both live pre-cap entries must survive: eviction must prefer the
// tombstone, not just satisfy the size bound by dropping any entry.
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20000), DstPort: 80},
"live entries must not be evicted while a tombstone exists")
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(20002), DstPort: 80},
"live entries must not be evicted while a tombstone exists")
}
func TestUDPCapEvicts(t *testing.T) {
t.Setenv(EnvUDPMaxEntries, "5")
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 5, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
for i := 0; i < 12; i++ {
tracker.TrackOutbound(src, dst, uint16(30000+i), 53, 0)
}
require.LessOrEqual(t, len(tracker.connections), 5)
require.Greater(t, len(tracker.connections), 0)
require.Contains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30011), DstPort: 53},
"newest UDP flow must be admitted after eviction")
require.NotContains(t, tracker.connections,
ConnKey{SrcIP: src, DstIP: dst, SrcPort: uint16(30000), DstPort: 53},
"oldest UDP flow should have been evicted")
}
func TestICMPCapEvicts(t *testing.T) {
t.Setenv(EnvICMPMaxEntries, "3")
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 3, tracker.maxEntries)
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
echoReq := layers.CreateICMPv4TypeCode(uint8(layers.ICMPv4TypeEchoRequest), 0)
for i := 0; i < 8; i++ {
tracker.TrackOutbound(src, dst, uint16(i), echoReq, nil, 64)
}
require.LessOrEqual(t, len(tracker.connections), 3)
require.Greater(t, len(tracker.connections), 0)
require.Contains(t, tracker.connections,
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(7)},
"newest ICMP flow must be admitted after eviction")
require.NotContains(t, tracker.connections,
ICMPConnKey{SrcIP: src, DstIP: dst, ID: uint16(0)},
"oldest ICMP flow should have been evicted")
}

View File

@@ -3,14 +3,61 @@ package conntrack
import ( import (
"fmt" "fmt"
"net/netip" "net/netip"
"os"
"strconv"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// evictSampleSize bounds how many map entries we scan per eviction call.
// Keeps eviction O(1) even at cap under sustained load; the sampled-LRU
// heuristic is good enough for a conntrack table that only overflows under
// abuse.
const evictSampleSize = 8
// envDuration parses an os.Getenv(name) as a time.Duration. Falls back to
// def on empty or invalid; logs a warning on invalid.
func envDuration(logger *nblog.Logger, name string, def time.Duration) time.Duration {
v := os.Getenv(name)
if v == "" {
return def
}
d, err := time.ParseDuration(v)
if err != nil {
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
return def
}
if d <= 0 {
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
return def
}
return d
}
// envInt parses an os.Getenv(name) as an int. Falls back to def on empty,
// invalid, or non-positive. Logs a warning on invalid input.
func envInt(logger *nblog.Logger, name string, def int) int {
v := os.Getenv(name)
if v == "" {
return def
}
n, err := strconv.Atoi(v)
switch {
case err != nil:
logger.Warn3("invalid %s=%q: %v, using default", name, v, err)
return def
case n <= 0:
logger.Warn2("invalid %s=%q: must be positive, using default", name, v)
return def
}
return n
}
// BaseConnTrack provides common fields and locking for all connection types // BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct { type BaseConnTrack struct {
FlowId uuid.UUID FlowId uuid.UUID

View File

@@ -0,0 +1,11 @@
//go:build !ios && !android
package conntrack
// Default per-tracker entry caps on desktop/server platforms. These mirror
// typical Linux netfilter nf_conntrack_max territory with ample headroom.
const (
DefaultMaxTCPEntries = 65536
DefaultMaxUDPEntries = 16384
DefaultMaxICMPEntries = 2048
)

View File

@@ -0,0 +1,13 @@
//go:build ios || android
package conntrack
// Default per-tracker entry caps on mobile platforms. iOS network extensions
// are capped at ~50 MB; Android runs under aggressive memory pressure. These
// values keep conntrack footprint well under 5 MB worst case (TCPConnTrack
// is ~200 B plus map overhead).
const (
DefaultMaxTCPEntries = 4096
DefaultMaxUDPEntries = 2048
DefaultMaxICMPEntries = 512
)

View File

@@ -44,6 +44,9 @@ type ICMPConnTrack struct {
ICMPCode uint8 ICMPCode uint8
} }
// EnvICMPMaxEntries caps the ICMP conntrack table size.
const EnvICMPMaxEntries = "NB_CONNTRACK_ICMP_MAX"
// ICMPTracker manages ICMP connection states // ICMPTracker manages ICMP connection states
type ICMPTracker struct { type ICMPTracker struct {
logger *nblog.Logger logger *nblog.Logger
@@ -52,6 +55,7 @@ type ICMPTracker struct {
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
maxEntries int
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
@@ -135,6 +139,7 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval), cleanupTicker: time.NewTicker(ICMPCleanupInterval),
tickerCancel: cancel, tickerCancel: cancel,
maxEntries: envInt(logger, EnvICMPMaxEntries, DefaultMaxICMPEntries),
flowLogger: flowLogger, flowLogger: flowLogger,
} }
@@ -221,7 +226,9 @@ func (t *ICMPTracker) track(
// non echo requests don't need tracking // non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) { if typ != uint8(layers.ICMPv4TypeEchoRequest) {
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
}
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size) t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return return
} }
@@ -240,10 +247,15 @@ func (t *ICMPTracker) track(
conn.UpdateCounters(direction, size) conn.UpdateCounters(direction, size)
t.mutex.Lock() t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo)
}
t.sendEvent(nftypes.TypeStart, conn, ruleId) t.sendEvent(nftypes.TypeStart, conn, ruleId)
} }
@@ -286,6 +298,34 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
} }
} }
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded sample scan: picks the oldest among up to evictSampleSize entries.
func (t *ICMPTracker) evictOneLocked() {
var candKey ICMPConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
delete(t.connections, candKey)
}
}
func (t *ICMPTracker) cleanup() { func (t *ICMPTracker) cleanup() {
t.mutex.Lock() t.mutex.Lock()
defer t.mutex.Unlock() defer t.mutex.Unlock()
@@ -294,8 +334,10 @@ func (t *ICMPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key) delete(t.connections, key)
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }

View File

@@ -38,6 +38,27 @@ const (
TCPHandshakeTimeout = 60 * time.Second TCPHandshakeTimeout = 60 * time.Second
// TCPCleanupInterval is how often we check for stale connections // TCPCleanupInterval is how often we check for stale connections
TCPCleanupInterval = 5 * time.Minute TCPCleanupInterval = 5 * time.Minute
// FinWaitTimeout bounds FIN_WAIT_1 / FIN_WAIT_2 / CLOSING states.
// Matches Linux netfilter nf_conntrack_tcp_timeout_fin_wait.
FinWaitTimeout = 60 * time.Second
// CloseWaitTimeout bounds CLOSE_WAIT. Matches Linux default; apps
// holding CloseWait longer than this should bump the env var.
CloseWaitTimeout = 60 * time.Second
// LastAckTimeout bounds LAST_ACK. Matches Linux default.
LastAckTimeout = 30 * time.Second
)
// Env vars to override per-state teardown timeouts. Values parsed by
// time.ParseDuration (e.g. "120s", "2m"). Invalid values fall back to the
// defaults above with a warning.
const (
EnvTCPFinWaitTimeout = "NB_CONNTRACK_TCP_FIN_WAIT_TIMEOUT"
EnvTCPCloseWaitTimeout = "NB_CONNTRACK_TCP_CLOSE_WAIT_TIMEOUT"
EnvTCPLastAckTimeout = "NB_CONNTRACK_TCP_LAST_ACK_TIMEOUT"
// EnvTCPMaxEntries caps the TCP conntrack table size. Oldest entries
// (tombstones first) are evicted when the cap is reached.
EnvTCPMaxEntries = "NB_CONNTRACK_TCP_MAX"
) )
// TCPState represents the state of a TCP connection // TCPState represents the state of a TCP connection
@@ -140,6 +161,10 @@ type TCPTracker struct {
tickerCancel context.CancelFunc tickerCancel context.CancelFunc
timeout time.Duration timeout time.Duration
waitTimeout time.Duration waitTimeout time.Duration
finWaitTimeout time.Duration
closeWaitTimeout time.Duration
lastAckTimeout time.Duration
maxEntries int
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
@@ -161,6 +186,10 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
tickerCancel: cancel, tickerCancel: cancel,
timeout: timeout, timeout: timeout,
waitTimeout: waitTimeout, waitTimeout: waitTimeout,
finWaitTimeout: envDuration(logger, EnvTCPFinWaitTimeout, FinWaitTimeout),
closeWaitTimeout: envDuration(logger, EnvTCPCloseWaitTimeout, CloseWaitTimeout),
lastAckTimeout: envDuration(logger, EnvTCPLastAckTimeout, LastAckTimeout),
maxEntries: envInt(logger, EnvTCPMaxEntries, DefaultMaxTCPEntries),
flowLogger: flowLogger, flowLogger: flowLogger,
} }
@@ -209,6 +238,12 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
if exists || flags&TCPSyn == 0 { if exists || flags&TCPSyn == 0 {
return return
} }
// Reject illegal SYN combinations (SYN+FIN, SYN+RST, …) so they don't
// create spurious conntrack entries. Not mandated by RFC 9293 but a
// common hardening (Linux netfilter/nftables rejects these too).
if !isValidFlagCombination(flags) {
return
}
conn := &TCPConnTrack{ conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{ BaseConnTrack: BaseConnTrack{
@@ -225,20 +260,65 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla
conn.state.Store(int32(TCPStateNew)) conn.state.Store(int32(TCPStateNew))
conn.DNATOrigPort.Store(uint32(origPort)) conn.DNATOrigPort.Store(uint32(origPort))
if t.logger.Enabled(nblog.LevelTrace) {
if origPort != 0 { if origPort != 0 {
t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else { } else {
t.logger.Trace2("New %s TCP connection: %s", direction, key) t.logger.Trace2("New %s TCP connection: %s", direction, key)
} }
}
t.updateState(key, conn, flags, direction, size) t.updateState(key, conn, flags, direction, size)
t.mutex.Lock() t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
t.sendEvent(nftypes.TypeStart, conn, ruleID) t.sendEvent(nftypes.TypeStart, conn, ruleID)
} }
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded scan: samples up to evictSampleSize pseudo-random entries (Go map
// iteration order is randomized), preferring a tombstone. If no tombstone
// found in the sample, evicts the oldest among the sampled entries. O(1)
// worst case — cheap enough to run on every insert at cap during abuse.
func (t *TCPTracker) evictOneLocked() {
var candKey ConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
if c.IsTombstone() {
delete(t.connections, k)
return
}
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
// TypeEnd is already emitted at the state transition to
// TimeWait and when a connection is tombstoned. Only emit
// here when we're reaping a still-active flow.
if evicted.GetState() != TCPStateTimeWait && !evicted.IsTombstone() {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
}
delete(t.connections, candKey)
}
}
// IsValidInbound checks if an inbound TCP packet matches a tracked connection // IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool { func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
key := ConnKey{ key := ConnKey{
@@ -256,12 +336,19 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
return false return false
} }
// Reject illegal flag combinations regardless of state. These never belong
// to a legitimate flow and must not advance or tear down state.
if !isValidFlagCombination(flags) {
if t.logger.Enabled(nblog.LevelWarn) {
t.logger.Warn3("TCP illegal flag combination %x for connection %s (state %s)", flags, key, conn.GetState())
}
return false
}
currentState := conn.GetState() currentState := conn.GetState()
if !t.isValidStateForFlags(currentState, flags) { if !t.isValidStateForFlags(currentState, flags) {
if t.logger.Enabled(nblog.LevelWarn) {
t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key) t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
// allow all flags for established for now
if currentState == TCPStateEstablished {
return true
} }
return false return false
} }
@@ -270,116 +357,208 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
return true return true
} }
// updateState updates the TCP connection state based on flags // updateState updates the TCP connection state based on flags.
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) { func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
conn.UpdateLastSeen()
conn.UpdateCounters(packetDir, size) conn.UpdateCounters(packetDir, size)
// Malformed flag combinations must not refresh lastSeen or drive state,
// otherwise spoofed packets keep a dead flow alive past its timeout.
if !isValidFlagCombination(flags) {
return
}
conn.UpdateLastSeen()
currentState := conn.GetState() currentState := conn.GetState()
if flags&TCPRst != 0 { if flags&TCPRst != 0 {
if conn.CompareAndSwapState(currentState, TCPStateClosed) { // Hardening beyond RFC 9293 §3.10.7.4: without sequence tracking we
conn.SetTombstone() // cannot apply the RFC 5961 in-window RST check, so we conservatively
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]", // reject RSTs that the spec would accept (TIME-WAIT with in-window
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) // SEQ, SynSent from same direction as own SYN, etc.).
t.sendEvent(nftypes.TypeEnd, conn, nil) t.handleRst(key, conn, currentState, packetDir)
}
return return
} }
var newState TCPState newState := nextState(currentState, conn.Direction, packetDir, flags)
switch currentState { if newState == 0 || !conn.CompareAndSwapState(currentState, newState) {
case TCPStateNew: return
}
t.onTransition(key, conn, currentState, newState, packetDir)
}
// handleRst processes a RST segment. Late RSTs in TimeWait and spoofed RSTs
// from the SYN direction are ignored; otherwise the flow is tombstoned.
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, currentState TCPState, packetDir nftypes.Direction) {
// TimeWait exists to absorb late segments; don't let a late RST
// tombstone the entry and break same-4-tuple reuse.
if currentState == TCPStateTimeWait {
return
}
// A RST from the same direction as the SYN cannot be a legitimate
// response and must not tear down a half-open connection.
if currentState == TCPStateSynSent && packetDir == conn.Direction {
return
}
if !conn.CompareAndSwapState(currentState, TCPStateClosed) {
return
}
conn.SetTombstone()
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
// stateTransition describes one state's transition logic. It receives the
// packet's flags plus whether the packet direction matches the connection's
// origin direction (same=true means same side as the SYN initiator). Return 0
// for no transition.
type stateTransition func(flags uint8, connDir nftypes.Direction, same bool) TCPState
// stateTable maps each state to its transition function. Centralized here so
// nextState stays trivial and each rule is easy to read in isolation.
var stateTable = map[TCPState]stateTransition{
TCPStateNew: transNew,
TCPStateSynSent: transSynSent,
TCPStateSynReceived: transSynReceived,
TCPStateEstablished: transEstablished,
TCPStateFinWait1: transFinWait1,
TCPStateFinWait2: transFinWait2,
TCPStateClosing: transClosing,
TCPStateCloseWait: transCloseWait,
TCPStateLastAck: transLastAck,
}
// nextState returns the target TCP state for the given current state and
// packet, or 0 if the packet does not trigger a transition.
func nextState(currentState TCPState, connDir, packetDir nftypes.Direction, flags uint8) TCPState {
fn, ok := stateTable[currentState]
if !ok {
return 0
}
return fn(flags, connDir, packetDir == connDir)
}
func transNew(flags uint8, connDir nftypes.Direction, _ bool) TCPState {
if flags&TCPSyn != 0 && flags&TCPAck == 0 { if flags&TCPSyn != 0 && flags&TCPAck == 0 {
if conn.Direction == nftypes.Egress { if connDir == nftypes.Egress {
newState = TCPStateSynSent return TCPStateSynSent
} else {
newState = TCPStateSynReceived
} }
return TCPStateSynReceived
} }
return 0
}
case TCPStateSynSent: func transSynSent(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPSyn != 0 && flags&TCPAck != 0 { if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if packetDir != conn.Direction { if same {
newState = TCPStateEstablished return TCPStateSynReceived // simultaneous open
} else {
// Simultaneous open
newState = TCPStateSynReceived
} }
return TCPStateEstablished
} }
return 0
}
case TCPStateSynReceived: func transSynReceived(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && flags&TCPSyn == 0 { if flags&TCPAck != 0 && flags&TCPSyn == 0 && same {
if packetDir == conn.Direction { return TCPStateEstablished
newState = TCPStateEstablished
}
} }
return 0
}
case TCPStateEstablished: func transEstablished(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin != 0 { if flags&TCPFin == 0 {
if packetDir == conn.Direction { return 0
newState = TCPStateFinWait1
} else {
newState = TCPStateCloseWait
} }
if same {
return TCPStateFinWait1
} }
return TCPStateCloseWait
}
case TCPStateFinWait1: // transFinWait1 handles the active-close peer response. A FIN carrying our
if packetDir != conn.Direction { // ACK piggybacked goes straight to TIME-WAIT (RFC 9293 §3.10.7.4, FIN-WAIT-1:
// "if our FIN has been ACKed... enter the TIME-WAIT state"); a lone FIN moves
// to CLOSING; a pure ACK of our FIN moves to FIN-WAIT-2.
func transFinWait1(flags uint8, _ nftypes.Direction, same bool) TCPState {
if same {
return 0
}
if flags&TCPFin != 0 && flags&TCPAck != 0 {
return TCPStateTimeWait
}
switch { switch {
case flags&TCPFin != 0 && flags&TCPAck != 0:
newState = TCPStateClosing
case flags&TCPFin != 0: case flags&TCPFin != 0:
newState = TCPStateClosing return TCPStateClosing
case flags&TCPAck != 0: case flags&TCPAck != 0:
newState = TCPStateFinWait2 return TCPStateFinWait2
} }
return 0
}
// transFinWait2 ignores own-side FIN retransmits; only the peer's FIN advances.
func transFinWait2(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin != 0 && !same {
return TCPStateTimeWait
}
return 0
}
// transClosing completes a simultaneous close on the peer's ACK.
func transClosing(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && !same {
return TCPStateTimeWait
}
return 0
}
// transCloseWait only advances to LastAck when WE send FIN, ignoring peer retransmits.
func transCloseWait(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPFin != 0 && same {
return TCPStateLastAck
}
return 0
}
// transLastAck closes the flow only on the peer's ACK (not our own ACK retransmits).
func transLastAck(flags uint8, _ nftypes.Direction, same bool) TCPState {
if flags&TCPAck != 0 && !same {
return TCPStateClosed
}
return 0
}
// onTransition handles logging and flow-event emission after a successful
// state transition. TimeWait and Closed are terminal for flow accounting.
func (t *TCPTracker) onTransition(key ConnKey, conn *TCPConnTrack, from, to TCPState, packetDir nftypes.Direction) {
traceOn := t.logger.Enabled(nblog.LevelTrace)
if traceOn {
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, from, to, packetDir)
} }
case TCPStateFinWait2: switch to {
if flags&TCPFin != 0 {
newState = TCPStateTimeWait
}
case TCPStateClosing:
if flags&TCPAck != 0 {
newState = TCPStateTimeWait
}
case TCPStateCloseWait:
if flags&TCPFin != 0 {
newState = TCPStateLastAck
}
case TCPStateLastAck:
if flags&TCPAck != 0 {
newState = TCPStateClosed
}
}
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
switch newState {
case TCPStateTimeWait: case TCPStateTimeWait:
if traceOn {
t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]", t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil) t.sendEvent(nftypes.TypeEnd, conn, nil)
case TCPStateClosed: case TCPStateClosed:
conn.SetTombstone() conn.SetTombstone()
if traceOn {
t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]", t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
// isValidStateForFlags checks if the TCP flags are valid for the current connection state // isValidStateForFlags checks if the TCP flags are valid for the current
// connection state. Caller must have already verified the flag combination is
// legal via isValidFlagCombination.
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
if !isValidFlagCombination(flags) {
return false
}
if flags&TCPRst != 0 { if flags&TCPRst != 0 {
if state == TCPStateSynSent { if state == TCPStateSynSent {
return flags&TCPAck != 0 return flags&TCPAck != 0
@@ -449,15 +628,24 @@ func (t *TCPTracker) cleanup() {
timeout = t.waitTimeout timeout = t.waitTimeout
case TCPStateEstablished: case TCPStateEstablished:
timeout = t.timeout timeout = t.timeout
case TCPStateFinWait1, TCPStateFinWait2, TCPStateClosing:
timeout = t.finWaitTimeout
case TCPStateCloseWait:
timeout = t.closeWaitTimeout
case TCPStateLastAck:
timeout = t.lastAckTimeout
default: default:
// SynSent / SynReceived / New
timeout = TCPHandshakeTimeout timeout = TCPHandshakeTimeout
} }
if conn.timeoutExceeded(timeout) { if conn.timeoutExceeded(timeout) {
delete(t.connections, key) delete(t.connections, key)
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]", t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
// event already handled by state change // event already handled by state change
if currentState != TCPStateTimeWait { if currentState != TCPStateTimeWait {

View File

@@ -0,0 +1,100 @@
package conntrack
import (
"net/netip"
"testing"
"github.com/stretchr/testify/require"
)
// RST hygiene tests: the tracker currently closes the flow on any RST that
// matches the 4-tuple, regardless of direction or state. These tests cover
// the minimum checks we want (no SEQ tracking).
func TestTCPRstInSynSentWrongDirection(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateSynSent, conn.GetState())
// A RST arriving in the same direction as the SYN (i.e. TrackOutbound)
// cannot be a legitimate response. It must not close the connection.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPRst|TCPAck, 0)
require.Equal(t, TCPStateSynSent, conn.GetState(),
"RST in same direction as SYN must not close connection")
require.False(t, conn.IsTombstone())
}
func TestTCPRstInTimeWaitIgnored(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
// Drive to TIME-WAIT via active close.
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateTimeWait, conn.GetState())
require.False(t, conn.IsTombstone(), "TIME-WAIT must not be tombstoned")
// Late RST during TIME-WAIT must not tombstone the entry (TIME-WAIT
// exists to absorb late segments).
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.Equal(t, TCPStateTimeWait, conn.GetState(),
"RST in TIME-WAIT must not transition state")
require.False(t, conn.IsTombstone(),
"RST in TIME-WAIT must not tombstone the entry")
}
func TestTCPIllegalFlagCombos(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
conn := tracker.connections[key]
// Illegal combos must be rejected and must not change state.
combos := []struct {
name string
flags uint8
}{
{"SYN+RST", TCPSyn | TCPRst},
{"FIN+RST", TCPFin | TCPRst},
{"SYN+FIN", TCPSyn | TCPFin},
{"SYN+FIN+RST", TCPSyn | TCPFin | TCPRst},
}
for _, c := range combos {
t.Run(c.name, func(t *testing.T) {
before := conn.GetState()
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, c.flags, 0)
require.False(t, valid, "illegal flag combo must be rejected: %s", c.name)
require.Equal(t, before, conn.GetState(),
"illegal flag combo must not change state")
require.False(t, conn.IsTombstone())
})
}
}

View File

@@ -0,0 +1,235 @@
package conntrack
import (
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// These tests exercise cases where the TCP state machine currently advances
// on retransmitted or wrong-direction segments and tears the flow down
// prematurely. They are expected to fail until the direction checks are added.
func TestTCPCloseWaitRetransmittedPeerFIN(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Peer sends FIN -> CloseWait (our app has not yet closed).
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.Equal(t, TCPStateCloseWait, conn.GetState())
// Peer retransmits their FIN (ACK may have been delayed). We have NOT
// sent our FIN yet, so state must remain CloseWait.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "retransmitted peer FIN must still be accepted")
require.Equal(t, TCPStateCloseWait, conn.GetState(),
"retransmitted peer FIN must not advance CloseWait to LastAck")
// Our app finally closes -> LastAck.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState())
// Peer ACK closes.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateClosed, conn.GetState())
}
func TestTCPFinWait2RetransmittedOwnFIN(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// We initiate close.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait2, conn.GetState())
// Stray retransmit of our own FIN (same direction as originator) must
// NOT advance FinWait2 to TimeWait; only the peer's FIN should.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
require.Equal(t, TCPStateFinWait2, conn.GetState(),
"own FIN retransmit must not advance FinWait2 to TimeWait")
// Peer FIN -> TimeWait.
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid)
require.Equal(t, TCPStateTimeWait, conn.GetState())
}
func TestTCPLastAckDirectionCheck(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Drive to LastAck: peer FIN -> CloseWait, our FIN -> LastAck.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateLastAck, conn.GetState())
// Our own ACK retransmit (same direction as originator) must NOT close.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateLastAck, conn.GetState(),
"own ACK retransmit in LastAck must not transition to Closed")
// Peer's ACK -> Closed.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0))
require.Equal(t, TCPStateClosed, conn.GetState())
}
func TestTCPFinWait1OwnAckDoesNotAdvance(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
conn := tracker.connections[key]
require.Equal(t, TCPStateFinWait1, conn.GetState())
// Our own ACK retransmit (same direction as originator) must not advance.
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
require.Equal(t, TCPStateFinWait1, conn.GetState(),
"own ACK in FinWait1 must not advance to FinWait2")
}
func TestTCPPerStateTeardownTimeouts(t *testing.T) {
// Verify cleanup reaps entries in each teardown state at the configured
// per-state timeout, not at the single handshake timeout.
t.Setenv(EnvTCPFinWaitTimeout, "50ms")
t.Setenv(EnvTCPCloseWaitTimeout, "80ms")
t.Setenv(EnvTCPLastAckTimeout, "30ms")
dstIP := netip.MustParseAddr("100.64.0.2")
dstPort := uint16(80)
// Drives a connection to the target state, forces its lastSeen well
// beyond the configured timeout, runs cleanup, and asserts reaping.
cases := []struct {
name string
// drive takes a fresh tracker and returns the conn key after
// transitioning the flow into the intended teardown state.
drive func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState)
}{
{
name: "FinWait1",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → FinWait1
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait1
},
},
{
name: "FinWait2",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // FinWait1
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) // → FinWait2
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait2
},
},
{
name: "CloseWait",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // → CloseWait
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateCloseWait
},
},
{
name: "LastAck",
drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) {
establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort)
require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // CloseWait
tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → LastAck
return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateLastAck
},
},
}
// Use a unique source port per subtest so nothing aliases.
port := uint16(12345)
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
require.Equal(t, 50*time.Millisecond, tracker.finWaitTimeout)
require.Equal(t, 80*time.Millisecond, tracker.closeWaitTimeout)
require.Equal(t, 30*time.Millisecond, tracker.lastAckTimeout)
srcIP := netip.MustParseAddr("100.64.0.1")
port++
key, wantState := c.drive(t, tracker, srcIP, port)
conn := tracker.connections[key]
require.NotNil(t, conn)
require.Equal(t, wantState, conn.GetState())
// Age the entry past the largest per-state timeout.
conn.lastSeen.Store(time.Now().Add(-500 * time.Millisecond).UnixNano())
tracker.cleanup()
_, exists := tracker.connections[key]
require.False(t, exists, "%s entry should be reaped", c.name)
})
}
}
func TestTCPEstablishedPSHACKInFinStates(t *testing.T) {
// Verifies FIN|PSH|ACK and bare ACK keepalives are not dropped in FIN
// teardown states, which some stacks emit during close.
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close()
srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345)
dstPort := uint16(80)
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Peer FIN -> CloseWait.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0))
// Peer pushes trailing data + FIN|PSH|ACK (legal).
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPPush|TCPAck, 100),
"FIN|PSH|ACK in CloseWait must be accepted")
// Bare ACK keepalive from peer in CloseWait must be accepted.
require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0),
"bare ACK in CloseWait must be accepted")
}

View File

@@ -17,6 +17,9 @@ const (
DefaultUDPTimeout = 30 * time.Second DefaultUDPTimeout = 30 * time.Second
// UDPCleanupInterval is how often we check for stale connections // UDPCleanupInterval is how often we check for stale connections
UDPCleanupInterval = 15 * time.Second UDPCleanupInterval = 15 * time.Second
// EnvUDPMaxEntries caps the UDP conntrack table size.
EnvUDPMaxEntries = "NB_CONNTRACK_UDP_MAX"
) )
// UDPConnTrack represents a UDP connection state // UDPConnTrack represents a UDP connection state
@@ -34,6 +37,7 @@ type UDPTracker struct {
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
maxEntries int
flowLogger nftypes.FlowLogger flowLogger nftypes.FlowLogger
} }
@@ -51,6 +55,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval), cleanupTicker: time.NewTicker(UDPCleanupInterval),
tickerCancel: cancel, tickerCancel: cancel,
maxEntries: envInt(logger, EnvUDPMaxEntries, DefaultMaxUDPEntries),
flowLogger: flowLogger, flowLogger: flowLogger,
} }
@@ -117,14 +122,19 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
conn.UpdateCounters(direction, size) conn.UpdateCounters(direction, size)
t.mutex.Lock() t.mutex.Lock()
if t.maxEntries > 0 && len(t.connections) >= t.maxEntries {
t.evictOneLocked()
}
t.connections[key] = conn t.connections[key] = conn
t.mutex.Unlock() t.mutex.Unlock()
if t.logger.Enabled(nblog.LevelTrace) {
if origPort != 0 { if origPort != 0 {
t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort)
} else { } else {
t.logger.Trace2("New %s UDP connection: %s", direction, key) t.logger.Trace2("New %s UDP connection: %s", direction, key)
} }
}
t.sendEvent(nftypes.TypeStart, conn, ruleID) t.sendEvent(nftypes.TypeStart, conn, ruleID)
} }
@@ -151,6 +161,34 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
return true return true
} }
// evictOneLocked removes one entry to make room. Caller must hold t.mutex.
// Bounded sample: picks the oldest among up to evictSampleSize entries.
func (t *UDPTracker) evictOneLocked() {
var candKey ConnKey
var candSeen int64
haveCand := false
sampled := 0
for k, c := range t.connections {
seen := c.lastSeen.Load()
if !haveCand || seen < candSeen {
candKey = k
candSeen = seen
haveCand = true
}
sampled++
if sampled >= evictSampleSize {
break
}
}
if haveCand {
if evicted := t.connections[candKey]; evicted != nil {
t.sendEvent(nftypes.TypeEnd, evicted, nil)
}
delete(t.connections, candKey)
}
}
// cleanupRoutine periodically removes stale connections // cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine(ctx context.Context) { func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop() defer t.cleanupTicker.Stop()
@@ -173,8 +211,10 @@ func (t *UDPTracker) cleanup() {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
delete(t.connections, key) delete(t.connections, key)
if t.logger.Enabled(nblog.LevelTrace) {
t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]", t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
}
t.sendEvent(nftypes.TypeEnd, conn, nil) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }

View File

@@ -709,7 +709,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
srcIP, dstIP := m.extractIPs(d) srcIP, dstIP := m.extractIPs(d)
if !srcIP.IsValid() { if !srcIP.IsValid() {
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("Unknown network layer: %v", d.decoded[0]) m.logger.Error1("Unknown network layer: %v", d.decoded[0])
}
return false return false
} }
@@ -808,7 +810,9 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
return false return false
} }
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue) m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
}
return true return true
} }
@@ -931,8 +935,10 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
// TODO: pass fragments of routed packets to forwarder // TODO: pass fragments of routed packets to forwarder
if fragment { if fragment {
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v",
srcIP, dstIP, d.ip4.Id, d.ip4.Flags) srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
}
return false return false
} }
@@ -974,8 +980,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
pnum := getProtocolFromPacket(d) pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort) ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
}
m.flowLogger.StoreEvent(nftypes.EventFields{ m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
@@ -1025,8 +1033,10 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
// Drop if routing is disabled // Drop if routing is disabled
if !m.routingEnabled.Load() { if !m.routingEnabled.Load() {
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s", m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP) srcIP, dstIP)
}
return true return true
} }
@@ -1043,8 +1053,10 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
if !pass { if !pass {
proto := getProtocolFromPacket(d) proto := getProtocolFromPacket(d)
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, proto, srcIP, srcPort, dstIP, dstPort) ruleID, proto, srcIP, srcPort, dstIP, dstPort)
}
m.flowLogger.StoreEvent(nftypes.EventFields{ m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
@@ -1126,7 +1138,9 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
// It returns true, true if the packet is a fragment and valid. // It returns true, true if the packet is a fragment and valid.
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace1("couldn't decode packet, err: %s", err) m.logger.Trace1("couldn't decode packet, err: %s", err)
}
return false, false return false, false
} }

View File

@@ -13,6 +13,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
@@ -92,8 +93,10 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by
return nil, fmt.Errorf("write ICMP packet: %w", err) return nil, fmt.Errorf("write ICMP packet: %w", err)
} }
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v", f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
epID(id), icmpType, icmpCode) epID(id), icmpType, icmpCode)
}
return conn, nil return conn, nil
} }
@@ -116,8 +119,10 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp
txBytes := f.handleEchoResponse(conn, id) txBytes := f.handleEchoResponse(conn, id)
rtt := time.Since(sendTime).Round(10 * time.Microsecond) rtt := time.Since(sendTime).Round(10 * time.Microsecond)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)", f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
epID(id), icmpType, icmpCode, rtt) epID(id), icmpType, icmpCode, rtt)
}
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
} }
@@ -198,13 +203,17 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi
} }
rtt := time.Since(pingStart).Round(10 * time.Microsecond) rtt := time.Since(pingStart).Round(10 * time.Microsecond)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v", f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
epID(id), icmpType, icmpCode) epID(id), icmpType, icmpCode)
}
txBytes := f.synthesizeEchoReply(id, icmpData) txBytes := f.synthesizeEchoReply(id, icmpData)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)", f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
epID(id), icmpType, icmpCode, rtt) epID(id), icmpType, icmpCode, rtt)
}
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
} }

View File

@@ -1,12 +1,9 @@
package forwarder package forwarder
import ( import (
"context"
"fmt" "fmt"
"io"
"net" "net"
"net/netip" "net/netip"
"sync"
"github.com/google/uuid" "github.com/google/uuid"
@@ -16,7 +13,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/util/netrelay"
) )
// handleTCP is called by the TCP forwarder for new connections. // handleTCP is called by the TCP forwarder for new connections.
@@ -38,7 +37,9 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil { if err != nil {
r.Complete(true) r.Complete(true)
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err) f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err)
}
return return
} }
@@ -61,64 +62,22 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep) inConn := gonet.NewTCPConn(&wq, ep)
success = true success = true
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: established TCP connection %v", epID(id)) f.logger.Trace1("forwarder: established TCP connection %v", epID(id))
}
go f.proxyTCP(id, inConn, outConn, ep, flowID) go f.proxyTCP(id, inConn, outConn, ep, flowID)
} }
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) { func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
// netrelay.Relay copies bidirectionally with proper half-close propagation
// and fully closes both conns before returning.
bytesFromInToOut, bytesFromOutToIn := netrelay.Relay(f.ctx, inConn, outConn, netrelay.Options{
Logger: f.logger,
})
ctx, cancel := context.WithCancel(f.ctx) // Close the netstack endpoint after both conns are drained.
defer cancel()
go func() {
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: inConn close error: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: outConn close error: %v", err)
}
ep.Close() ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesFromInToOut int64 // bytes from client to server (tx for client)
bytesFromOutToIn int64 // bytes from server to client (rx for client)
errInToOut error
errOutToIn error
)
go func() {
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
cancel()
wg.Done()
}()
go func() {
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
cancel()
wg.Done()
}()
wg.Wait()
if errInToOut != nil {
if !isClosedError(errInToOut) {
f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
}
}
if errOutToIn != nil {
if !isClosedError(errOutToIn) {
f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
}
}
var rxPackets, txPackets uint64 var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
@@ -127,7 +86,9 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
txPackets = tcpStats.SegmentsReceived.Value() txPackets = tcpStats.SegmentsReceived.Value()
} }
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut) f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
}
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets) f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
} }

View File

@@ -125,10 +125,12 @@ func (f *udpForwarder) cleanup() {
delete(f.conns, idle.id) delete(f.conns, idle.id)
f.Unlock() f.Unlock()
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id)) f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
} }
} }
} }
}
} }
// handleUDP is called by the UDP forwarder for new packets // handleUDP is called by the UDP forwarder for new packets
@@ -144,7 +146,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
_, exists := f.udpForwarder.conns[id] _, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock() f.udpForwarder.RUnlock()
if exists { if exists {
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id)) f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
}
return true return true
} }
@@ -206,7 +210,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
success = true success = true
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace1("forwarder: established UDP connection %v", epID(id)) f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
}
go f.proxyUDP(connCtx, pConn, id, ep) go f.proxyUDP(connCtx, pConn, id, ep)
return true return true
@@ -265,7 +271,9 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
txPackets = udpStats.PacketsReceived.Value() txPackets = udpStats.PacketsReceived.Value()
} }
if f.logger.Enabled(nblog.LevelTrace) {
f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes) f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
}
f.udpForwarder.Lock() f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id) delete(f.udpForwarder.conns, id)

View File

@@ -54,6 +54,7 @@ var levelStrings = map[Level]string{
type logMessage struct { type logMessage struct {
level Level level Level
argCount uint8
format string format string
arg1 any arg1 any
arg2 any arg2 any
@@ -107,6 +108,13 @@ func (l *Logger) SetLevel(level Level) {
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
} }
// Enabled reports whether the given level is currently logged. Callers on the
// hot path should guard log sites with this to avoid boxing arguments into
// any when the level is off.
func (l *Logger) Enabled(level Level) bool {
return l.level.Load() >= uint32(level)
}
func (l *Logger) Error(format string) { func (l *Logger) Error(format string) {
if l.level.Load() >= uint32(LevelError) { if l.level.Load() >= uint32(LevelError) {
select { select {
@@ -155,7 +163,7 @@ func (l *Logger) Trace(format string) {
func (l *Logger) Error1(format string, arg1 any) { func (l *Logger) Error1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelError) { if l.level.Load() >= uint32(LevelError) {
select { select {
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}: case l.msgChannel <- logMessage{level: LevelError, argCount: 1, format: format, arg1: arg1}:
default: default:
} }
} }
@@ -164,7 +172,16 @@ func (l *Logger) Error1(format string, arg1 any) {
func (l *Logger) Error2(format string, arg1, arg2 any) { func (l *Logger) Error2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelError) { if l.level.Load() >= uint32(LevelError) {
select { select {
case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}: case l.msgChannel <- logMessage{level: LevelError, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default:
}
}
}
func (l *Logger) Warn2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelWarn) {
select {
case l.msgChannel <- logMessage{level: LevelWarn, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default: default:
} }
} }
@@ -173,7 +190,7 @@ func (l *Logger) Error2(format string, arg1, arg2 any) {
func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) { func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelWarn) { if l.level.Load() >= uint32(LevelWarn) {
select { select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: case l.msgChannel <- logMessage{level: LevelWarn, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default: default:
} }
} }
@@ -182,7 +199,7 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) { func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelWarn) { if l.level.Load() >= uint32(LevelWarn) {
select { select {
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: case l.msgChannel <- logMessage{level: LevelWarn, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default: default:
} }
} }
@@ -191,7 +208,7 @@ func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
func (l *Logger) Debug1(format string, arg1 any) { func (l *Logger) Debug1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelDebug) { if l.level.Load() >= uint32(LevelDebug) {
select { select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}: case l.msgChannel <- logMessage{level: LevelDebug, argCount: 1, format: format, arg1: arg1}:
default: default:
} }
} }
@@ -200,7 +217,7 @@ func (l *Logger) Debug1(format string, arg1 any) {
func (l *Logger) Debug2(format string, arg1, arg2 any) { func (l *Logger) Debug2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelDebug) { if l.level.Load() >= uint32(LevelDebug) {
select { select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}: case l.msgChannel <- logMessage{level: LevelDebug, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default: default:
} }
} }
@@ -209,16 +226,59 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) {
func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) { func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelDebug) { if l.level.Load() >= uint32(LevelDebug) {
select { select {
case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: case l.msgChannel <- logMessage{level: LevelDebug, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default: default:
} }
} }
} }
// Debugf is the variadic shape. Dispatches to Debug/Debug1/Debug2/Debug3
// to avoid allocating an args slice on the fast path when the arg count is
// known (0-3). Args beyond 3 land on the general variadic path; callers on
// the hot path should prefer DebugN for known counts.
func (l *Logger) Debugf(format string, args ...any) {
if l.level.Load() < uint32(LevelDebug) {
return
}
switch len(args) {
case 0:
l.Debug(format)
case 1:
l.Debug1(format, args[0])
case 2:
l.Debug2(format, args[0], args[1])
case 3:
l.Debug3(format, args[0], args[1], args[2])
default:
l.sendVariadic(LevelDebug, format, args)
}
}
// sendVariadic packs a slice of arguments into a logMessage and non-blocking
// enqueues it. Used for arg counts beyond the fixed-arity fast paths. Args
// beyond the 8-arg slot limit are dropped so callers don't produce silently
// empty log lines via uint8 wraparound in argCount.
func (l *Logger) sendVariadic(level Level, format string, args []any) {
const maxArgs = 8
n := len(args)
if n > maxArgs {
n = maxArgs
}
msg := logMessage{level: level, argCount: uint8(n), format: format}
slots := [maxArgs]*any{&msg.arg1, &msg.arg2, &msg.arg3, &msg.arg4, &msg.arg5, &msg.arg6, &msg.arg7, &msg.arg8}
for i := 0; i < n; i++ {
*slots[i] = args[i]
}
select {
case l.msgChannel <- msg:
default:
}
}
func (l *Logger) Trace1(format string, arg1 any) { func (l *Logger) Trace1(format string, arg1 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}: case l.msgChannel <- logMessage{level: LevelTrace, argCount: 1, format: format, arg1: arg1}:
default: default:
} }
} }
@@ -227,7 +287,7 @@ func (l *Logger) Trace1(format string, arg1 any) {
func (l *Logger) Trace2(format string, arg1, arg2 any) { func (l *Logger) Trace2(format string, arg1, arg2 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}: case l.msgChannel <- logMessage{level: LevelTrace, argCount: 2, format: format, arg1: arg1, arg2: arg2}:
default: default:
} }
} }
@@ -236,7 +296,7 @@ func (l *Logger) Trace2(format string, arg1, arg2 any) {
func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) { func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: case l.msgChannel <- logMessage{level: LevelTrace, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}:
default: default:
} }
} }
@@ -245,7 +305,7 @@ func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) {
func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) { func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: case l.msgChannel <- logMessage{level: LevelTrace, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
default: default:
} }
} }
@@ -254,7 +314,7 @@ func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) {
func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) { func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}: case l.msgChannel <- logMessage{level: LevelTrace, argCount: 5, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}:
default: default:
} }
} }
@@ -263,7 +323,7 @@ func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) {
func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}: case l.msgChannel <- logMessage{level: LevelTrace, argCount: 6, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}:
default: default:
} }
} }
@@ -273,7 +333,7 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) {
func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) { func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
select { select {
case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}: case l.msgChannel <- logMessage{level: LevelTrace, argCount: 8, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}:
default: default:
} }
} }
@@ -286,35 +346,8 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) {
*buf = append(*buf, levelStrings[msg.level]...) *buf = append(*buf, levelStrings[msg.level]...)
*buf = append(*buf, ' ') *buf = append(*buf, ' ')
// Count non-nil arguments for switch
argCount := 0
if msg.arg1 != nil {
argCount++
if msg.arg2 != nil {
argCount++
if msg.arg3 != nil {
argCount++
if msg.arg4 != nil {
argCount++
if msg.arg5 != nil {
argCount++
if msg.arg6 != nil {
argCount++
if msg.arg7 != nil {
argCount++
if msg.arg8 != nil {
argCount++
}
}
}
}
}
}
}
}
var formatted string var formatted string
switch argCount { switch msg.argCount {
case 0: case 0:
formatted = msg.format formatted = msg.format
case 1: case 1:

View File

@@ -11,6 +11,7 @@ import (
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
@@ -242,11 +243,15 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
} }
if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil {
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite packet destination: %v", err) m.logger.Error1("failed to rewrite packet destination: %v", err)
}
return false return false
} }
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP) m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP)
}
return true return true
} }
@@ -264,11 +269,15 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
} }
if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil {
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite packet source: %v", err) m.logger.Error1("failed to rewrite packet source: %v", err)
}
return false return false
} }
if m.logger.Enabled(nblog.LevelTrace) {
m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP) m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP)
}
return true return true
} }
@@ -521,7 +530,9 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti
} }
if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil { if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil {
if m.logger.Enabled(nblog.LevelError) {
m.logger.Error1("failed to rewrite port: %v", err) m.logger.Error1("failed to rewrite port: %v", err)
}
return false return false
} }
d.dnatOrigPort = rule.origPort d.dnatOrigPort = rule.origPort

View File

@@ -25,6 +25,7 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/netrelay"
) )
const ( const (
@@ -536,7 +537,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
continue continue
} }
go c.handleLocalForward(localConn, remoteAddr) go c.handleLocalForward(ctx, localConn, remoteAddr)
} }
}() }()
@@ -548,7 +549,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str
} }
// handleLocalForward handles a single local port forwarding connection // handleLocalForward handles a single local port forwarding connection
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, remoteAddr string) {
defer func() { defer func() {
if err := localConn.Close(); err != nil { if err := localConn.Close(); err != nil {
log.Debugf("local port forwarding: close local connection: %v", err) log.Debugf("local port forwarding: close local connection: %v", err)
@@ -571,7 +572,7 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
} }
}() }()
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel) netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
} }
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr // RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
@@ -653,16 +654,19 @@ func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr stri
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case newChan := <-channelRequests: case newChan, ok := <-channelRequests:
if !ok {
return
}
if newChan != nil { if newChan != nil {
go c.handleRemoteForwardChannel(newChan, localAddr) go c.handleRemoteForwardChannel(ctx, newChan, localAddr)
} }
} }
} }
} }
// handleRemoteForwardChannel handles a single forwarded-tcpip channel // handleRemoteForwardChannel handles a single forwarded-tcpip channel
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) { func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.NewChannel, localAddr string) {
channel, reqs, err := newChan.Accept() channel, reqs, err := newChan.Accept()
if err != nil { if err != nil {
return return
@@ -675,8 +679,14 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
go ssh.DiscardRequests(reqs) go ssh.DiscardRequests(reqs)
localConn, err := net.Dial("tcp", localAddr) // Bound the dial so a black-holed localAddr can't pin the accepted SSH
// channel open indefinitely; the relay itself runs under the outer ctx.
dialCtx, cancelDial := context.WithTimeout(ctx, 10*time.Second)
var dialer net.Dialer
localConn, err := dialer.DialContext(dialCtx, "tcp", localAddr)
cancelDial()
if err != nil { if err != nil {
log.Debugf("remote port forwarding: dial %s: %v", localAddr, err)
return return
} }
defer func() { defer func() {
@@ -685,7 +695,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st
} }
}() }()
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel) netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
} }
// tcpipForwardMsg represents the structure for tcpip-forward requests // tcpipForwardMsg represents the structure for tcpip-forward requests

View File

@@ -194,63 +194,3 @@ func buildAddressList(hostname string, remote net.Addr) []string {
return addresses return addresses
} }
// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections.
// It waits for both directions to complete before returning.
// The caller is responsible for closing the connections.
func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
<-done
<-done
}
func isExpectedCopyError(err error) bool {
return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled)
}
// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections.
// It waits for both directions to complete or for context cancellation before returning.
// Both connections are closed when the function returns.
func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) {
done := make(chan struct{}, 2)
go func() {
if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (1->2): %v", err)
}
done <- struct{}{}
}()
go func() {
if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) {
logger.Debugf("copy error (2->1): %v", err)
}
done <- struct{}{}
}()
select {
case <-ctx.Done():
case <-done:
select {
case <-ctx.Done():
case <-done:
}
}
_ = conn1.Close()
_ = conn2.Close()
}

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/util/netrelay"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -352,7 +353,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne
} }
go cryptossh.DiscardRequests(clientReqs) go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan) netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
} }
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) { func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
@@ -591,7 +592,7 @@ func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh
} }
go cryptossh.DiscardRequests(clientReqs) go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan) netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())})
} }
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) { func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {

View File

@@ -17,7 +17,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh" cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/util/netrelay"
) )
const privilegedPortThreshold = 1024 const privilegedPortThreshold = 1024
@@ -356,7 +356,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
return return
} }
nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel) netrelay.Relay(ctx, conn, channel, netrelay.Options{Logger: logger})
} }
// openForwardChannel creates an SSH forwarded-tcpip channel // openForwardChannel creates an SSH forwarded-tcpip channel

View File

@@ -10,6 +10,7 @@ import (
"net" "net"
"net/netip" "net/netip"
"slices" "slices"
"strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -26,6 +27,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/auth/jwt" "github.com/netbirdio/netbird/shared/auth/jwt"
"github.com/netbirdio/netbird/util/netrelay"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -52,6 +54,10 @@ const (
DefaultJWTMaxTokenAge = 10 * 60 DefaultJWTMaxTokenAge = 10 * 60
) )
// directTCPIPDialTimeout bounds how long relayDirectTCPIP waits on a dial to
// the forwarded destination before rejecting the SSH channel.
const directTCPIPDialTimeout = 30 * time.Second
var ( var (
ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled) ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
ErrUserNotFound = errors.New("user not found") ErrUserNotFound = errors.New("user not found")
@@ -891,5 +897,29 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn,
s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr) s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr)
logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port) logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx) s.relayDirectTCPIP(ctx, newChan, payload.Host, int(payload.Port), logger)
}
// relayDirectTCPIP is a netrelay-based replacement for gliderlabs'
// DirectTCPIPHandler. The upstream handler closes both sides on the first
// EOF; netrelay.Relay propagates CloseWrite so each direction drains on its
// own terms.
func (s *Server) relayDirectTCPIP(ctx ssh.Context, newChan cryptossh.NewChannel, host string, port int, logger *log.Entry) {
dest := net.JoinHostPort(host, strconv.Itoa(port))
dialer := net.Dialer{Timeout: directTCPIPDialTimeout}
dconn, err := dialer.DialContext(ctx, "tcp", dest)
if err != nil {
_ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
return
}
ch, reqs, err := newChan.Accept()
if err != nil {
_ = dconn.Close()
return
}
go cryptossh.DiscardRequests(reqs)
netrelay.Relay(ctx, dconn, ch, netrelay.Options{Logger: logger})
} }

View File

@@ -25,6 +25,12 @@ func (c *peekedConn) Read(b []byte) (int, error) {
return c.reader.Read(b) return c.reader.Read(b)
} }
// halfCloser matches connections that support shutting down the write
// side while keeping the read side open (e.g. *net.TCPConn).
type halfCloser interface {
CloseWrite() error
}
// CloseWrite delegates to the underlying connection if it supports // CloseWrite delegates to the underlying connection if it supports
// half-close (e.g. *net.TCPConn). Without this, embedding net.Conn // half-close (e.g. *net.TCPConn). Without this, embedding net.Conn
// as an interface hides the concrete type's CloseWrite method, making // as an interface hides the concrete type's CloseWrite method, making

View File

@@ -1,156 +0,0 @@
package tcp
import (
"context"
"errors"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/internal/netutil"
)
// errIdleTimeout is returned when a relay connection is closed due to inactivity.
var errIdleTimeout = errors.New("idle timeout")
// DefaultIdleTimeout is the default idle timeout for TCP relay connections.
// A zero value disables idle timeout checking.
const DefaultIdleTimeout = 5 * time.Minute
// halfCloser is implemented by connections that support half-close
// (e.g. *net.TCPConn). When one copy direction finishes, we signal
// EOF to the remote by closing the write side while keeping the read
// side open so the other direction can drain.
type halfCloser interface {
CloseWrite() error
}
// copyBufPool avoids allocating a new 32KB buffer per io.Copy call.
var copyBufPool = sync.Pool{
New: func() any {
buf := make([]byte, 32*1024)
return &buf
},
}
// Relay copies data bidirectionally between src and dst until both
// sides are done or the context is canceled. When idleTimeout is
// non-zero, each direction's read is deadline-guarded; if no data
// flows within the timeout the connection is torn down. When one
// direction finishes, it half-closes the write side of the
// destination (if supported) to signal EOF, allowing the other
// direction to drain gracefully before the full connection teardown.
func Relay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (srcToDst, dstToSrc int64) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
_ = src.Close()
_ = dst.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var errSrcToDst, errDstToSrc error
go func() {
defer wg.Done()
srcToDst, errSrcToDst = copyWithIdleTimeout(dst, src, idleTimeout)
halfClose(dst)
cancel()
}()
go func() {
defer wg.Done()
dstToSrc, errDstToSrc = copyWithIdleTimeout(src, dst, idleTimeout)
halfClose(src)
cancel()
}()
wg.Wait()
if errors.Is(errSrcToDst, errIdleTimeout) || errors.Is(errDstToSrc, errIdleTimeout) {
logger.Debug("relay closed due to idle timeout")
}
if errSrcToDst != nil && !isExpectedCopyError(errSrcToDst) {
logger.Debugf("relay copy error (src→dst): %v", errSrcToDst)
}
if errDstToSrc != nil && !isExpectedCopyError(errDstToSrc) {
logger.Debugf("relay copy error (dst→src): %v", errDstToSrc)
}
return srcToDst, dstToSrc
}
// copyWithIdleTimeout copies from src to dst using a pooled buffer.
// When idleTimeout > 0 it sets a read deadline on src before each
// read and treats a timeout as an idle-triggered close.
func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) {
bufp := copyBufPool.Get().(*[]byte)
defer copyBufPool.Put(bufp)
if idleTimeout <= 0 {
return io.CopyBuffer(dst, src, *bufp)
}
conn, ok := src.(net.Conn)
if !ok {
return io.CopyBuffer(dst, src, *bufp)
}
buf := *bufp
var total int64
for {
if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
return total, err
}
nr, readErr := src.Read(buf)
if nr > 0 {
n, err := checkedWrite(dst, buf[:nr])
total += n
if err != nil {
return total, err
}
}
if readErr != nil {
if netutil.IsTimeout(readErr) {
return total, errIdleTimeout
}
return total, readErr
}
}
}
// checkedWrite writes buf to dst and returns the number of bytes written.
// It guards against short writes and negative counts per io.Copy convention.
func checkedWrite(dst io.Writer, buf []byte) (int64, error) {
nw, err := dst.Write(buf)
if nw < 0 || nw > len(buf) {
nw = 0
}
if err != nil {
return int64(nw), err
}
if nw != len(buf) {
return int64(nw), io.ErrShortWrite
}
return int64(nw), nil
}
func isExpectedCopyError(err error) bool {
return errors.Is(err, errIdleTimeout) || netutil.IsExpectedError(err)
}
// halfClose attempts to half-close the write side of the connection.
// If the connection does not support half-close, this is a no-op.
func halfClose(conn net.Conn) {
if hc, ok := conn.(halfCloser); ok {
// Best-effort; the full close will follow shortly.
_ = hc.CloseWrite()
}
}

View File

@@ -13,8 +13,13 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/netutil" "github.com/netbirdio/netbird/proxy/internal/netutil"
"github.com/netbirdio/netbird/util/netrelay"
) )
func testRelay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (int64, int64) {
return netrelay.Relay(ctx, src, dst, netrelay.Options{IdleTimeout: idleTimeout, Logger: logger})
}
func TestRelay_BidirectionalCopy(t *testing.T) { func TestRelay_BidirectionalCopy(t *testing.T) {
srcClient, srcServer := net.Pipe() srcClient, srcServer := net.Pipe()
dstClient, dstServer := net.Pipe() dstClient, dstServer := net.Pipe()
@@ -41,7 +46,7 @@ func TestRelay_BidirectionalCopy(t *testing.T) {
srcClient.Close() srcClient.Close()
}() }()
s2d, d2s := Relay(ctx, logger, srcServer, dstServer, 0) s2d, d2s := testRelay(ctx, logger, srcServer, dstServer, 0)
assert.Equal(t, int64(len(srcData)), s2d, "bytes src→dst") assert.Equal(t, int64(len(srcData)), s2d, "bytes src→dst")
assert.Equal(t, int64(len(dstData)), d2s, "bytes dst→src") assert.Equal(t, int64(len(dstData)), d2s, "bytes dst→src")
@@ -58,7 +63,7 @@ func TestRelay_ContextCancellation(t *testing.T) {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
Relay(ctx, logger, srcServer, dstServer, 0) testRelay(ctx, logger, srcServer, dstServer, 0)
close(done) close(done)
}() }()
@@ -85,7 +90,7 @@ func TestRelay_OneSideClosed(t *testing.T) {
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
Relay(ctx, logger, srcServer, dstServer, 0) testRelay(ctx, logger, srcServer, dstServer, 0)
close(done) close(done)
}() }()
@@ -129,7 +134,7 @@ func TestRelay_LargeTransfer(t *testing.T) {
dstClient.Close() dstClient.Close()
}() }()
s2d, _ := Relay(ctx, logger, srcServer, dstServer, 0) s2d, _ := testRelay(ctx, logger, srcServer, dstServer, 0)
assert.Equal(t, int64(len(data)), s2d, "should transfer all bytes") assert.Equal(t, int64(len(data)), s2d, "should transfer all bytes")
require.NoError(t, <-errCh) require.NoError(t, <-errCh)
} }
@@ -182,7 +187,7 @@ func TestRelay_IdleTimeout(t *testing.T) {
done := make(chan struct{}) done := make(chan struct{})
var s2d, d2s int64 var s2d, d2s int64
go func() { go func() {
s2d, d2s = Relay(ctx, logger, srcServer, dstServer, 200*time.Millisecond) s2d, d2s = testRelay(ctx, logger, srcServer, dstServer, 200*time.Millisecond)
close(done) close(done)
}() }()

View File

@@ -16,6 +16,7 @@ import (
"github.com/netbirdio/netbird/proxy/internal/accesslog" "github.com/netbirdio/netbird/proxy/internal/accesslog"
"github.com/netbirdio/netbird/proxy/internal/restrict" "github.com/netbirdio/netbird/proxy/internal/restrict"
"github.com/netbirdio/netbird/proxy/internal/types" "github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/util/netrelay"
) )
// defaultDialTimeout is the fallback dial timeout when no per-route // defaultDialTimeout is the fallback dial timeout when no per-route
@@ -528,11 +529,14 @@ func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route
idleTimeout := route.SessionIdleTimeout idleTimeout := route.SessionIdleTimeout
if idleTimeout <= 0 { if idleTimeout <= 0 {
idleTimeout = DefaultIdleTimeout idleTimeout = netrelay.DefaultIdleTimeout
} }
start := time.Now() start := time.Now()
s2d, d2s := Relay(svcCtx, entry, conn, backend, idleTimeout) s2d, d2s := netrelay.Relay(svcCtx, conn, backend, netrelay.Options{
IdleTimeout: idleTimeout,
Logger: entry,
})
elapsed := time.Since(start) elapsed := time.Since(start)
if obs != nil { if obs != nil {

238
util/netrelay/relay.go Normal file
View File

@@ -0,0 +1,238 @@
// Package netrelay provides a bidirectional byte-copy helper for TCP-like
// connections with correct half-close propagation.
//
// When one direction reads EOF, the write side of the opposite connection is
// half-closed (CloseWrite) so the peer sees FIN, then the second direction is
// allowed to drain to its own EOF before both connections are fully closed.
// This preserves TCP half-close semantics (e.g. shutdown(SHUT_WR)) that the
// naive "cancel-both-on-first-EOF" pattern breaks.
package netrelay
import (
"context"
"errors"
"io"
"net"
"sync"
"sync/atomic"
"syscall"
"time"
)
// DebugLogger is the minimal interface netrelay uses to surface teardown
// errors. Both *logrus.Entry and *nblog.Logger (via its Debugf method)
// satisfy it, so callers can pass whichever they already use without an
// adapter. Debugf is the only required method; callers with richer
// loggers just expose this one shape here.
type DebugLogger interface {
Debugf(format string, args ...any)
}
// DefaultIdleTimeout is a reasonable default for Options.IdleTimeout. Callers
// that want an idle timeout but have no specific preference can use this.
const DefaultIdleTimeout = 5 * time.Minute
// halfCloser is implemented by connections that support half-close
// (e.g. *net.TCPConn, *gonet.TCPConn).
type halfCloser interface {
CloseWrite() error
}
var copyBufPool = sync.Pool{
New: func() any {
buf := make([]byte, 32*1024)
return &buf
},
}
// Options configures Relay behavior. The zero value is valid: no idle timeout,
// no logging.
type Options struct {
// IdleTimeout tears down the session if no bytes flow in either
// direction within this window. It is a connection-wide watchdog, so a
// long unidirectional transfer on one side keeps the other side alive.
// Zero disables idle tracking.
IdleTimeout time.Duration
// Logger receives debug-level copy/idle errors. Nil suppresses logging.
// Any logger with Debug/Debugf methods is accepted (logrus.Entry,
// uspfilter's nblog.Logger, etc.).
Logger DebugLogger
}
// Relay copies bytes in both directions between a and b until both directions
// EOF or ctx is canceled. On each direction's EOF it half-closes the
// opposite conn's write side (best effort) so the peer sees FIN while the
// other direction drains. Both conns are fully closed when Relay returns.
//
// a and b only need to implement io.ReadWriteCloser; connections that also
// implement CloseWrite (e.g. *net.TCPConn, ssh.Channel) get proper half-close
// propagation. Options.IdleTimeout, when set, is enforced by a connection-wide
// watchdog that tracks reads in either direction.
//
// Return values are byte counts: aToB (a.Read → b.Write) and bToA (b.Read →
// a.Write). Errors are logged via Options.Logger when set; they are not
// returned because a relay always terminates on some kind of EOF/cancel.
func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bToA int64) {
ctx, cancel := context.WithCancel(ctx)
closeDone := make(chan struct{})
defer func() {
cancel()
<-closeDone
}()
go func() {
<-ctx.Done()
_ = a.Close()
_ = b.Close()
close(closeDone)
}()
// Both sides must support CloseWrite to propagate half-close. If either
// doesn't, a direction's EOF can't be signaled to the peer and the other
// direction would block forever waiting for data; in that case we fall
// back to the cancel-both-on-first-EOF behavior.
_, aHC := a.(halfCloser)
_, bHC := b.(halfCloser)
halfCloseSupported := aHC && bHC
var (
lastActivity atomic.Int64
idleHit atomic.Bool
)
lastActivity.Store(time.Now().UnixNano())
if opts.IdleTimeout > 0 {
go watchdog(ctx, cancel, &lastActivity, &idleHit, opts.IdleTimeout)
}
var wg sync.WaitGroup
wg.Add(2)
var errAToB, errBToA error
go func() {
defer wg.Done()
aToB, errAToB = copyTracked(b, a, &lastActivity)
if halfCloseSupported && isCleanEOF(errAToB) {
halfClose(b)
} else {
cancel()
}
}()
go func() {
defer wg.Done()
bToA, errBToA = copyTracked(a, b, &lastActivity)
if halfCloseSupported && isCleanEOF(errBToA) {
halfClose(a)
} else {
cancel()
}
}()
wg.Wait()
if opts.Logger != nil {
if idleHit.Load() {
opts.Logger.Debugf("relay closed due to idle timeout")
}
if errAToB != nil && !isExpectedCopyError(errAToB) {
opts.Logger.Debugf("relay copy error (a→b): %v", errAToB)
}
if errBToA != nil && !isExpectedCopyError(errBToA) {
opts.Logger.Debugf("relay copy error (b→a): %v", errBToA)
}
}
return aToB, bToA
}
// watchdog enforces a connection-wide idle timeout. It cancels ctx when no
// activity has been seen on either direction for idle. It exits as soon as
// ctx is canceled so it doesn't outlive the relay.
func watchdog(ctx context.Context, cancel context.CancelFunc, lastActivity *atomic.Int64, idleHit *atomic.Bool, idle time.Duration) {
// Cap the tick at 50ms so detection latency stays bounded regardless of
// how large idle is, and fall back to idle/2 when that is smaller so
// very short timeouts (mainly in tests) are still caught promptly.
tick := min(idle/2, 50*time.Millisecond)
if tick <= 0 {
tick = time.Millisecond
}
t := time.NewTicker(tick)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
last := time.Unix(0, lastActivity.Load())
if time.Since(last) >= idle {
idleHit.Store(true)
cancel()
return
}
}
}
}
// copyTracked copies from src to dst using a pooled buffer, updating
// lastActivity on every successful read so a shared watchdog can enforce a
// connection-wide idle timeout.
func copyTracked(dst io.Writer, src io.Reader, lastActivity *atomic.Int64) (int64, error) {
bufp := copyBufPool.Get().(*[]byte)
defer copyBufPool.Put(bufp)
buf := *bufp
var total int64
for {
nr, readErr := src.Read(buf)
if nr > 0 {
lastActivity.Store(time.Now().UnixNano())
n, werr := checkedWrite(dst, buf[:nr])
total += n
if werr != nil {
return total, werr
}
}
if readErr != nil {
return total, readErr
}
}
}
func checkedWrite(dst io.Writer, buf []byte) (int64, error) {
nw, err := dst.Write(buf)
if nw < 0 || nw > len(buf) {
nw = 0
}
if err != nil {
return int64(nw), err
}
if nw != len(buf) {
return int64(nw), io.ErrShortWrite
}
return int64(nw), nil
}
func halfClose(conn io.ReadWriteCloser) {
if hc, ok := conn.(halfCloser); ok {
_ = hc.CloseWrite()
}
}
// isCleanEOF reports whether a copy terminated on a graceful end-of-stream.
// Only in that case is it correct to propagate the EOF via CloseWrite on the
// peer; any other error means the flow is broken and both directions should
// tear down.
func isCleanEOF(err error) bool {
return err == nil || errors.Is(err, io.EOF)
}
func isExpectedCopyError(err error) bool {
return errors.Is(err, net.ErrClosed) ||
errors.Is(err, context.Canceled) ||
errors.Is(err, io.EOF) ||
errors.Is(err, syscall.ECONNRESET) ||
errors.Is(err, syscall.EPIPE) ||
errors.Is(err, syscall.ECONNABORTED)
}

221
util/netrelay/relay_test.go Normal file
View File

@@ -0,0 +1,221 @@
package netrelay
import (
"io"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// tcpPair returns two connected loopback TCP conns.
func tcpPair(t *testing.T) (*net.TCPConn, *net.TCPConn) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer ln.Close()
type result struct {
c *net.TCPConn
err error
}
ch := make(chan result, 1)
go func() {
c, err := ln.Accept()
if err != nil {
ch <- result{nil, err}
return
}
ch <- result{c.(*net.TCPConn), nil}
}()
dial, err := net.Dial("tcp", ln.Addr().String())
require.NoError(t, err)
r := <-ch
require.NoError(t, r.err)
return dial.(*net.TCPConn), r.c
}
// TestRelayHalfClose exercises the shutdown(SHUT_WR) scenario that the naive
// cancel-both-on-first-EOF pattern breaks. Client A shuts down its write
// side; B must still be able to write a full response and A must receive
// all of it before its read returns EOF.
func TestRelayHalfClose(t *testing.T) {
// Real peer pairs for each side of the relay. We relay between relayA
// and relayB. Peer A talks through relayA; peer B talks through relayB.
peerA, relayA := tcpPair(t)
relayB, peerB := tcpPair(t)
defer peerA.Close()
defer peerB.Close()
// Bound blocking reads/writes so a broken relay fails the test instead of
// hanging the test process.
deadline := time.Now().Add(5 * time.Second)
require.NoError(t, peerA.SetDeadline(deadline))
require.NoError(t, peerB.SetDeadline(deadline))
ctx := t.Context()
done := make(chan struct{})
go func() {
Relay(ctx, relayA, relayB, Options{})
close(done)
}()
// Peer A sends a request, then half-closes its write side.
req := []byte("request-payload")
_, err := peerA.Write(req)
require.NoError(t, err)
require.NoError(t, peerA.CloseWrite())
// Peer B reads the request to EOF (FIN must have propagated).
got, err := io.ReadAll(peerB)
require.NoError(t, err)
require.Equal(t, req, got)
// Peer B writes its response; peer A must receive all of it even though
// peer A's write side is already closed.
resp := make([]byte, 64*1024)
for i := range resp {
resp[i] = byte(i)
}
_, err = peerB.Write(resp)
require.NoError(t, err)
require.NoError(t, peerB.Close())
gotResp, err := io.ReadAll(peerA)
require.NoError(t, err)
require.Equal(t, resp, gotResp)
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("relay did not return")
}
}
// TestRelayFullDuplex verifies bidirectional copy in the simple case.
func TestRelayFullDuplex(t *testing.T) {
peerA, relayA := tcpPair(t)
relayB, peerB := tcpPair(t)
defer peerA.Close()
defer peerB.Close()
// Bound blocking reads/writes so a broken relay fails the test instead of
// hanging the test process.
deadline := time.Now().Add(5 * time.Second)
require.NoError(t, peerA.SetDeadline(deadline))
require.NoError(t, peerB.SetDeadline(deadline))
ctx := t.Context()
done := make(chan struct{})
go func() {
Relay(ctx, relayA, relayB, Options{})
close(done)
}()
type result struct {
got []byte
err error
}
resA := make(chan result, 1)
resB := make(chan result, 1)
msgAB := []byte("hello-from-a")
msgBA := []byte("hello-from-b")
go func() {
if _, err := peerA.Write(msgAB); err != nil {
resA <- result{err: err}
return
}
buf := make([]byte, len(msgBA))
_, err := io.ReadFull(peerA, buf)
resA <- result{got: buf, err: err}
_ = peerA.Close()
}()
go func() {
if _, err := peerB.Write(msgBA); err != nil {
resB <- result{err: err}
return
}
buf := make([]byte, len(msgAB))
_, err := io.ReadFull(peerB, buf)
resB <- result{got: buf, err: err}
_ = peerB.Close()
}()
a, b := <-resA, <-resB
require.NoError(t, a.err)
require.Equal(t, msgBA, a.got)
require.NoError(t, b.err)
require.Equal(t, msgAB, b.got)
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("relay did not return")
}
}
// TestRelayNoHalfCloseFallback ensures Relay terminates when the underlying
// conns don't support CloseWrite (e.g. net.Pipe). Without the fallback to
// cancel-both-on-first-EOF, the second direction would block forever.
func TestRelayNoHalfCloseFallback(t *testing.T) {
a1, a2 := net.Pipe()
b1, b2 := net.Pipe()
defer a1.Close()
defer b1.Close()
ctx := t.Context()
done := make(chan struct{})
go func() {
Relay(ctx, a2, b2, Options{})
close(done)
}()
// Close peer A's side; a2's Read will return EOF.
require.NoError(t, a1.Close())
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("relay did not terminate when half-close is unsupported")
}
}
// TestRelayIdleTimeout ensures the idle watchdog tears down a silent flow.
func TestRelayIdleTimeout(t *testing.T) {
peerA, relayA := tcpPair(t)
relayB, peerB := tcpPair(t)
defer peerA.Close()
defer peerB.Close()
ctx := t.Context()
const idle = 150 * time.Millisecond
start := time.Now()
done := make(chan struct{})
go func() {
Relay(ctx, relayA, relayB, Options{IdleTimeout: idle})
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("relay did not close on idle")
}
elapsed := time.Since(start)
require.GreaterOrEqual(t, elapsed, idle,
"relay must not close before the idle timeout elapses")
require.Less(t, elapsed, idle+500*time.Millisecond,
"relay should close shortly after the idle timeout")
}