diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index ae9926795..c8ea159da 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -189,7 +189,7 @@ func (t *ICMPTracker) cleanup() { if conn.timeoutExceeded(t.timeout) { delete(t.connections, key) - t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", + t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) t.sendEvent(nftypes.TypeEnd, conn, nil) } diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 8109fff41..2d42ea32e 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -23,11 +23,11 @@ const ( ) const ( - TCPSyn uint8 = 0x02 - TCPAck uint8 = 0x10 TCPFin uint8 = 0x01 + TCPSyn uint8 = 0x02 TCPRst uint8 = 0x04 TCPPush uint8 = 0x08 + TCPAck uint8 = 0x10 TCPUrg uint8 = 0x20 ) @@ -41,7 +41,7 @@ const ( ) // TCPState represents the state of a TCP connection -type TCPState int +type TCPState int32 func (s TCPState) String() string { switch s { @@ -89,22 +89,25 @@ const ( // TCPConnTrack represents a TCP connection state type TCPConnTrack struct { BaseConnTrack - SourcePort uint16 - DestPort uint16 - State TCPState - established atomic.Bool - tombstone atomic.Bool - sync.RWMutex + SourcePort uint16 + DestPort uint16 + state atomic.Int32 + tombstone atomic.Bool } -// IsEstablished safely checks if connection is established -func (t *TCPConnTrack) IsEstablished() bool { - return t.established.Load() +// GetState safely retrieves the current state +func (t *TCPConnTrack) GetState() TCPState { + return TCPState(t.state.Load()) } -// SetEstablished safely sets the established state -func (t *TCPConnTrack) SetEstablished(state bool) { - t.established.Store(state) +// SetState safely updates the current state +func (t *TCPConnTrack) SetState(state TCPState) { + t.state.Store(int32(state)) +} + +// CompareAndSwapState atomically changes the state from old to new if current == old +func (t *TCPConnTrack) CompareAndSwapState(old, newState TCPState) bool { + return t.state.CompareAndSwap(int32(old), int32(newState)) } // IsTombstone safely checks if the connection is marked for deletion @@ -125,13 +128,17 @@ type TCPTracker struct { cleanupTicker *time.Ticker tickerCancel context.CancelFunc timeout time.Duration + waitTimeout time.Duration flowLogger nftypes.FlowLogger } // NewTCPTracker creates a new TCP connection tracker func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker { + waitTimeout := TimeWaitTimeout if timeout == 0 { timeout = DefaultTCPTimeout + } else { + waitTimeout = timeout / 45 } ctx, cancel := context.WithCancel(context.Background()) @@ -142,6 +149,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp cleanupTicker: time.NewTicker(TCPCleanupInterval), tickerCancel: cancel, timeout: timeout, + waitTimeout: waitTimeout, flowLogger: flowLogger, } @@ -149,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -162,12 +170,7 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort t.mutex.RUnlock() if exists { - conn.Lock() - t.updateState(key, conn, flags, conn.Direction == nftypes.Egress) - conn.Unlock() - - conn.UpdateCounters(direction, size) - + t.updateState(key, conn, flags, direction, size) return key, true } @@ -175,7 +178,7 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort } // TrackOutbound records an outbound TCP connection -func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) { +func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) { if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists { // if (inverted direction) conn is not tracked, track this direction t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size) @@ -183,14 +186,14 @@ func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort u } // TrackInbound processes an inbound TCP packet and updates connection state -func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) { +func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) { t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) } // track is the common implementation for tracking both inbound and outbound connections -func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { +func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) - if exists { + if exists || flags&TCPSyn == 0 { return } @@ -205,12 +208,11 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d DestPort: dstPort, } - conn.established.Store(false) conn.tombstone.Store(false) + conn.state.Store(int32(TCPStateNew)) t.logger.Trace("New %s TCP connection: %s", direction, key) - t.updateState(key, conn, flags, direction == nftypes.Egress) - conn.UpdateCounters(direction, size) + t.updateState(key, conn, flags, direction, size) t.mutex.Lock() t.connections[key] = conn @@ -220,7 +222,7 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d } // IsValidInbound checks if an inbound TCP packet matches a tracked connection -func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool { +func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool { key := ConnKey{ SrcIP: dstIP, DstIP: srcIP, @@ -232,134 +234,125 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort conn, exists := t.connections[key] t.mutex.RUnlock() - if !exists { + if !exists || conn.IsTombstone() { return false } - // Handle RST flag specially - it always causes transition to closed - if flags&TCPRst != 0 { - return t.handleRst(key, conn, size) + currentState := conn.GetState() + if !t.isValidStateForFlags(currentState, flags) { + t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key) + // allow all flags for established for now + if currentState == TCPStateEstablished { + return true + } + return false } - conn.Lock() - t.updateState(key, conn, flags, false) - isEstablished := conn.IsEstablished() - isValidState := t.isValidStateForFlags(conn.State, flags) - conn.Unlock() - conn.UpdateCounters(nftypes.Ingress, size) - - return isEstablished || isValidState -} - -func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, size int) bool { - if conn.IsTombstone() { - return true - } - - conn.Lock() - conn.SetTombstone() - conn.State = TCPStateClosed - conn.SetEstablished(false) - conn.Unlock() - conn.UpdateCounters(nftypes.Ingress, size) - - t.logger.Trace("TCP connection reset: %s", key) - t.sendEvent(nftypes.TypeEnd, conn, nil) + t.updateState(key, conn, flags, nftypes.Ingress, size) return true } // updateState updates the TCP connection state based on flags -func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) { +func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) { conn.UpdateLastSeen() + conn.UpdateCounters(packetDir, size) - state := conn.State - defer func() { - if state != conn.State { - t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State) + currentState := conn.GetState() + + if flags&TCPRst != 0 { + if conn.CompareAndSwapState(currentState, TCPStateClosed) { + conn.SetTombstone() + t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]", + key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + t.sendEvent(nftypes.TypeEnd, conn, nil) } - }() + return + } - switch state { + var newState TCPState + switch currentState { case TCPStateNew: if flags&TCPSyn != 0 && flags&TCPAck == 0 { - conn.State = TCPStateSynSent + if conn.Direction == nftypes.Egress { + newState = TCPStateSynSent + } else { + newState = TCPStateSynReceived + } } case TCPStateSynSent: if flags&TCPSyn != 0 && flags&TCPAck != 0 { - if isOutbound { - conn.State = TCPStateEstablished - conn.SetEstablished(true) + if packetDir != conn.Direction { + newState = TCPStateEstablished } else { // Simultaneous open - conn.State = TCPStateSynReceived + newState = TCPStateSynReceived } } case TCPStateSynReceived: if flags&TCPAck != 0 && flags&TCPSyn == 0 { - conn.State = TCPStateEstablished - conn.SetEstablished(true) + if packetDir == conn.Direction { + newState = TCPStateEstablished + } } case TCPStateEstablished: if flags&TCPFin != 0 { - if isOutbound { - conn.State = TCPStateFinWait1 + if packetDir == conn.Direction { + newState = TCPStateFinWait1 } else { - conn.State = TCPStateCloseWait + newState = TCPStateCloseWait } - conn.SetEstablished(false) - } else if flags&TCPRst != 0 { - conn.State = TCPStateClosed - conn.SetTombstone() - t.sendEvent(nftypes.TypeEnd, conn, nil) } case TCPStateFinWait1: - switch { - case flags&TCPFin != 0 && flags&TCPAck != 0: - conn.State = TCPStateClosing - case flags&TCPFin != 0: - conn.State = TCPStateFinWait2 - case flags&TCPAck != 0: - conn.State = TCPStateFinWait2 - case flags&TCPRst != 0: - conn.State = TCPStateClosed - conn.SetTombstone() - t.sendEvent(nftypes.TypeEnd, conn, nil) + if packetDir != conn.Direction { + switch { + case flags&TCPFin != 0 && flags&TCPAck != 0: + newState = TCPStateClosing + case flags&TCPFin != 0: + newState = TCPStateClosing + case flags&TCPAck != 0: + newState = TCPStateFinWait2 + } } case TCPStateFinWait2: if flags&TCPFin != 0 { - conn.State = TCPStateTimeWait - - t.logger.Trace("TCP connection %s completed", key) - t.sendEvent(nftypes.TypeEnd, conn, nil) + newState = TCPStateTimeWait } case TCPStateClosing: if flags&TCPAck != 0 { - conn.State = TCPStateTimeWait - // Keep established = false from previous state - - t.logger.Trace("TCP connection %s closed (simultaneous)", key) - t.sendEvent(nftypes.TypeEnd, conn, nil) + newState = TCPStateTimeWait } case TCPStateCloseWait: if flags&TCPFin != 0 { - conn.State = TCPStateLastAck + newState = TCPStateLastAck } case TCPStateLastAck: if flags&TCPAck != 0 { - conn.State = TCPStateClosed - conn.SetTombstone() + newState = TCPStateClosed + } + } - // Send close event for gracefully closed connections + if newState != 0 && conn.CompareAndSwapState(currentState, newState) { + t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir) + + switch newState { + case TCPStateTimeWait: + t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]", + key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) + t.sendEvent(nftypes.TypeEnd, conn, nil) + + case TCPStateClosed: + conn.SetTombstone() + t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]", + key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) t.sendEvent(nftypes.TypeEnd, conn, nil) - t.logger.Trace("TCP connection %s closed gracefully", key) } } } @@ -369,18 +362,22 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { if !isValidFlagCombination(flags) { return false } + if flags&TCPRst != 0 { + if state == TCPStateSynSent { + return flags&TCPAck != 0 + } + return true + } switch state { case TCPStateNew: return flags&TCPSyn != 0 && flags&TCPAck == 0 case TCPStateSynSent: + // TODO: support simultaneous open return flags&TCPSyn != 0 && flags&TCPAck != 0 case TCPStateSynReceived: return flags&TCPAck != 0 case TCPStateEstablished: - if flags&TCPRst != 0 { - return true - } return flags&TCPAck != 0 case TCPStateFinWait1: return flags&TCPFin != 0 || flags&TCPAck != 0 @@ -397,9 +394,7 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { case TCPStateLastAck: return flags&TCPAck != 0 case TCPStateClosed: - // Accept retransmitted ACKs in closed state - // This is important because the final ACK might be lost - // and the peer will retransmit their FIN-ACK + // Accept retransmitted ACKs in closed state, the final ACK might be lost and the peer will retransmit their FIN-ACK return flags&TCPAck != 0 } return false @@ -430,23 +425,24 @@ func (t *TCPTracker) cleanup() { } var timeout time.Duration - switch { - case conn.State == TCPStateTimeWait: - timeout = TimeWaitTimeout - case conn.IsEstablished(): + currentState := conn.GetState() + switch currentState { + case TCPStateTimeWait: + timeout = t.waitTimeout + case TCPStateEstablished: timeout = t.timeout default: timeout = TCPHandshakeTimeout } if conn.timeoutExceeded(timeout) { - // Return IPs to pool delete(t.connections, key) - t.logger.Trace("Cleaned up timed-out TCP connection %s", key) + t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]", + key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) // event already handled by state change - if conn.State != TCPStateTimeWait { + if currentState != TCPStateTimeWait { t.sendEvent(nftypes.TypeEnd, conn, nil) } } diff --git a/client/firewall/uspfilter/conntrack/tcp_bench_test.go b/client/firewall/uspfilter/conntrack/tcp_bench_test.go new file mode 100644 index 000000000..9ecb3af9f --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp_bench_test.go @@ -0,0 +1,83 @@ +package conntrack + +import ( + "net/netip" + "testing" + "time" +) + +func BenchmarkTCPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("192.168.1.1") + dstIP := netip.MustParseAddr("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("192.168.1.1") + dstIP := netip.MustParseAddr("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck|TCPSyn, 0) + } + }) + + b.Run("ConcurrentAccess", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("192.168.1.1") + dstIP := netip.MustParseAddr("192.168.1.2") + + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + if i%2 == 0 { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0) + } else { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck|TCPSyn, 0) + } + i++ + } + }) + }) +} + +// Benchmark connection cleanup +func BenchmarkCleanup(b *testing.B) { + b.Run("TCPCleanup", func(b *testing.B) { + tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) + defer tracker.Close() + + // Pre-populate with expired connections + srcIP := netip.MustParseAddr("192.168.1.1") + dstIP := netip.MustParseAddr("192.168.1.2") + for i := 0; i < 10000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0) + } + + // Wait for connections to expire + time.Sleep(200 * time.Millisecond) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.cleanup() + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index 96558583d..d01a8db4f 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -124,9 +125,6 @@ func TestTCPStateMachine(t *testing.T) { // Receive RST valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0) require.True(t, valid, "RST should be allowed for established connection") - - // Connection is logically dead but we don't enforce blocking subsequent packets - // The connection will be cleaned up by timeout }, }, { @@ -217,97 +215,446 @@ func TestRSTHandling(t *testing.T) { conn := tracker.connections[key] if tt.wantValid { require.NotNil(t, conn) - require.Equal(t, TCPStateClosed, conn.State) - require.False(t, conn.IsEstablished()) + require.Equal(t, TCPStateClosed, conn.GetState()) } }) } } +func TestTCPRetransmissions(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + // Test SYN retransmission + t.Run("SYN Retransmission", func(t *testing.T) { + // Initial SYN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0) + + // Retransmit SYN (should not affect the state machine) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0) + + // Verify we're still in SYN-SENT state + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn := tracker.connections[key] + require.NotNil(t, conn) + require.Equal(t, TCPStateSynSent, conn.GetState()) + + // Complete the handshake + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0) + require.True(t, valid) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + + // Verify we're in ESTABLISHED state + require.Equal(t, TCPStateEstablished, conn.GetState()) + }) + + // Test ACK retransmission in established state + t.Run("ACK Retransmission", func(t *testing.T) { + tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + + // Establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Get connection object + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn := tracker.connections[key] + require.NotNil(t, conn) + require.Equal(t, TCPStateEstablished, conn.GetState()) + + // Retransmit ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + + // State should remain ESTABLISHED + require.Equal(t, TCPStateEstablished, conn.GetState()) + }) + + // Test FIN retransmission + t.Run("FIN Retransmission", func(t *testing.T) { + tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + + // Establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Get connection object + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn := tracker.connections[key] + require.NotNil(t, conn) + + // Send FIN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + require.Equal(t, TCPStateFinWait1, conn.GetState()) + + // Retransmit FIN (should not change state) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + require.Equal(t, TCPStateFinWait1, conn.GetState()) + + // Receive ACK for FIN + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + require.True(t, valid) + require.Equal(t, TCPStateFinWait2, conn.GetState()) + }) +} + +func TestTCPDataTransfer(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + t.Run("Data Transfer", func(t *testing.T) { + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Get connection object + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn := tracker.connections[key] + require.NotNil(t, conn) + + // Send data + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000) + + // Receive ACK for data + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 100) + require.True(t, valid) + + // Receive data + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1500) + require.True(t, valid) + + // Send ACK for received data + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100) + + // State should remain ESTABLISHED + require.Equal(t, TCPStateEstablished, conn.GetState()) + + assert.Equal(t, uint64(1300), conn.BytesTx.Load()) + assert.Equal(t, uint64(1700), conn.BytesRx.Load()) + assert.Equal(t, uint64(4), conn.PacketsTx.Load()) + assert.Equal(t, uint64(3), conn.PacketsRx.Load()) + }) +} + +func TestTCPHalfClosedConnections(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + // Test half-closed connection: local end closes, remote end continues sending data + t.Run("Local Close, Remote Data", func(t *testing.T) { + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn := tracker.connections[key] + require.NotNil(t, conn) + + // Send FIN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + require.Equal(t, TCPStateFinWait1, conn.GetState()) + + // Receive ACK for FIN + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + require.True(t, valid) + require.Equal(t, TCPStateFinWait2, conn.GetState()) + + // Remote end can still send data + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1000) + require.True(t, valid) + + // We can still ACK their data + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + + // Receive FIN from remote end + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + require.True(t, valid) + require.Equal(t, TCPStateTimeWait, conn.GetState()) + + // Send final ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + + // State should remain TIME-WAIT (waiting for possible retransmissions) + require.Equal(t, TCPStateTimeWait, conn.GetState()) + }) + + // Test half-closed connection: remote end closes, local end continues sending data + t.Run("Remote Close, Local Data", func(t *testing.T) { + tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + + // Establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Get connection object + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn := tracker.connections[key] + require.NotNil(t, conn) + + // Receive FIN from remote + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + require.True(t, valid) + require.Equal(t, TCPStateCloseWait, conn.GetState()) + + // We can still send data + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000) + + // Remote can still ACK our data + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + require.True(t, valid) + + // Send our FIN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + require.Equal(t, TCPStateLastAck, conn.GetState()) + + // Receive final ACK + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + require.True(t, valid) + require.Equal(t, TCPStateClosed, conn.GetState()) + }) +} + +func TestTCPAbnormalSequences(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + // Test handling of unsolicited RST in various states + t.Run("Unsolicited RST in SYN-SENT", func(t *testing.T) { + // Send SYN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0) + + // Receive unsolicited RST (without proper ACK) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0) + require.False(t, valid, "RST without proper ACK in SYN-SENT should be rejected") + + // Receive RST with proper ACK + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0) + require.True(t, valid, "RST with proper ACK in SYN-SENT should be accepted") + + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn := tracker.connections[key] + require.Equal(t, TCPStateClosed, conn.GetState()) + require.True(t, conn.IsTombstone()) + }) +} + +func TestTCPTimeoutHandling(t *testing.T) { + // Create tracker with a very short timeout for testing + shortTimeout := 100 * time.Millisecond + tracker := NewTCPTracker(shortTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + t.Run("Connection Timeout", func(t *testing.T) { + // Establish a connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Get connection object + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn := tracker.connections[key] + require.NotNil(t, conn) + require.Equal(t, TCPStateEstablished, conn.GetState()) + + // Wait for the connection to timeout + time.Sleep(2 * shortTimeout) + + // Force cleanup + tracker.cleanup() + + // Connection should be removed + _, exists := tracker.connections[key] + require.False(t, exists, "Connection should be removed after timeout") + }) + + t.Run("TIME_WAIT Timeout", func(t *testing.T) { + tracker = NewTCPTracker(shortTimeout, logger, flowLogger) + + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn := tracker.connections[key] + require.NotNil(t, conn) + + // Complete the connection close to enter TIME_WAIT + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + + require.Equal(t, TCPStateTimeWait, conn.GetState()) + + // TIME_WAIT should have its own timeout value (usually 2*MSL) + // For the test, we're using a short timeout + time.Sleep(2 * shortTimeout) + + tracker.cleanup() + + // Connection should be removed + _, exists := tracker.connections[key] + require.False(t, exists, "Connection should be removed after TIME_WAIT timeout") + }) +} + +func TestSynFlood(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + basePort := uint16(10000) + dstPort := uint16(80) + + // Create a large number of SYN packets to simulate a SYN flood + for i := uint16(0); i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, basePort+i, dstPort, TCPSyn, 0) + } + + // Check that we're tracking all connections + require.Equal(t, 1000, len(tracker.connections)) + + // Now simulate SYN timeout + var oldConns int + tracker.mutex.Lock() + for _, conn := range tracker.connections { + if conn.GetState() == TCPStateSynSent { + // Make the connection appear old + conn.lastSeen.Store(time.Now().Add(-TCPHandshakeTimeout - time.Second).UnixNano()) + oldConns++ + } + } + tracker.mutex.Unlock() + require.Equal(t, 1000, oldConns) + + // Run cleanup + tracker.cleanup() + + // Check that stale connections were cleaned up + require.Equal(t, 0, len(tracker.connections)) +} + +func TestTCPInboundInitiatedConnection(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + clientIP := netip.MustParseAddr("100.64.0.1") + serverIP := netip.MustParseAddr("100.64.0.2") + clientPort := uint16(12345) + serverPort := uint16(80) + + // 1. Client sends SYN (we receive it as inbound) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100) + + key := ConnKey{ + SrcIP: clientIP, + DstIP: serverIP, + SrcPort: clientPort, + DstPort: serverPort, + } + + tracker.mutex.RLock() + conn := tracker.connections[key] + tracker.mutex.RUnlock() + + require.NotNil(t, conn) + require.Equal(t, TCPStateSynReceived, conn.GetState(), "Connection should be in SYN-RECEIVED state after inbound SYN") + + // 2. Server sends SYN-ACK response + tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100) + + // 3. Client sends ACK to complete handshake + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion") + + // 4. Test data transfer + // Client sends data + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000) + + // Server sends ACK for data + tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100) + + // Server sends data + tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500) + + // Client sends ACK for data + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + + // Verify state and counters + require.Equal(t, TCPStateEstablished, conn.GetState()) + assert.Equal(t, uint64(1300), conn.BytesRx.Load()) // 3 packets * 100 + 1000 data + assert.Equal(t, uint64(1700), conn.BytesTx.Load()) // 2 packets * 100 + 1500 data + assert.Equal(t, uint64(4), conn.PacketsRx.Load()) // SYN, ACK, Data + assert.Equal(t, uint64(3), conn.PacketsTx.Load()) // SYN-ACK, Data +} + // Helper to establish a TCP connection func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) { t.Helper() - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100) - valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100) require.True(t, valid, "SYN-ACK should be allowed") - tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) -} - -func BenchmarkTCPTracker(b *testing.B) { - b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) - defer tracker.Close() - - srcIP := netip.MustParseAddr("192.168.1.1") - dstIP := netip.MustParseAddr("192.168.1.2") - - b.ResetTimer() - for i := 0; i < b.N; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0) - } - }) - - b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) - defer tracker.Close() - - srcIP := netip.MustParseAddr("192.168.1.1") - dstIP := netip.MustParseAddr("192.168.1.2") - - // Pre-populate some connections - for i := 0; i < 1000; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0) - } - }) - - b.Run("ConcurrentAccess", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) - defer tracker.Close() - - srcIP := netip.MustParseAddr("192.168.1.1") - dstIP := netip.MustParseAddr("192.168.1.2") - - b.RunParallel(func(pb *testing.PB) { - i := 0 - for pb.Next() { - if i%2 == 0 { - tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0) - } else { - tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0) - } - i++ - } - }) - }) -} - -// Benchmark connection cleanup -func BenchmarkCleanup(b *testing.B) { - b.Run("TCPCleanup", func(b *testing.B) { - tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing - defer tracker.Close() - - // Pre-populate with expired connections - srcIP := netip.MustParseAddr("192.168.1.1") - dstIP := netip.MustParseAddr("192.168.1.2") - for i := 0; i < 10000; i++ { - tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0) - } - - // Wait for connections to expire - time.Sleep(200 * time.Millisecond) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - tracker.cleanup() - } - }) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100) } diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index d72988d27..000eaa1b6 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -165,7 +165,7 @@ func (t *UDPTracker) cleanup() { if conn.timeoutExceeded(t.timeout) { delete(t.connections, key) - t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]", + t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]", key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load()) t.sendEvent(nftypes.TypeEnd, conn, nil) }