diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 992c10769..335a3abab 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -115,6 +115,17 @@ func (t *TCPConnTrack) IsTombstone() bool { return t.tombstone.Load() } +// IsSupersededBy returns true if this connection should be replaced by a new one +// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT +// connections are superseded by a pure SYN (a new connection attempt for the same +// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191). +func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool { + if t.tombstone.Load() { + return true + } + return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait +} + // SetTombstone safely marks the connection for deletion func (t *TCPConnTrack) SetTombstone() { t.tombstone.Store(true) @@ -169,7 +180,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui conn, exists := t.connections[key] t.mutex.RUnlock() - if exists && !conn.IsTombstone() { + if exists && !conn.IsSupersededBy(flags) { t.updateState(key, conn, flags, direction, size) return key, uint16(conn.DNATOrigPort.Load()), true } @@ -241,7 +252,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui conn, exists := t.connections[key] t.mutex.RUnlock() - if !exists || conn.IsTombstone() { + if !exists || conn.IsSupersededBy(flags) { return false } diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index 52bc3858f..f46c5c1ab 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -636,6 +636,110 @@ func TestTCPPortReuseTombstone(t *testing.T) { }) } +func TestTCPPortReuseTimeWait(t *testing.T) { + srcIP := netip.MustParseAddr("100.64.0.1") + dstIP := netip.MustParseAddr("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + // Establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Active close: client (outbound initiator) sends FIN first + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + conn := tracker.connections[key] + require.Equal(t, TCPStateFinWait1, conn.GetState()) + + // Server ACKs the FIN + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + require.True(t, valid) + require.Equal(t, TCPStateFinWait2, conn.GetState()) + + // Server sends its own FIN + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + require.True(t, valid) + require.Equal(t, TCPStateTimeWait, conn.GetState()) + + // Client sends final ACK (TIME-WAIT stays, not tombstoned) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned") + + // New outbound SYN on the same port (port reuse during TIME-WAIT) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100) + + // Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection + newConn := tracker.connections[key] + require.NotNil(t, newConn, "new connection should exist") + require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned") + require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT") + + // SYN-ACK for new connection should be valid + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100) + require.True(t, valid, "SYN-ACK for new connection should be accepted") + require.Equal(t, TCPStateEstablished, newConn.GetState()) + }) + + t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + // Establish outbound connection and close via active close → TIME-WAIT + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + + conn := tracker.connections[key] + require.Equal(t, TCPStateTimeWait, conn.GetState()) + + // Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false + // so the filter falls through to ACL check + TrackInbound (which creates + // a new connection via track() → updateIfExists skips TIME-WAIT for SYN) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0) + require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation") + + // Simulate what the filter does next: TrackInbound via the normal path + tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0) + + // The new inbound connection uses the inverted key (dst→src becomes src→dst in track) + invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort} + newConn := tracker.connections[invertedKey] + require.NotNil(t, newConn, "new inbound connection should be tracked") + require.Equal(t, TCPStateSynReceived, newConn.GetState()) + require.False(t, newConn.IsTombstone()) + }) + + t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger) + defer tracker.Close() + + key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort} + + // Establish and active close → TIME-WAIT + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0) + + conn := tracker.connections[key] + require.Equal(t, TCPStateTimeWait, conn.GetState()) + + // Late ACK retransmits during TIME-WAIT should still be accepted + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) + require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted") + }) +} + func TestTCPTimeoutHandling(t *testing.T) { // Create tracker with a very short timeout for testing shortTimeout := 100 * time.Millisecond