diff --git a/client/firewall/uspfilter/conntrack/cap_test.go b/client/firewall/uspfilter/conntrack/cap_test.go new file mode 100644 index 000000000..1f633f134 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/cap_test.go @@ -0,0 +1,92 @@ +package conntrack + +import ( + "net/netip" + "testing" + + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" +) + +func TestTCPCapEvicts(t *testing.T) { + t.Setenv(EnvTCPMaxEntries, "4") + + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + require.Equal(t, 4, tracker.maxEntries) + + src := netip.MustParseAddr("100.64.0.1") + dst := netip.MustParseAddr("100.64.0.2") + + for i := 0; i < 10; i++ { + tracker.TrackOutbound(src, dst, uint16(10000+i), 80, TCPSyn, 0) + } + require.LessOrEqual(t, len(tracker.connections), 4, + "TCP table must not exceed the configured cap") + require.Greater(t, len(tracker.connections), 0, + "some entries must remain after eviction") +} + +func TestTCPCapPrefersTombstonedForEviction(t *testing.T) { + t.Setenv(EnvTCPMaxEntries, "3") + + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + src := netip.MustParseAddr("100.64.0.1") + dst := netip.MustParseAddr("100.64.0.2") + + // Fill to cap with 3 live connections. + for i := 0; i < 3; i++ { + tracker.TrackOutbound(src, dst, uint16(20000+i), 80, TCPSyn, 0) + } + require.Len(t, tracker.connections, 3) + + // Tombstone one by sending RST through IsValidInbound. + tombstonedKey := ConnKey{SrcIP: src, DstIP: dst, SrcPort: 20001, DstPort: 80} + require.True(t, tracker.IsValidInbound(dst, src, 80, 20001, TCPRst|TCPAck, 0)) + require.True(t, tracker.connections[tombstonedKey].IsTombstone()) + + // Another live connection forces eviction. The tombstone must go first. + tracker.TrackOutbound(src, dst, uint16(29999), 80, TCPSyn, 0) + + _, tombstonedStillPresent := tracker.connections[tombstonedKey] + require.False(t, tombstonedStillPresent, + "tombstoned entry should be evicted before live entries") + require.LessOrEqual(t, len(tracker.connections), 3) +} + +func TestUDPCapEvicts(t *testing.T) { + t.Setenv(EnvUDPMaxEntries, "5") + + tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger) + defer tracker.Close() + require.Equal(t, 5, tracker.maxEntries) + + src := netip.MustParseAddr("100.64.0.1") + dst := netip.MustParseAddr("100.64.0.2") + + for i := 0; i < 12; i++ { + tracker.TrackOutbound(src, dst, uint16(30000+i), 53, 0) + } + require.LessOrEqual(t, len(tracker.connections), 5) + require.Greater(t, len(tracker.connections), 0) +} + +func TestICMPCapEvicts(t *testing.T) { + t.Setenv(EnvICMPMaxEntries, "3") + + tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger) + defer tracker.Close() + require.Equal(t, 3, tracker.maxEntries) + + src := netip.MustParseAddr("100.64.0.1") + dst := netip.MustParseAddr("100.64.0.2") + + echoReq := layers.CreateICMPv4TypeCode(uint8(layers.ICMPv4TypeEchoRequest), 0) + for i := 0; i < 8; i++ { + tracker.TrackOutbound(src, dst, uint16(i), echoReq, nil, 64) + } + require.LessOrEqual(t, len(tracker.connections), 3) + require.Greater(t, len(tracker.connections), 0) +} diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index 7be0dd78f..5ddfedccf 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -3,14 +3,61 @@ package conntrack import ( "fmt" "net/netip" + "os" + "strconv" "sync/atomic" "time" "github.com/google/uuid" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) +// evictSampleSize bounds how many map entries we scan per eviction call. +// Keeps eviction O(1) even at cap under sustained load; the sampled-LRU +// heuristic is good enough for a conntrack table that only overflows under +// abuse. +const evictSampleSize = 8 + +// envDuration parses an os.Getenv(name) as a time.Duration. Falls back to +// def on empty or invalid; logs a warning on invalid. +func envDuration(logger *nblog.Logger, name string, def time.Duration) time.Duration { + v := os.Getenv(name) + if v == "" { + return def + } + d, err := time.ParseDuration(v) + if err != nil { + logger.Warn3("invalid %s=%q: %v, using default", name, v, err) + return def + } + if d <= 0 { + logger.Warn2("invalid %s=%q: must be positive, using default", name, v) + return def + } + return d +} + +// envInt parses an os.Getenv(name) as an int. Falls back to def on empty, +// invalid, or non-positive. Logs a warning on invalid input. +func envInt(logger *nblog.Logger, name string, def int) int { + v := os.Getenv(name) + if v == "" { + return def + } + n, err := strconv.Atoi(v) + switch { + case err != nil: + logger.Warn3("invalid %s=%q: %v, using default", name, v, err) + return def + case n <= 0: + logger.Warn2("invalid %s=%q: must be positive, using default", name, v) + return def + } + return n +} + // BaseConnTrack provides common fields and locking for all connection types type BaseConnTrack struct { FlowId uuid.UUID diff --git a/client/firewall/uspfilter/conntrack/defaults_desktop.go b/client/firewall/uspfilter/conntrack/defaults_desktop.go new file mode 100644 index 000000000..2f07f5984 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/defaults_desktop.go @@ -0,0 +1,11 @@ +//go:build !ios && !android + +package conntrack + +// Default per-tracker entry caps on desktop/server platforms. These mirror +// typical Linux netfilter nf_conntrack_max territory with ample headroom. +const ( + DefaultMaxTCPEntries = 65536 + DefaultMaxUDPEntries = 16384 + DefaultMaxICMPEntries = 2048 +) diff --git a/client/firewall/uspfilter/conntrack/defaults_mobile.go b/client/firewall/uspfilter/conntrack/defaults_mobile.go new file mode 100644 index 000000000..c9e05d229 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/defaults_mobile.go @@ -0,0 +1,13 @@ +//go:build ios || android + +package conntrack + +// Default per-tracker entry caps on mobile platforms. iOS network extensions +// are capped at ~50 MB; Android runs under aggressive memory pressure. These +// values keep conntrack footprint well under 5 MB worst case (TCPConnTrack +// is ~200 B plus map overhead). +const ( + DefaultMaxTCPEntries = 4096 + DefaultMaxUDPEntries = 2048 + DefaultMaxICMPEntries = 512 +) diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 50b663642..2fd37145a 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -44,6 +44,9 @@ type ICMPConnTrack struct { ICMPCode uint8 } +// EnvICMPMaxEntries caps the ICMP conntrack table size. +const EnvICMPMaxEntries = "NB_CONNTRACK_ICMP_MAX" + // ICMPTracker manages ICMP connection states type ICMPTracker struct { logger *nblog.Logger @@ -52,6 +55,7 @@ type ICMPTracker struct { cleanupTicker *time.Ticker tickerCancel context.CancelFunc mutex sync.RWMutex + maxEntries int flowLogger nftypes.FlowLogger } @@ -135,6 +139,7 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nfty timeout: timeout, cleanupTicker: time.NewTicker(ICMPCleanupInterval), tickerCancel: cancel, + maxEntries: envInt(logger, EnvICMPMaxEntries, DefaultMaxICMPEntries), flowLogger: flowLogger, } @@ -221,7 +226,9 @@ func (t *ICMPTracker) track( // non echo requests don't need tracking if typ != uint8(layers.ICMPv4TypeEchoRequest) { - t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) + } t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size) return } @@ -240,10 +247,15 @@ func (t *ICMPTracker) track( conn.UpdateCounters(direction, size) t.mutex.Lock() + if t.maxEntries > 0 && len(t.connections) >= t.maxEntries { + t.evictOneLocked() + } t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace3("New %s ICMP connection %s - %s", direction, key, icmpInfo) + } t.sendEvent(nftypes.TypeStart, conn, ruleId) } @@ -286,6 +298,34 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) { } } +// evictOneLocked removes one entry to make room. Caller must hold t.mutex. +// Bounded sample scan: picks the oldest among up to evictSampleSize entries. +func (t *ICMPTracker) evictOneLocked() { + var candKey ICMPConnKey + var candSeen int64 + haveCand := false + sampled := 0 + + for k, c := range t.connections { + seen := c.lastSeen.Load() + if !haveCand || seen < candSeen { + candKey = k + candSeen = seen + haveCand = true + } + sampled++ + if sampled >= evictSampleSize { + break + } + } + if haveCand { + if evicted := t.connections[candKey]; evicted != nil { + t.sendEvent(nftypes.TypeEnd, evicted, nil) + } + delete(t.connections, candKey) + } +} + func (t *ICMPTracker) cleanup() { t.mutex.Lock() defer t.mutex.Unlock() @@ -294,8 +334,10 @@ func (t *ICMPTracker) cleanup() { if conn.timeoutExceeded(t.timeout) { delete(t.connections, key) - t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", - key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace5("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", + key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + } t.sendEvent(nftypes.TypeEnd, conn, nil) } } diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 335a3abab..9edc9af22 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -38,6 +38,27 @@ const ( TCPHandshakeTimeout = 60 * time.Second // TCPCleanupInterval is how often we check for stale connections TCPCleanupInterval = 5 * time.Minute + // FinWaitTimeout bounds FIN_WAIT_1 / FIN_WAIT_2 / CLOSING states. + // Matches Linux netfilter nf_conntrack_tcp_timeout_fin_wait. + FinWaitTimeout = 60 * time.Second + // CloseWaitTimeout bounds CLOSE_WAIT. Matches Linux default; apps + // holding CloseWait longer than this should bump the env var. + CloseWaitTimeout = 60 * time.Second + // LastAckTimeout bounds LAST_ACK. Matches Linux default. + LastAckTimeout = 30 * time.Second +) + +// Env vars to override per-state teardown timeouts. Values parsed by +// time.ParseDuration (e.g. "120s", "2m"). Invalid values fall back to the +// defaults above with a warning. +const ( + EnvTCPFinWaitTimeout = "NB_CONNTRACK_TCP_FIN_WAIT_TIMEOUT" + EnvTCPCloseWaitTimeout = "NB_CONNTRACK_TCP_CLOSE_WAIT_TIMEOUT" + EnvTCPLastAckTimeout = "NB_CONNTRACK_TCP_LAST_ACK_TIMEOUT" + + // EnvTCPMaxEntries caps the TCP conntrack table size. Oldest entries + // (tombstones first) are evicted when the cap is reached. + EnvTCPMaxEntries = "NB_CONNTRACK_TCP_MAX" ) // TCPState represents the state of a TCP connection @@ -133,14 +154,18 @@ func (t *TCPConnTrack) SetTombstone() { // TCPTracker manages TCP connection states type TCPTracker struct { - logger *nblog.Logger - connections map[ConnKey]*TCPConnTrack - mutex sync.RWMutex - cleanupTicker *time.Ticker - tickerCancel context.CancelFunc - timeout time.Duration - waitTimeout time.Duration - flowLogger nftypes.FlowLogger + logger *nblog.Logger + connections map[ConnKey]*TCPConnTrack + mutex sync.RWMutex + cleanupTicker *time.Ticker + tickerCancel context.CancelFunc + timeout time.Duration + waitTimeout time.Duration + finWaitTimeout time.Duration + closeWaitTimeout time.Duration + lastAckTimeout time.Duration + maxEntries int + flowLogger nftypes.FlowLogger } // NewTCPTracker creates a new TCP connection tracker @@ -155,13 +180,17 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp ctx, cancel := context.WithCancel(context.Background()) tracker := &TCPTracker{ - logger: logger, - connections: make(map[ConnKey]*TCPConnTrack), - cleanupTicker: time.NewTicker(TCPCleanupInterval), - tickerCancel: cancel, - timeout: timeout, - waitTimeout: waitTimeout, - flowLogger: flowLogger, + logger: logger, + connections: make(map[ConnKey]*TCPConnTrack), + cleanupTicker: time.NewTicker(TCPCleanupInterval), + tickerCancel: cancel, + timeout: timeout, + waitTimeout: waitTimeout, + finWaitTimeout: envDuration(logger, EnvTCPFinWaitTimeout, FinWaitTimeout), + closeWaitTimeout: envDuration(logger, EnvTCPCloseWaitTimeout, CloseWaitTimeout), + lastAckTimeout: envDuration(logger, EnvTCPLastAckTimeout, LastAckTimeout), + maxEntries: envInt(logger, EnvTCPMaxEntries, DefaultMaxTCPEntries), + flowLogger: flowLogger, } go tracker.cleanupRoutine(ctx) @@ -209,6 +238,12 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla if exists || flags&TCPSyn == 0 { return } + // Reject illegal SYN combinations (SYN+FIN, SYN+RST, …) so they don't + // create spurious conntrack entries. Not mandated by RFC 9293 but a + // common hardening (Linux netfilter/nftables rejects these too). + if !isValidFlagCombination(flags) { + return + } conn := &TCPConnTrack{ BaseConnTrack: BaseConnTrack{ @@ -225,20 +260,65 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla conn.state.Store(int32(TCPStateNew)) conn.DNATOrigPort.Store(uint32(origPort)) - if origPort != 0 { - t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) - } else { - t.logger.Trace2("New %s TCP connection: %s", direction, key) + if t.logger.Enabled(nblog.LevelTrace) { + if origPort != 0 { + t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s TCP connection: %s", direction, key) + } } t.updateState(key, conn, flags, direction, size) t.mutex.Lock() + if t.maxEntries > 0 && len(t.connections) >= t.maxEntries { + t.evictOneLocked() + } t.connections[key] = conn t.mutex.Unlock() t.sendEvent(nftypes.TypeStart, conn, ruleID) } +// evictOneLocked removes one entry to make room. Caller must hold t.mutex. +// Bounded scan: samples up to evictSampleSize pseudo-random entries (Go map +// iteration order is randomized), preferring a tombstone. If no tombstone +// found in the sample, evicts the oldest among the sampled entries. O(1) +// worst case — cheap enough to run on every insert at cap during abuse. +func (t *TCPTracker) evictOneLocked() { + var candKey ConnKey + var candSeen int64 + haveCand := false + sampled := 0 + + for k, c := range t.connections { + if c.IsTombstone() { + delete(t.connections, k) + return + } + seen := c.lastSeen.Load() + if !haveCand || seen < candSeen { + candKey = k + candSeen = seen + haveCand = true + } + sampled++ + if sampled >= evictSampleSize { + break + } + } + if haveCand { + if evicted := t.connections[candKey]; evicted != nil { + // TypeEnd is already emitted at the state transition to + // TimeWait and when a connection is tombstoned. Only emit + // here when we're reaping a still-active flow. + if evicted.GetState() != TCPStateTimeWait && !evicted.IsTombstone() { + t.sendEvent(nftypes.TypeEnd, evicted, nil) + } + } + delete(t.connections, candKey) + } +} + // IsValidInbound checks if an inbound TCP packet matches a tracked connection func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool { key := ConnKey{ @@ -256,12 +336,19 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui return false } + // Reject illegal flag combinations regardless of state. These never belong + // to a legitimate flow and must not advance or tear down state. + if !isValidFlagCombination(flags) { + if t.logger.Enabled(nblog.LevelWarn) { + t.logger.Warn3("TCP illegal flag combination %x for connection %s (state %s)", flags, key, conn.GetState()) + } + return false + } + currentState := conn.GetState() if !t.isValidStateForFlags(currentState, flags) { - t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key) - // allow all flags for established for now - if currentState == TCPStateEstablished { - return true + if t.logger.Enabled(nblog.LevelWarn) { + t.logger.Warn3("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key) } return false } @@ -270,116 +357,208 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui return true } -// updateState updates the TCP connection state based on flags +// updateState updates the TCP connection state based on flags. func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) { - conn.UpdateLastSeen() conn.UpdateCounters(packetDir, size) + // Malformed flag combinations must not refresh lastSeen or drive state, + // otherwise spoofed packets keep a dead flow alive past its timeout. + if !isValidFlagCombination(flags) { + return + } + + conn.UpdateLastSeen() + currentState := conn.GetState() if flags&TCPRst != 0 { - if conn.CompareAndSwapState(currentState, TCPStateClosed) { - conn.SetTombstone() - t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]", - key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) - t.sendEvent(nftypes.TypeEnd, conn, nil) - } + // Hardening beyond RFC 9293 §3.10.7.4: without sequence tracking we + // cannot apply the RFC 5961 in-window RST check, so we conservatively + // reject RSTs that the spec would accept (TIME-WAIT with in-window + // SEQ, SynSent from same direction as own SYN, etc.). + t.handleRst(key, conn, currentState, packetDir) return } - var newState TCPState - switch currentState { - case TCPStateNew: - if flags&TCPSyn != 0 && flags&TCPAck == 0 { - if conn.Direction == nftypes.Egress { - newState = TCPStateSynSent - } else { - newState = TCPStateSynReceived - } - } + newState := nextState(currentState, conn.Direction, packetDir, flags) + if newState == 0 || !conn.CompareAndSwapState(currentState, newState) { + return + } + t.onTransition(key, conn, currentState, newState, packetDir) +} - case TCPStateSynSent: - if flags&TCPSyn != 0 && flags&TCPAck != 0 { - if packetDir != conn.Direction { - newState = TCPStateEstablished - } else { - // Simultaneous open - newState = TCPStateSynReceived - } - } +// handleRst processes a RST segment. Late RSTs in TimeWait and spoofed RSTs +// from the SYN direction are ignored; otherwise the flow is tombstoned. +func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, currentState TCPState, packetDir nftypes.Direction) { + // TimeWait exists to absorb late segments; don't let a late RST + // tombstone the entry and break same-4-tuple reuse. + if currentState == TCPStateTimeWait { + return + } + // A RST from the same direction as the SYN cannot be a legitimate + // response and must not tear down a half-open connection. + if currentState == TCPStateSynSent && packetDir == conn.Direction { + return + } + if !conn.CompareAndSwapState(currentState, TCPStateClosed) { + return + } + conn.SetTombstone() + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace6("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]", + key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + } + t.sendEvent(nftypes.TypeEnd, conn, nil) +} - case TCPStateSynReceived: - if flags&TCPAck != 0 && flags&TCPSyn == 0 { - if packetDir == conn.Direction { - newState = TCPStateEstablished - } - } +// stateTransition describes one state's transition logic. It receives the +// packet's flags plus whether the packet direction matches the connection's +// origin direction (same=true means same side as the SYN initiator). Return 0 +// for no transition. +type stateTransition func(flags uint8, connDir nftypes.Direction, same bool) TCPState - case TCPStateEstablished: - if flags&TCPFin != 0 { - if packetDir == conn.Direction { - newState = TCPStateFinWait1 - } else { - newState = TCPStateCloseWait - } - } +// stateTable maps each state to its transition function. Centralized here so +// nextState stays trivial and each rule is easy to read in isolation. +var stateTable = map[TCPState]stateTransition{ + TCPStateNew: transNew, + TCPStateSynSent: transSynSent, + TCPStateSynReceived: transSynReceived, + TCPStateEstablished: transEstablished, + TCPStateFinWait1: transFinWait1, + TCPStateFinWait2: transFinWait2, + TCPStateClosing: transClosing, + TCPStateCloseWait: transCloseWait, + TCPStateLastAck: transLastAck, +} - case TCPStateFinWait1: - if packetDir != conn.Direction { - switch { - case flags&TCPFin != 0 && flags&TCPAck != 0: - newState = TCPStateClosing - case flags&TCPFin != 0: - newState = TCPStateClosing - case flags&TCPAck != 0: - newState = TCPStateFinWait2 - } - } +// nextState returns the target TCP state for the given current state and +// packet, or 0 if the packet does not trigger a transition. +func nextState(currentState TCPState, connDir, packetDir nftypes.Direction, flags uint8) TCPState { + fn, ok := stateTable[currentState] + if !ok { + return 0 + } + return fn(flags, connDir, packetDir == connDir) +} - case TCPStateFinWait2: - if flags&TCPFin != 0 { - newState = TCPStateTimeWait +func transNew(flags uint8, connDir nftypes.Direction, _ bool) TCPState { + if flags&TCPSyn != 0 && flags&TCPAck == 0 { + if connDir == nftypes.Egress { + return TCPStateSynSent } + return TCPStateSynReceived + } + return 0 +} - case TCPStateClosing: - if flags&TCPAck != 0 { - newState = TCPStateTimeWait +func transSynSent(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPSyn != 0 && flags&TCPAck != 0 { + if same { + return TCPStateSynReceived // simultaneous open } + return TCPStateEstablished + } + return 0 +} - case TCPStateCloseWait: - if flags&TCPFin != 0 { - newState = TCPStateLastAck - } +func transSynReceived(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPAck != 0 && flags&TCPSyn == 0 && same { + return TCPStateEstablished + } + return 0 +} - case TCPStateLastAck: - if flags&TCPAck != 0 { - newState = TCPStateClosed - } +func transEstablished(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPFin == 0 { + return 0 + } + if same { + return TCPStateFinWait1 + } + return TCPStateCloseWait +} + +// transFinWait1 handles the active-close peer response. A FIN carrying our +// ACK piggybacked goes straight to TIME-WAIT (RFC 9293 §3.10.7.4, FIN-WAIT-1: +// "if our FIN has been ACKed... enter the TIME-WAIT state"); a lone FIN moves +// to CLOSING; a pure ACK of our FIN moves to FIN-WAIT-2. +func transFinWait1(flags uint8, _ nftypes.Direction, same bool) TCPState { + if same { + return 0 + } + if flags&TCPFin != 0 && flags&TCPAck != 0 { + return TCPStateTimeWait + } + switch { + case flags&TCPFin != 0: + return TCPStateClosing + case flags&TCPAck != 0: + return TCPStateFinWait2 + } + return 0 +} + +// transFinWait2 ignores own-side FIN retransmits; only the peer's FIN advances. +func transFinWait2(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPFin != 0 && !same { + return TCPStateTimeWait + } + return 0 +} + +// transClosing completes a simultaneous close on the peer's ACK. +func transClosing(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPAck != 0 && !same { + return TCPStateTimeWait + } + return 0 +} + +// transCloseWait only advances to LastAck when WE send FIN, ignoring peer retransmits. +func transCloseWait(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPFin != 0 && same { + return TCPStateLastAck + } + return 0 +} + +// transLastAck closes the flow only on the peer's ACK (not our own ACK retransmits). +func transLastAck(flags uint8, _ nftypes.Direction, same bool) TCPState { + if flags&TCPAck != 0 && !same { + return TCPStateClosed + } + return 0 +} + +// onTransition handles logging and flow-event emission after a successful +// state transition. TimeWait and Closed are terminal for flow accounting. +func (t *TCPTracker) onTransition(key ConnKey, conn *TCPConnTrack, from, to TCPState, packetDir nftypes.Direction) { + traceOn := t.logger.Enabled(nblog.LevelTrace) + if traceOn { + t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, from, to, packetDir) } - if newState != 0 && conn.CompareAndSwapState(currentState, newState) { - t.logger.Trace4("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir) - - switch newState { - case TCPStateTimeWait: + switch to { + case TCPStateTimeWait: + if traceOn { t.logger.Trace5("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]", key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) - t.sendEvent(nftypes.TypeEnd, conn, nil) - - case TCPStateClosed: - conn.SetTombstone() + } + t.sendEvent(nftypes.TypeEnd, conn, nil) + case TCPStateClosed: + conn.SetTombstone() + if traceOn { t.logger.Trace5("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]", key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) - t.sendEvent(nftypes.TypeEnd, conn, nil) } + t.sendEvent(nftypes.TypeEnd, conn, nil) } } -// isValidStateForFlags checks if the TCP flags are valid for the current connection state +// isValidStateForFlags checks if the TCP flags are valid for the current +// connection state. Caller must have already verified the flag combination is +// legal via isValidFlagCombination. func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { - if !isValidFlagCombination(flags) { - return false - } if flags&TCPRst != 0 { if state == TCPStateSynSent { return flags&TCPAck != 0 @@ -449,15 +628,24 @@ func (t *TCPTracker) cleanup() { timeout = t.waitTimeout case TCPStateEstablished: timeout = t.timeout + case TCPStateFinWait1, TCPStateFinWait2, TCPStateClosing: + timeout = t.finWaitTimeout + case TCPStateCloseWait: + timeout = t.closeWaitTimeout + case TCPStateLastAck: + timeout = t.lastAckTimeout default: + // SynSent / SynReceived / New timeout = TCPHandshakeTimeout } if conn.timeoutExceeded(timeout) { delete(t.connections, key) - t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]", - key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace6("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]", + key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + } // event already handled by state change if currentState != TCPStateTimeWait { diff --git a/client/firewall/uspfilter/conntrack/tcp_rst_bugs_test.go b/client/firewall/uspfilter/conntrack/tcp_rst_bugs_test.go new file mode 100644 index 000000000..81d4f5710 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp_rst_bugs_test.go @@ -0,0 +1,100 @@ +package conntrack + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/require" +) + +// RST hygiene tests: the tracker currently closes the flow on any RST that +// matches the 4-tuple, regardless of direction or state. These tests cover +// the minimum checks we want (no SEQ tracking). + +func TestTCPRstInSynSentWrongDirection(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0) + conn := tracker.connections[key] + require.Equal(t, TCPStateSynSent, conn.GetState()) + + // A RST arriving in the same direction as the SYN (i.e. TrackOutbound) + // cannot be a legitimate response. It must not close the connection. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPRst|TCPAck, 0) + require.Equal(t, TCPStateSynSent, conn.GetState(), + "RST in same direction as SYN must not close connection") + require.False(t, conn.IsTombstone()) +} + +func TestTCPRstInTimeWaitIgnored(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + // Drive to TIME-WAIT via active close. + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + + conn := tracker.connections[key] + require.Equal(t, TCPStateTimeWait, conn.GetState()) + require.False(t, conn.IsTombstone(), "TIME-WAIT must not be tombstoned") + + // Late RST during TIME-WAIT must not tombstone the entry (TIME-WAIT + // exists to absorb late segments). + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0) + require.Equal(t, TCPStateTimeWait, conn.GetState(), + "RST in TIME-WAIT must not transition state") + require.False(t, conn.IsTombstone(), + "RST in TIME-WAIT must not tombstone the entry") +} + +func TestTCPIllegalFlagCombos(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + conn := tracker.connections[key] + + // Illegal combos must be rejected and must not change state. + combos := []struct { + name string + flags uint8 + }{ + {"SYN+RST", TCPSyn | TCPRst}, + {"FIN+RST", TCPFin | TCPRst}, + {"SYN+FIN", TCPSyn | TCPFin}, + {"SYN+FIN+RST", TCPSyn | TCPFin | TCPRst}, + } + + for _, c := range combos { + t.Run(c.name, func(t *testing.T) { + before := conn.GetState() + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, c.flags, 0) + require.False(t, valid, "illegal flag combo must be rejected: %s", c.name) + require.Equal(t, before, conn.GetState(), + "illegal flag combo must not change state") + require.False(t, conn.IsTombstone()) + }) + } +} diff --git a/client/firewall/uspfilter/conntrack/tcp_state_bugs_test.go b/client/firewall/uspfilter/conntrack/tcp_state_bugs_test.go new file mode 100644 index 000000000..32112cd58 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp_state_bugs_test.go @@ -0,0 +1,235 @@ +package conntrack + +import ( + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// These tests exercise cases where the TCP state machine currently advances +// on retransmitted or wrong-direction segments and tears the flow down +// prematurely. They are expected to fail until the direction checks are added. + +func TestTCPCloseWaitRetransmittedPeerFIN(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Peer sends FIN -> CloseWait (our app has not yet closed). + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + require.True(t, valid) + conn := tracker.connections[key] + require.Equal(t, TCPStateCloseWait, conn.GetState()) + + // Peer retransmits their FIN (ACK may have been delayed). We have NOT + // sent our FIN yet, so state must remain CloseWait. + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + require.True(t, valid, "retransmitted peer FIN must still be accepted") + require.Equal(t, TCPStateCloseWait, conn.GetState(), + "retransmitted peer FIN must not advance CloseWait to LastAck") + + // Our app finally closes -> LastAck. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + require.Equal(t, TCPStateLastAck, conn.GetState()) + + // Peer ACK closes. + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + require.True(t, valid) + require.Equal(t, TCPStateClosed, conn.GetState()) +} + +func TestTCPFinWait2RetransmittedOwnFIN(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // We initiate close. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + require.True(t, valid) + conn := tracker.connections[key] + require.Equal(t, TCPStateFinWait2, conn.GetState()) + + // Stray retransmit of our own FIN (same direction as originator) must + // NOT advance FinWait2 to TimeWait; only the peer's FIN should. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + require.Equal(t, TCPStateFinWait2, conn.GetState(), + "own FIN retransmit must not advance FinWait2 to TimeWait") + + // Peer FIN -> TimeWait. + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + require.True(t, valid) + require.Equal(t, TCPStateTimeWait, conn.GetState()) +} + +func TestTCPLastAckDirectionCheck(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Drive to LastAck: peer FIN -> CloseWait, our FIN -> LastAck. + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + conn := tracker.connections[key] + require.Equal(t, TCPStateLastAck, conn.GetState()) + + // Our own ACK retransmit (same direction as originator) must NOT close. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + require.Equal(t, TCPStateLastAck, conn.GetState(), + "own ACK retransmit in LastAck must not transition to Closed") + + // Peer's ACK -> Closed. + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) + require.Equal(t, TCPStateClosed, conn.GetState()) +} + +func TestTCPFinWait1OwnAckDoesNotAdvance(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + conn := tracker.connections[key] + require.Equal(t, TCPStateFinWait1, conn.GetState()) + + // Our own ACK retransmit (same direction as originator) must not advance. + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + require.Equal(t, TCPStateFinWait1, conn.GetState(), + "own ACK in FinWait1 must not advance to FinWait2") +} + +func TestTCPPerStateTeardownTimeouts(t *testing.T) { + // Verify cleanup reaps entries in each teardown state at the configured + // per-state timeout, not at the single handshake timeout. + t.Setenv(EnvTCPFinWaitTimeout, "50ms") + t.Setenv(EnvTCPCloseWaitTimeout, "80ms") + t.Setenv(EnvTCPLastAckTimeout, "30ms") + + dstIP := netip.MustParseAddr("100.64.0.2") + dstPort := uint16(80) + + // Drives a connection to the target state, forces its lastSeen well + // beyond the configured timeout, runs cleanup, and asserts reaping. + cases := []struct { + name string + // drive takes a fresh tracker and returns the conn key after + // transitioning the flow into the intended teardown state. + drive func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) + }{ + { + name: "FinWait1", + drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) { + establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort) + tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → FinWait1 + return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait1 + }, + }, + { + name: "FinWait2", + drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) { + establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort) + tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // FinWait1 + require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)) // → FinWait2 + return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateFinWait2 + }, + }, + { + name: "CloseWait", + drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) { + establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort) + require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // → CloseWait + return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateCloseWait + }, + }, + { + name: "LastAck", + drive: func(t *testing.T, tr *TCPTracker, srcIP netip.Addr, srcPort uint16) (ConnKey, TCPState) { + establishConnection(t, tr, srcIP, dstIP, srcPort, dstPort) + require.True(t, tr.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) // CloseWait + tr.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // → LastAck + return ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}, TCPStateLastAck + }, + }, + } + + // Use a unique source port per subtest so nothing aliases. + port := uint16(12345) + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + require.Equal(t, 50*time.Millisecond, tracker.finWaitTimeout) + require.Equal(t, 80*time.Millisecond, tracker.closeWaitTimeout) + require.Equal(t, 30*time.Millisecond, tracker.lastAckTimeout) + + srcIP := netip.MustParseAddr("100.64.0.1") + port++ + key, wantState := c.drive(t, tracker, srcIP, port) + conn := tracker.connections[key] + require.NotNil(t, conn) + require.Equal(t, wantState, conn.GetState()) + + // Age the entry past the largest per-state timeout. + conn.lastSeen.Store(time.Now().Add(-500 * time.Millisecond).UnixNano()) + tracker.cleanup() + _, exists := tracker.connections[key] + require.False(t, exists, "%s entry should be reaped", c.name) + }) + } +} + +func TestTCPEstablishedPSHACKInFinStates(t *testing.T) { + // Verifies FIN|PSH|ACK and bare ACK keepalives are not dropped in FIN + // teardown states, which some stacks emit during close. + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Peer FIN -> CloseWait. + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)) + + // Peer pushes trailing data + FIN|PSH|ACK (legal). + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPPush|TCPAck, 100), + "FIN|PSH|ACK in CloseWait must be accepted") + + // Bare ACK keepalive from peer in CloseWait must be accepted. + require.True(t, tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0), + "bare ACK in CloseWait must be accepted") +} diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index a3b6a418b..335c5832a 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -17,6 +17,9 @@ const ( DefaultUDPTimeout = 30 * time.Second // UDPCleanupInterval is how often we check for stale connections UDPCleanupInterval = 15 * time.Second + + // EnvUDPMaxEntries caps the UDP conntrack table size. + EnvUDPMaxEntries = "NB_CONNTRACK_UDP_MAX" ) // UDPConnTrack represents a UDP connection state @@ -34,6 +37,7 @@ type UDPTracker struct { cleanupTicker *time.Ticker tickerCancel context.CancelFunc mutex sync.RWMutex + maxEntries int flowLogger nftypes.FlowLogger } @@ -51,6 +55,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp timeout: timeout, cleanupTicker: time.NewTicker(UDPCleanupInterval), tickerCancel: cancel, + maxEntries: envInt(logger, EnvUDPMaxEntries, DefaultMaxUDPEntries), flowLogger: flowLogger, } @@ -117,13 +122,18 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d conn.UpdateCounters(direction, size) t.mutex.Lock() + if t.maxEntries > 0 && len(t.connections) >= t.maxEntries { + t.evictOneLocked() + } t.connections[key] = conn t.mutex.Unlock() - if origPort != 0 { - t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) - } else { - t.logger.Trace2("New %s UDP connection: %s", direction, key) + if t.logger.Enabled(nblog.LevelTrace) { + if origPort != 0 { + t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s UDP connection: %s", direction, key) + } } t.sendEvent(nftypes.TypeStart, conn, ruleID) } @@ -151,6 +161,34 @@ func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort return true } +// evictOneLocked removes one entry to make room. Caller must hold t.mutex. +// Bounded sample: picks the oldest among up to evictSampleSize entries. +func (t *UDPTracker) evictOneLocked() { + var candKey ConnKey + var candSeen int64 + haveCand := false + sampled := 0 + + for k, c := range t.connections { + seen := c.lastSeen.Load() + if !haveCand || seen < candSeen { + candKey = k + candSeen = seen + haveCand = true + } + sampled++ + if sampled >= evictSampleSize { + break + } + } + if haveCand { + if evicted := t.connections[candKey]; evicted != nil { + t.sendEvent(nftypes.TypeEnd, evicted, nil) + } + delete(t.connections, candKey) + } +} + // cleanupRoutine periodically removes stale connections func (t *UDPTracker) cleanupRoutine(ctx context.Context) { defer t.cleanupTicker.Stop() @@ -173,8 +211,10 @@ func (t *UDPTracker) cleanup() { if conn.timeoutExceeded(t.timeout) { delete(t.connections, key) - t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]", - key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + if t.logger.Enabled(nblog.LevelTrace) { + t.logger.Trace5("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]", + key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + } t.sendEvent(nftypes.TypeEnd, conn, nil) } } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 24b3d0167..1d4dcb1e5 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -709,7 +709,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { srcIP, dstIP := m.extractIPs(d) if !srcIP.IsValid() { - m.logger.Error1("Unknown network layer: %v", d.decoded[0]) + if m.logger.Enabled(nblog.LevelError) { + m.logger.Error1("Unknown network layer: %v", d.decoded[0]) + } return false } @@ -808,7 +810,9 @@ func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool { return false } - m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue) + } return true } @@ -931,8 +935,10 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { // TODO: pass fragments of routed packets to forwarder if fragment { - m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", - srcIP, dstIP, d.ip4.Id, d.ip4.Flags) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace4("packet is a fragment: src=%v dst=%v id=%v flags=%v", + srcIP, dstIP, d.ip4.Id, d.ip4.Flags) + } return false } @@ -974,8 +980,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet pnum := getProtocolFromPacket(d) srcPort, dstPort := getPortsFromPacket(d) - m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", - ruleID, pnum, srcIP, srcPort, dstIP, dstPort) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", + ruleID, pnum, srcIP, srcPort, dstIP, dstPort) + } m.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: uuid.New(), @@ -1025,8 +1033,10 @@ func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool { func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { // Drop if routing is disabled if !m.routingEnabled.Load() { - m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s", - srcIP, dstIP) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace2("Dropping routed packet (routing disabled): src=%s dst=%s", + srcIP, dstIP) + } return true } @@ -1043,8 +1053,10 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe if !pass { proto := getProtocolFromPacket(d) - m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", - ruleID, proto, srcIP, srcPort, dstIP, dstPort) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", + ruleID, proto, srcIP, srcPort, dstIP, dstPort) + } m.flowLogger.StoreEvent(nftypes.EventFields{ FlowID: uuid.New(), @@ -1126,7 +1138,9 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) { // It returns true, true if the packet is a fragment and valid. func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) { if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - m.logger.Trace1("couldn't decode packet, err: %s", err) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace1("couldn't decode packet, err: %s", err) + } return false, false } diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index cb3db325d..5aa280d43 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -13,6 +13,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" ) @@ -92,8 +93,10 @@ func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []by return nil, fmt.Errorf("write ICMP packet: %w", err) } - f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v", - epID(id), icmpType, icmpCode) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v", + epID(id), icmpType, icmpCode) + } return conn, nil } @@ -116,8 +119,10 @@ func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndp txBytes := f.handleEchoResponse(conn, id) rtt := time.Since(sendTime).Round(10 * time.Microsecond) - f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)", - epID(id), icmpType, icmpCode, rtt) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)", + epID(id), icmpType, icmpCode, rtt) + } f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } @@ -198,13 +203,17 @@ func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpoi } rtt := time.Since(pingStart).Round(10 * time.Microsecond) - f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v", - epID(id), icmpType, icmpCode) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v", + epID(id), icmpType, icmpCode) + } txBytes := f.synthesizeEchoReply(id, icmpData) - f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)", - epID(id), icmpType, icmpCode, rtt) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)", + epID(id), icmpType, icmpCode, rtt) + } f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes)) } diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go index aef420061..8e95522fe 100644 --- a/client/firewall/uspfilter/forwarder/tcp.go +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -1,12 +1,9 @@ package forwarder import ( - "context" "fmt" - "io" "net" "net/netip" - "sync" "github.com/google/uuid" @@ -16,7 +13,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" + "github.com/netbirdio/netbird/util/netrelay" ) // handleTCP is called by the TCP forwarder for new connections. @@ -38,7 +37,9 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) if err != nil { r.Complete(true) - f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace2("forwarder: dial error for %v: %v", epID(id), err) + } return } @@ -61,64 +62,22 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { inConn := gonet.NewTCPConn(&wq, ep) success = true - f.logger.Trace1("forwarder: established TCP connection %v", epID(id)) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace1("forwarder: established TCP connection %v", epID(id)) + } go f.proxyTCP(id, inConn, outConn, ep, flowID) } func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) { + // netrelay.Relay copies bidirectionally with proper half-close propagation + // and fully closes both conns before returning. + bytesFromInToOut, bytesFromOutToIn := netrelay.Relay(f.ctx, inConn, outConn, netrelay.Options{ + Logger: f.logger, + }) - ctx, cancel := context.WithCancel(f.ctx) - defer cancel() - - go func() { - <-ctx.Done() - // Close connections and endpoint. - if err := inConn.Close(); err != nil && !isClosedError(err) { - f.logger.Debug1("forwarder: inConn close error: %v", err) - } - if err := outConn.Close(); err != nil && !isClosedError(err) { - f.logger.Debug1("forwarder: outConn close error: %v", err) - } - - ep.Close() - }() - - var wg sync.WaitGroup - wg.Add(2) - - var ( - bytesFromInToOut int64 // bytes from client to server (tx for client) - bytesFromOutToIn int64 // bytes from server to client (rx for client) - errInToOut error - errOutToIn error - ) - - go func() { - bytesFromInToOut, errInToOut = io.Copy(outConn, inConn) - cancel() - wg.Done() - }() - - go func() { - - bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn) - cancel() - wg.Done() - }() - - wg.Wait() - - if errInToOut != nil { - if !isClosedError(errInToOut) { - f.logger.Error2("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut) - } - } - if errOutToIn != nil { - if !isClosedError(errOutToIn) { - f.logger.Error2("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn) - } - } + // Close the netstack endpoint after both conns are drained. + ep.Close() var rxPackets, txPackets uint64 if tcpStats, ok := ep.Stats().(*tcp.Stats); ok { @@ -127,7 +86,9 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn txPackets = tcpStats.SegmentsReceived.Value() } - f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace5("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut) + } f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets) } diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index f175e275b..778a40842 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -125,7 +125,9 @@ func (f *udpForwarder) cleanup() { delete(f.conns, idle.id) f.Unlock() - f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id)) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace1("forwarder: cleaned up idle UDP connection %v", epID(idle.id)) + } } } } @@ -144,7 +146,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool { _, exists := f.udpForwarder.conns[id] f.udpForwarder.RUnlock() if exists { - f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id)) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id)) + } return true } @@ -206,7 +210,9 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool { f.udpForwarder.Unlock() success = true - f.logger.Trace1("forwarder: established UDP connection %v", epID(id)) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace1("forwarder: established UDP connection %v", epID(id)) + } go f.proxyUDP(connCtx, pConn, id, ep) return true @@ -265,7 +271,9 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack txPackets = udpStats.PacketsReceived.Value() } - f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes) + if f.logger.Enabled(nblog.LevelTrace) { + f.logger.Trace5("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes) + } f.udpForwarder.Lock() delete(f.udpForwarder.conns, id) diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go index c6ca55e70..03e7d4809 100644 --- a/client/firewall/uspfilter/log/log.go +++ b/client/firewall/uspfilter/log/log.go @@ -53,16 +53,17 @@ var levelStrings = map[Level]string{ } type logMessage struct { - level Level - format string - arg1 any - arg2 any - arg3 any - arg4 any - arg5 any - arg6 any - arg7 any - arg8 any + level Level + argCount uint8 + format string + arg1 any + arg2 any + arg3 any + arg4 any + arg5 any + arg6 any + arg7 any + arg8 any } // Logger is a high-performance, non-blocking logger @@ -107,6 +108,13 @@ func (l *Logger) SetLevel(level Level) { log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) } +// Enabled reports whether the given level is currently logged. Callers on the +// hot path should guard log sites with this to avoid boxing arguments into +// any when the level is off. +func (l *Logger) Enabled(level Level) bool { + return l.level.Load() >= uint32(level) +} + func (l *Logger) Error(format string) { if l.level.Load() >= uint32(LevelError) { select { @@ -155,7 +163,7 @@ func (l *Logger) Trace(format string) { func (l *Logger) Error1(format string, arg1 any) { if l.level.Load() >= uint32(LevelError) { select { - case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1}: + case l.msgChannel <- logMessage{level: LevelError, argCount: 1, format: format, arg1: arg1}: default: } } @@ -164,7 +172,16 @@ func (l *Logger) Error1(format string, arg1 any) { func (l *Logger) Error2(format string, arg1, arg2 any) { if l.level.Load() >= uint32(LevelError) { select { - case l.msgChannel <- logMessage{level: LevelError, format: format, arg1: arg1, arg2: arg2}: + case l.msgChannel <- logMessage{level: LevelError, argCount: 2, format: format, arg1: arg1, arg2: arg2}: + default: + } + } +} + +func (l *Logger) Warn2(format string, arg1, arg2 any) { + if l.level.Load() >= uint32(LevelWarn) { + select { + case l.msgChannel <- logMessage{level: LevelWarn, argCount: 2, format: format, arg1: arg1, arg2: arg2}: default: } } @@ -173,7 +190,7 @@ func (l *Logger) Error2(format string, arg1, arg2 any) { func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) { if l.level.Load() >= uint32(LevelWarn) { select { - case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + case l.msgChannel <- logMessage{level: LevelWarn, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: default: } } @@ -182,7 +199,7 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) { func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) { if l.level.Load() >= uint32(LevelWarn) { select { - case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: + case l.msgChannel <- logMessage{level: LevelWarn, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: default: } } @@ -191,7 +208,7 @@ func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) { func (l *Logger) Debug1(format string, arg1 any) { if l.level.Load() >= uint32(LevelDebug) { select { - case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1}: + case l.msgChannel <- logMessage{level: LevelDebug, argCount: 1, format: format, arg1: arg1}: default: } } @@ -200,7 +217,7 @@ func (l *Logger) Debug1(format string, arg1 any) { func (l *Logger) Debug2(format string, arg1, arg2 any) { if l.level.Load() >= uint32(LevelDebug) { select { - case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2}: + case l.msgChannel <- logMessage{level: LevelDebug, argCount: 2, format: format, arg1: arg1, arg2: arg2}: default: } } @@ -209,16 +226,59 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) { func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) { if l.level.Load() >= uint32(LevelDebug) { select { - case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + case l.msgChannel <- logMessage{level: LevelDebug, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: default: } } } +// Debugf is the variadic shape. Dispatches to Debug/Debug1/Debug2/Debug3 +// to avoid allocating an args slice on the fast path when the arg count is +// known (0-3). Args beyond 3 land on the general variadic path; callers on +// the hot path should prefer DebugN for known counts. +func (l *Logger) Debugf(format string, args ...any) { + if l.level.Load() < uint32(LevelDebug) { + return + } + switch len(args) { + case 0: + l.Debug(format) + case 1: + l.Debug1(format, args[0]) + case 2: + l.Debug2(format, args[0], args[1]) + case 3: + l.Debug3(format, args[0], args[1], args[2]) + default: + l.sendVariadic(LevelDebug, format, args) + } +} + +// sendVariadic packs a slice of arguments into a logMessage and non-blocking +// enqueues it. Used for arg counts beyond the fixed-arity fast paths. Args +// beyond the 8-arg slot limit are dropped so callers don't produce silently +// empty log lines via uint8 wraparound in argCount. +func (l *Logger) sendVariadic(level Level, format string, args []any) { + const maxArgs = 8 + n := len(args) + if n > maxArgs { + n = maxArgs + } + msg := logMessage{level: level, argCount: uint8(n), format: format} + slots := [maxArgs]*any{&msg.arg1, &msg.arg2, &msg.arg3, &msg.arg4, &msg.arg5, &msg.arg6, &msg.arg7, &msg.arg8} + for i := 0; i < n; i++ { + *slots[i] = args[i] + } + select { + case l.msgChannel <- msg: + default: + } +} + func (l *Logger) Trace1(format string, arg1 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 1, format: format, arg1: arg1}: default: } } @@ -227,7 +287,7 @@ func (l *Logger) Trace1(format string, arg1 any) { func (l *Logger) Trace2(format string, arg1, arg2 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 2, format: format, arg1: arg1, arg2: arg2}: default: } } @@ -236,7 +296,7 @@ func (l *Logger) Trace2(format string, arg1, arg2 any) { func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 3, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: default: } } @@ -245,7 +305,7 @@ func (l *Logger) Trace3(format string, arg1, arg2, arg3 any) { func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 4, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}: default: } } @@ -254,7 +314,7 @@ func (l *Logger) Trace4(format string, arg1, arg2, arg3, arg4 any) { func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 5, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5}: default: } } @@ -263,7 +323,7 @@ func (l *Logger) Trace5(format string, arg1, arg2, arg3, arg4, arg5 any) { func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 6, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6}: default: } } @@ -273,7 +333,7 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) { if l.level.Load() >= uint32(LevelTrace) { select { - case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}: + case l.msgChannel <- logMessage{level: LevelTrace, argCount: 8, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}: default: } } @@ -286,35 +346,8 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { *buf = append(*buf, levelStrings[msg.level]...) *buf = append(*buf, ' ') - // Count non-nil arguments for switch - argCount := 0 - if msg.arg1 != nil { - argCount++ - if msg.arg2 != nil { - argCount++ - if msg.arg3 != nil { - argCount++ - if msg.arg4 != nil { - argCount++ - if msg.arg5 != nil { - argCount++ - if msg.arg6 != nil { - argCount++ - if msg.arg7 != nil { - argCount++ - if msg.arg8 != nil { - argCount++ - } - } - } - } - } - } - } - } - var formatted string - switch argCount { + switch msg.argCount { case 0: formatted = msg.format case 1: diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 8ed32eb5e..c24b18daa 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -11,6 +11,7 @@ import ( "github.com/google/gopacket/layers" firewall "github.com/netbirdio/netbird/client/firewall/manager" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") @@ -242,11 +243,15 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { } if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { - m.logger.Error1("failed to rewrite packet destination: %v", err) + if m.logger.Enabled(nblog.LevelError) { + m.logger.Error1("failed to rewrite packet destination: %v", err) + } return false } - m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace2("DNAT: %s -> %s", dstIP, translatedIP) + } return true } @@ -264,11 +269,15 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { } if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { - m.logger.Error1("failed to rewrite packet source: %v", err) + if m.logger.Enabled(nblog.LevelError) { + m.logger.Error1("failed to rewrite packet source: %v", err) + } return false } - m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP) + if m.logger.Enabled(nblog.LevelTrace) { + m.logger.Trace2("Reverse DNAT: %s -> %s", srcIP, originalIP) + } return true } @@ -521,7 +530,9 @@ func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP neti } if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil { - m.logger.Error1("failed to rewrite port: %v", err) + if m.logger.Enabled(nblog.LevelError) { + m.logger.Error1("failed to rewrite port: %v", err) + } return false } d.dnatOrigPort = rule.origPort diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index 7f72a72cf..61904366d 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -25,6 +25,7 @@ import ( nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/util" + "github.com/netbirdio/netbird/util/netrelay" ) const ( @@ -536,7 +537,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str continue } - go c.handleLocalForward(localConn, remoteAddr) + go c.handleLocalForward(ctx, localConn, remoteAddr) } }() @@ -548,7 +549,7 @@ func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr str } // handleLocalForward handles a single local port forwarding connection -func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { +func (c *Client) handleLocalForward(ctx context.Context, localConn net.Conn, remoteAddr string) { defer func() { if err := localConn.Close(); err != nil { log.Debugf("local port forwarding: close local connection: %v", err) @@ -571,7 +572,7 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { } }() - nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel) + netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())}) } // RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr @@ -653,16 +654,19 @@ func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr stri select { case <-ctx.Done(): return - case newChan := <-channelRequests: + case newChan, ok := <-channelRequests: + if !ok { + return + } if newChan != nil { - go c.handleRemoteForwardChannel(newChan, localAddr) + go c.handleRemoteForwardChannel(ctx, newChan, localAddr) } } } } // handleRemoteForwardChannel handles a single forwarded-tcpip channel -func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) { +func (c *Client) handleRemoteForwardChannel(ctx context.Context, newChan ssh.NewChannel, localAddr string) { channel, reqs, err := newChan.Accept() if err != nil { return @@ -675,7 +679,8 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st go ssh.DiscardRequests(reqs) - localConn, err := net.Dial("tcp", localAddr) + var dialer net.Dialer + localConn, err := dialer.DialContext(ctx, "tcp", localAddr) if err != nil { return } @@ -685,7 +690,7 @@ func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr st } }() - nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel) + netrelay.Relay(ctx, localConn, channel, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())}) } // tcpipForwardMsg represents the structure for tcpip-forward requests diff --git a/client/ssh/common.go b/client/ssh/common.go index f6aec5f9c..92e647b7d 100644 --- a/client/ssh/common.go +++ b/client/ssh/common.go @@ -194,63 +194,3 @@ func buildAddressList(hostname string, remote net.Addr) []string { return addresses } -// BidirectionalCopy copies data bidirectionally between two io.ReadWriter connections. -// It waits for both directions to complete before returning. -// The caller is responsible for closing the connections. -func BidirectionalCopy(logger *log.Entry, rw1, rw2 io.ReadWriter) { - done := make(chan struct{}, 2) - - go func() { - if _, err := io.Copy(rw2, rw1); err != nil && !isExpectedCopyError(err) { - logger.Debugf("copy error (1->2): %v", err) - } - done <- struct{}{} - }() - - go func() { - if _, err := io.Copy(rw1, rw2); err != nil && !isExpectedCopyError(err) { - logger.Debugf("copy error (2->1): %v", err) - } - done <- struct{}{} - }() - - <-done - <-done -} - -func isExpectedCopyError(err error) bool { - return errors.Is(err, io.EOF) || errors.Is(err, context.Canceled) -} - -// BidirectionalCopyWithContext copies data bidirectionally between two io.ReadWriteCloser connections. -// It waits for both directions to complete or for context cancellation before returning. -// Both connections are closed when the function returns. -func BidirectionalCopyWithContext(logger *log.Entry, ctx context.Context, conn1, conn2 io.ReadWriteCloser) { - done := make(chan struct{}, 2) - - go func() { - if _, err := io.Copy(conn2, conn1); err != nil && !isExpectedCopyError(err) { - logger.Debugf("copy error (1->2): %v", err) - } - done <- struct{}{} - }() - - go func() { - if _, err := io.Copy(conn1, conn2); err != nil && !isExpectedCopyError(err) { - logger.Debugf("copy error (2->1): %v", err) - } - done <- struct{}{} - }() - - select { - case <-ctx.Done(): - case <-done: - select { - case <-ctx.Done(): - case <-done: - } - } - - _ = conn1.Close() - _ = conn2.Close() -} diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index 59007f75c..f6bc0d250 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -23,6 +23,7 @@ import ( "github.com/netbirdio/netbird/client/proto" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/util/netrelay" "github.com/netbirdio/netbird/version" ) @@ -352,7 +353,7 @@ func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, ne } go cryptossh.DiscardRequests(clientReqs) - nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan) + netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())}) } func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) { @@ -591,7 +592,7 @@ func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh } go cryptossh.DiscardRequests(clientReqs) - nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan) + netrelay.Relay(sshCtx, clientChan, backendChan, netrelay.Options{Logger: log.NewEntry(log.StandardLogger())}) } func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) { diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go index e16ff5d46..a819840a5 100644 --- a/client/ssh/server/port_forwarding.go +++ b/client/ssh/server/port_forwarding.go @@ -17,7 +17,7 @@ import ( log "github.com/sirupsen/logrus" cryptossh "golang.org/x/crypto/ssh" - nbssh "github.com/netbirdio/netbird/client/ssh" + "github.com/netbirdio/netbird/util/netrelay" ) const privilegedPortThreshold = 1024 @@ -356,7 +356,7 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h return } - nbssh.BidirectionalCopyWithContext(logger, ctx, conn, channel) + netrelay.Relay(ctx, conn, channel, netrelay.Options{Logger: logger}) } // openForwardChannel creates an SSH forwarded-tcpip channel diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index 82d3b700f..739eec513 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -10,6 +10,7 @@ import ( "net" "net/netip" "slices" + "strconv" "strings" "sync" "time" @@ -26,6 +27,7 @@ import ( "github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/auth/jwt" + "github.com/netbirdio/netbird/util/netrelay" "github.com/netbirdio/netbird/version" ) @@ -52,6 +54,10 @@ const ( DefaultJWTMaxTokenAge = 10 * 60 ) +// directTCPIPDialTimeout bounds how long relayDirectTCPIP waits on a dial to +// the forwarded destination before rejecting the SSH channel. +const directTCPIPDialTimeout = 30 * time.Second + var ( ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled) ErrUserNotFound = errors.New("user not found") @@ -891,5 +897,29 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, s.addConnectionPortForward(ctx.User(), ctx.RemoteAddr(), forwardAddr) logger.Infof("local port forwarding: %s:%d", payload.Host, payload.Port) - ssh.DirectTCPIPHandler(srv, conn, newChan, ctx) + s.relayDirectTCPIP(ctx, newChan, payload.Host, int(payload.Port), logger) +} + +// relayDirectTCPIP is a netrelay-based replacement for gliderlabs' +// DirectTCPIPHandler. The upstream handler closes both sides on the first +// EOF; netrelay.Relay propagates CloseWrite so each direction drains on its +// own terms. +func (s *Server) relayDirectTCPIP(ctx ssh.Context, newChan cryptossh.NewChannel, host string, port int, logger *log.Entry) { + dest := net.JoinHostPort(host, strconv.Itoa(port)) + + dialer := net.Dialer{Timeout: directTCPIPDialTimeout} + dconn, err := dialer.DialContext(ctx, "tcp", dest) + if err != nil { + _ = newChan.Reject(cryptossh.ConnectionFailed, err.Error()) + return + } + + ch, reqs, err := newChan.Accept() + if err != nil { + _ = dconn.Close() + return + } + go cryptossh.DiscardRequests(reqs) + + netrelay.Relay(ctx, dconn, ch, netrelay.Options{Logger: logger}) } diff --git a/proxy/internal/tcp/peekedconn.go b/proxy/internal/tcp/peekedconn.go index 26f3e5c7c..23a348352 100644 --- a/proxy/internal/tcp/peekedconn.go +++ b/proxy/internal/tcp/peekedconn.go @@ -25,6 +25,12 @@ func (c *peekedConn) Read(b []byte) (int, error) { return c.reader.Read(b) } +// halfCloser matches connections that support shutting down the write +// side while keeping the read side open (e.g. *net.TCPConn). +type halfCloser interface { + CloseWrite() error +} + // CloseWrite delegates to the underlying connection if it supports // half-close (e.g. *net.TCPConn). Without this, embedding net.Conn // as an interface hides the concrete type's CloseWrite method, making diff --git a/proxy/internal/tcp/relay.go b/proxy/internal/tcp/relay.go deleted file mode 100644 index 39949818d..000000000 --- a/proxy/internal/tcp/relay.go +++ /dev/null @@ -1,156 +0,0 @@ -package tcp - -import ( - "context" - "errors" - "io" - "net" - "sync" - "time" - - log "github.com/sirupsen/logrus" - - "github.com/netbirdio/netbird/proxy/internal/netutil" -) - -// errIdleTimeout is returned when a relay connection is closed due to inactivity. -var errIdleTimeout = errors.New("idle timeout") - -// DefaultIdleTimeout is the default idle timeout for TCP relay connections. -// A zero value disables idle timeout checking. -const DefaultIdleTimeout = 5 * time.Minute - -// halfCloser is implemented by connections that support half-close -// (e.g. *net.TCPConn). When one copy direction finishes, we signal -// EOF to the remote by closing the write side while keeping the read -// side open so the other direction can drain. -type halfCloser interface { - CloseWrite() error -} - -// copyBufPool avoids allocating a new 32KB buffer per io.Copy call. -var copyBufPool = sync.Pool{ - New: func() any { - buf := make([]byte, 32*1024) - return &buf - }, -} - -// Relay copies data bidirectionally between src and dst until both -// sides are done or the context is canceled. When idleTimeout is -// non-zero, each direction's read is deadline-guarded; if no data -// flows within the timeout the connection is torn down. When one -// direction finishes, it half-closes the write side of the -// destination (if supported) to signal EOF, allowing the other -// direction to drain gracefully before the full connection teardown. -func Relay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (srcToDst, dstToSrc int64) { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - go func() { - <-ctx.Done() - _ = src.Close() - _ = dst.Close() - }() - - var wg sync.WaitGroup - wg.Add(2) - - var errSrcToDst, errDstToSrc error - - go func() { - defer wg.Done() - srcToDst, errSrcToDst = copyWithIdleTimeout(dst, src, idleTimeout) - halfClose(dst) - cancel() - }() - - go func() { - defer wg.Done() - dstToSrc, errDstToSrc = copyWithIdleTimeout(src, dst, idleTimeout) - halfClose(src) - cancel() - }() - - wg.Wait() - - if errors.Is(errSrcToDst, errIdleTimeout) || errors.Is(errDstToSrc, errIdleTimeout) { - logger.Debug("relay closed due to idle timeout") - } - if errSrcToDst != nil && !isExpectedCopyError(errSrcToDst) { - logger.Debugf("relay copy error (src→dst): %v", errSrcToDst) - } - if errDstToSrc != nil && !isExpectedCopyError(errDstToSrc) { - logger.Debugf("relay copy error (dst→src): %v", errDstToSrc) - } - - return srcToDst, dstToSrc -} - -// copyWithIdleTimeout copies from src to dst using a pooled buffer. -// When idleTimeout > 0 it sets a read deadline on src before each -// read and treats a timeout as an idle-triggered close. -func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) { - bufp := copyBufPool.Get().(*[]byte) - defer copyBufPool.Put(bufp) - - if idleTimeout <= 0 { - return io.CopyBuffer(dst, src, *bufp) - } - - conn, ok := src.(net.Conn) - if !ok { - return io.CopyBuffer(dst, src, *bufp) - } - - buf := *bufp - var total int64 - for { - if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil { - return total, err - } - nr, readErr := src.Read(buf) - if nr > 0 { - n, err := checkedWrite(dst, buf[:nr]) - total += n - if err != nil { - return total, err - } - } - if readErr != nil { - if netutil.IsTimeout(readErr) { - return total, errIdleTimeout - } - return total, readErr - } - } -} - -// checkedWrite writes buf to dst and returns the number of bytes written. -// It guards against short writes and negative counts per io.Copy convention. -func checkedWrite(dst io.Writer, buf []byte) (int64, error) { - nw, err := dst.Write(buf) - if nw < 0 || nw > len(buf) { - nw = 0 - } - if err != nil { - return int64(nw), err - } - if nw != len(buf) { - return int64(nw), io.ErrShortWrite - } - return int64(nw), nil -} - -func isExpectedCopyError(err error) bool { - return errors.Is(err, errIdleTimeout) || netutil.IsExpectedError(err) -} - -// halfClose attempts to half-close the write side of the connection. -// If the connection does not support half-close, this is a no-op. -func halfClose(conn net.Conn) { - if hc, ok := conn.(halfCloser); ok { - // Best-effort; the full close will follow shortly. - _ = hc.CloseWrite() - } -} diff --git a/proxy/internal/tcp/relay_test.go b/proxy/internal/tcp/relay_test.go index e42d65b9d..f83a0d155 100644 --- a/proxy/internal/tcp/relay_test.go +++ b/proxy/internal/tcp/relay_test.go @@ -13,8 +13,13 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/proxy/internal/netutil" + "github.com/netbirdio/netbird/util/netrelay" ) +func testRelay(ctx context.Context, logger *log.Entry, src, dst net.Conn, idleTimeout time.Duration) (int64, int64) { + return netrelay.Relay(ctx, src, dst, netrelay.Options{IdleTimeout: idleTimeout, Logger: logger}) +} + func TestRelay_BidirectionalCopy(t *testing.T) { srcClient, srcServer := net.Pipe() dstClient, dstServer := net.Pipe() @@ -41,7 +46,7 @@ func TestRelay_BidirectionalCopy(t *testing.T) { srcClient.Close() }() - s2d, d2s := Relay(ctx, logger, srcServer, dstServer, 0) + s2d, d2s := testRelay(ctx, logger, srcServer, dstServer, 0) assert.Equal(t, int64(len(srcData)), s2d, "bytes src→dst") assert.Equal(t, int64(len(dstData)), d2s, "bytes dst→src") @@ -58,7 +63,7 @@ func TestRelay_ContextCancellation(t *testing.T) { done := make(chan struct{}) go func() { - Relay(ctx, logger, srcServer, dstServer, 0) + testRelay(ctx, logger, srcServer, dstServer, 0) close(done) }() @@ -85,7 +90,7 @@ func TestRelay_OneSideClosed(t *testing.T) { done := make(chan struct{}) go func() { - Relay(ctx, logger, srcServer, dstServer, 0) + testRelay(ctx, logger, srcServer, dstServer, 0) close(done) }() @@ -129,7 +134,7 @@ func TestRelay_LargeTransfer(t *testing.T) { dstClient.Close() }() - s2d, _ := Relay(ctx, logger, srcServer, dstServer, 0) + s2d, _ := testRelay(ctx, logger, srcServer, dstServer, 0) assert.Equal(t, int64(len(data)), s2d, "should transfer all bytes") require.NoError(t, <-errCh) } @@ -182,7 +187,7 @@ func TestRelay_IdleTimeout(t *testing.T) { done := make(chan struct{}) var s2d, d2s int64 go func() { - s2d, d2s = Relay(ctx, logger, srcServer, dstServer, 200*time.Millisecond) + s2d, d2s = testRelay(ctx, logger, srcServer, dstServer, 200*time.Millisecond) close(done) }() diff --git a/proxy/internal/tcp/router.go b/proxy/internal/tcp/router.go index 9f8660aeb..05beb658b 100644 --- a/proxy/internal/tcp/router.go +++ b/proxy/internal/tcp/router.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/proxy/internal/accesslog" "github.com/netbirdio/netbird/proxy/internal/restrict" "github.com/netbirdio/netbird/proxy/internal/types" + "github.com/netbirdio/netbird/util/netrelay" ) // defaultDialTimeout is the fallback dial timeout when no per-route @@ -528,11 +529,14 @@ func (r *Router) relayTCP(ctx context.Context, conn net.Conn, sni SNIHost, route idleTimeout := route.SessionIdleTimeout if idleTimeout <= 0 { - idleTimeout = DefaultIdleTimeout + idleTimeout = netrelay.DefaultIdleTimeout } start := time.Now() - s2d, d2s := Relay(svcCtx, entry, conn, backend, idleTimeout) + s2d, d2s := netrelay.Relay(svcCtx, conn, backend, netrelay.Options{ + IdleTimeout: idleTimeout, + Logger: entry, + }) elapsed := time.Since(start) if obs != nil { diff --git a/util/netrelay/relay.go b/util/netrelay/relay.go new file mode 100644 index 000000000..3afd35b1b --- /dev/null +++ b/util/netrelay/relay.go @@ -0,0 +1,213 @@ +// Package netrelay provides a bidirectional byte-copy helper for TCP-like +// connections with correct half-close propagation. +// +// When one direction reads EOF, the write side of the opposite connection is +// half-closed (CloseWrite) so the peer sees FIN, then the second direction is +// allowed to drain to its own EOF before both connections are fully closed. +// This preserves TCP half-close semantics (e.g. shutdown(SHUT_WR)) that the +// naive "cancel-both-on-first-EOF" pattern breaks. +package netrelay + +import ( + "context" + "errors" + "io" + "net" + "sync" + "syscall" + "time" +) + +// DebugLogger is the minimal interface netrelay uses to surface teardown +// errors. Both *logrus.Entry and *nblog.Logger (via its Debugf method) +// satisfy it, so callers can pass whichever they already use without an +// adapter. Debugf is the only required method; callers with richer +// loggers just expose this one shape here. +type DebugLogger interface { + Debugf(format string, args ...any) +} + +// DefaultIdleTimeout is a reasonable default for Options.IdleTimeout. Callers +// that want an idle timeout but have no specific preference can use this. +const DefaultIdleTimeout = 5 * time.Minute + +// ErrIdleTimeout is returned when a relay direction is torn down after no +// data flowed for longer than Options.IdleTimeout. +var ErrIdleTimeout = errors.New("idle timeout") + +// halfCloser is implemented by connections that support half-close +// (e.g. *net.TCPConn, *gonet.TCPConn). +type halfCloser interface { + CloseWrite() error +} + +var copyBufPool = sync.Pool{ + New: func() any { + buf := make([]byte, 32*1024) + return &buf + }, +} + +// Options configures Relay behavior. The zero value is valid: no idle timeout, +// no logging. +type Options struct { + // IdleTimeout tears down a direction if no bytes flow within this + // window. Zero disables idle tracking. + IdleTimeout time.Duration + // Logger receives debug-level copy/idle errors. Nil suppresses logging. + // Any logger with Debug/Debugf methods is accepted (logrus.Entry, + // uspfilter's nblog.Logger, etc.). + Logger DebugLogger +} + +// Relay copies bytes in both directions between a and b until both directions +// EOF or ctx is canceled. On each direction's EOF it half-closes the +// opposite conn's write side (best effort) so the peer sees FIN while the +// other direction drains. Both conns are fully closed when Relay returns. +// +// a and b only need to implement io.ReadWriteCloser; connections that also +// implement CloseWrite (e.g. *net.TCPConn, ssh.Channel) get proper half-close +// propagation. Idle-timeout enforcement requires a net.Conn; for other types +// Options.IdleTimeout is ignored. +// +// Return values are byte counts: aToB (a.Read → b.Write) and bToA (b.Read → +// a.Write). Errors are logged via Options.Logger when set; they are not +// returned because a relay always terminates on some kind of EOF/cancel. +func Relay(ctx context.Context, a, b io.ReadWriteCloser, opts Options) (aToB, bToA int64) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + <-ctx.Done() + _ = a.Close() + _ = b.Close() + }() + + // Both sides must support CloseWrite to propagate half-close. If either + // doesn't, a direction's EOF can't be signaled to the peer and the other + // direction would block forever waiting for data; in that case we fall + // back to the cancel-both-on-first-EOF behavior. + _, aHC := a.(halfCloser) + _, bHC := b.(halfCloser) + halfCloseSupported := aHC && bHC + + var wg sync.WaitGroup + wg.Add(2) + + var errAToB, errBToA error + + go func() { + defer wg.Done() + aToB, errAToB = copyWithIdleTimeout(b, a, opts.IdleTimeout) + if halfCloseSupported { + halfClose(b) + } else { + cancel() + } + }() + + go func() { + defer wg.Done() + bToA, errBToA = copyWithIdleTimeout(a, b, opts.IdleTimeout) + if halfCloseSupported { + halfClose(a) + } else { + cancel() + } + }() + + wg.Wait() + + if opts.Logger != nil { + if errors.Is(errAToB, ErrIdleTimeout) || errors.Is(errBToA, ErrIdleTimeout) { + opts.Logger.Debugf("relay closed due to idle timeout") + } + if errAToB != nil && !isExpectedCopyError(errAToB) { + opts.Logger.Debugf("relay copy error (a→b): %v", errAToB) + } + if errBToA != nil && !isExpectedCopyError(errBToA) { + opts.Logger.Debugf("relay copy error (b→a): %v", errBToA) + } + } + + return aToB, bToA +} + +// copyWithIdleTimeout copies from src to dst using a pooled buffer. When +// idleTimeout > 0 and src is a net.Conn it sets a read deadline before each +// read; a timeout is reported as ErrIdleTimeout. +func copyWithIdleTimeout(dst io.Writer, src io.Reader, idleTimeout time.Duration) (int64, error) { + bufp := copyBufPool.Get().(*[]byte) + defer copyBufPool.Put(bufp) + + if idleTimeout <= 0 { + return io.CopyBuffer(dst, src, *bufp) + } + + conn, ok := src.(net.Conn) + if !ok { + return io.CopyBuffer(dst, src, *bufp) + } + + buf := *bufp + var total int64 + for { + if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil { + return total, err + } + nr, readErr := src.Read(buf) + if nr > 0 { + n, werr := checkedWrite(dst, buf[:nr]) + total += n + if werr != nil { + return total, werr + } + } + if readErr != nil { + if isNetTimeout(readErr) { + return total, ErrIdleTimeout + } + return total, readErr + } + } +} + +func checkedWrite(dst io.Writer, buf []byte) (int64, error) { + nw, err := dst.Write(buf) + if nw < 0 || nw > len(buf) { + nw = 0 + } + if err != nil { + return int64(nw), err + } + if nw != len(buf) { + return int64(nw), io.ErrShortWrite + } + return int64(nw), nil +} + +func halfClose(conn io.ReadWriteCloser) { + if hc, ok := conn.(halfCloser); ok { + _ = hc.CloseWrite() + } +} + +func isNetTimeout(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) { + return netErr.Timeout() + } + return false +} + +func isExpectedCopyError(err error) bool { + if errors.Is(err, ErrIdleTimeout) { + return true + } + return errors.Is(err, net.ErrClosed) || + errors.Is(err, context.Canceled) || + errors.Is(err, io.EOF) || + errors.Is(err, syscall.ECONNRESET) || + errors.Is(err, syscall.EPIPE) || + errors.Is(err, syscall.ECONNABORTED) +} diff --git a/util/netrelay/relay_test.go b/util/netrelay/relay_test.go new file mode 100644 index 000000000..26baebfbb --- /dev/null +++ b/util/netrelay/relay_test.go @@ -0,0 +1,203 @@ +package netrelay + +import ( + "io" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// tcpPair returns two connected loopback TCP conns. +func tcpPair(t *testing.T) (*net.TCPConn, *net.TCPConn) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer ln.Close() + + type result struct { + c *net.TCPConn + err error + } + ch := make(chan result, 1) + go func() { + c, err := ln.Accept() + if err != nil { + ch <- result{nil, err} + return + } + ch <- result{c.(*net.TCPConn), nil} + }() + + dial, err := net.Dial("tcp", ln.Addr().String()) + require.NoError(t, err) + + r := <-ch + require.NoError(t, r.err) + return dial.(*net.TCPConn), r.c +} + +// TestRelayHalfClose exercises the shutdown(SHUT_WR) scenario that the naive +// cancel-both-on-first-EOF pattern breaks. Client A shuts down its write +// side; B must still be able to write a full response and A must receive +// all of it before its read returns EOF. +func TestRelayHalfClose(t *testing.T) { + // Real peer pairs for each side of the relay. We relay between relayA + // and relayB. Peer A talks through relayA; peer B talks through relayB. + peerA, relayA := tcpPair(t) + relayB, peerB := tcpPair(t) + + defer peerA.Close() + defer peerB.Close() + + ctx := t.Context() + + done := make(chan struct{}) + go func() { + Relay(ctx, relayA, relayB, Options{}) + close(done) + }() + + // Peer A sends a request, then half-closes its write side. + req := []byte("request-payload") + _, err := peerA.Write(req) + require.NoError(t, err) + require.NoError(t, peerA.CloseWrite()) + + // Peer B reads the request to EOF (FIN must have propagated). + got, err := io.ReadAll(peerB) + require.NoError(t, err) + require.Equal(t, req, got) + + // Peer B writes its response; peer A must receive all of it even though + // peer A's write side is already closed. + resp := make([]byte, 64*1024) + for i := range resp { + resp[i] = byte(i) + } + _, err = peerB.Write(resp) + require.NoError(t, err) + require.NoError(t, peerB.Close()) + + gotResp, err := io.ReadAll(peerA) + require.NoError(t, err) + require.Equal(t, resp, gotResp) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("relay did not return") + } +} + +// TestRelayFullDuplex verifies bidirectional copy in the simple case. +func TestRelayFullDuplex(t *testing.T) { + peerA, relayA := tcpPair(t) + relayB, peerB := tcpPair(t) + defer peerA.Close() + defer peerB.Close() + + ctx := t.Context() + + done := make(chan struct{}) + go func() { + Relay(ctx, relayA, relayB, Options{}) + close(done) + }() + + type result struct { + got []byte + err error + } + resA := make(chan result, 1) + resB := make(chan result, 1) + + msgAB := []byte("hello-from-a") + msgBA := []byte("hello-from-b") + + go func() { + if _, err := peerA.Write(msgAB); err != nil { + resA <- result{err: err} + return + } + buf := make([]byte, len(msgBA)) + _, err := io.ReadFull(peerA, buf) + resA <- result{got: buf, err: err} + _ = peerA.Close() + }() + + go func() { + if _, err := peerB.Write(msgBA); err != nil { + resB <- result{err: err} + return + } + buf := make([]byte, len(msgAB)) + _, err := io.ReadFull(peerB, buf) + resB <- result{got: buf, err: err} + _ = peerB.Close() + }() + + a, b := <-resA, <-resB + require.NoError(t, a.err) + require.Equal(t, msgBA, a.got) + require.NoError(t, b.err) + require.Equal(t, msgAB, b.got) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("relay did not return") + } +} + +// TestRelayNoHalfCloseFallback ensures Relay terminates when the underlying +// conns don't support CloseWrite (e.g. net.Pipe). Without the fallback to +// cancel-both-on-first-EOF, the second direction would block forever. +func TestRelayNoHalfCloseFallback(t *testing.T) { + a1, a2 := net.Pipe() + b1, b2 := net.Pipe() + defer a1.Close() + defer b1.Close() + + ctx := t.Context() + done := make(chan struct{}) + go func() { + Relay(ctx, a2, b2, Options{}) + close(done) + }() + + // Close peer A's side; a2's Read will return EOF. + require.NoError(t, a1.Close()) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("relay did not terminate when half-close is unsupported") + } +} + +// TestRelayIdleTimeout ensures the idle watchdog tears down a silent flow. +func TestRelayIdleTimeout(t *testing.T) { + peerA, relayA := tcpPair(t) + relayB, peerB := tcpPair(t) + defer peerA.Close() + defer peerB.Close() + + ctx := t.Context() + + start := time.Now() + done := make(chan struct{}) + go func() { + Relay(ctx, relayA, relayB, Options{IdleTimeout: 150 * time.Millisecond}) + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("relay did not close on idle") + } + + require.WithinDuration(t, start.Add(150*time.Millisecond), time.Now(), 500*time.Millisecond) +}