mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Improve TCP conn tracking (#3572)
This commit is contained in:
@@ -189,7 +189,7 @@ func (t *ICMPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
@@ -23,11 +23,11 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
TCPSyn uint8 = 0x02
|
||||
TCPAck uint8 = 0x10
|
||||
TCPFin uint8 = 0x01
|
||||
TCPSyn uint8 = 0x02
|
||||
TCPRst uint8 = 0x04
|
||||
TCPPush uint8 = 0x08
|
||||
TCPAck uint8 = 0x10
|
||||
TCPUrg uint8 = 0x20
|
||||
)
|
||||
|
||||
@@ -41,7 +41,7 @@ const (
|
||||
)
|
||||
|
||||
// TCPState represents the state of a TCP connection
|
||||
type TCPState int
|
||||
type TCPState int32
|
||||
|
||||
func (s TCPState) String() string {
|
||||
switch s {
|
||||
@@ -89,22 +89,25 @@ const (
|
||||
// TCPConnTrack represents a TCP connection state
|
||||
type TCPConnTrack struct {
|
||||
BaseConnTrack
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
State TCPState
|
||||
established atomic.Bool
|
||||
tombstone atomic.Bool
|
||||
sync.RWMutex
|
||||
SourcePort uint16
|
||||
DestPort uint16
|
||||
state atomic.Int32
|
||||
tombstone atomic.Bool
|
||||
}
|
||||
|
||||
// IsEstablished safely checks if connection is established
|
||||
func (t *TCPConnTrack) IsEstablished() bool {
|
||||
return t.established.Load()
|
||||
// GetState safely retrieves the current state
|
||||
func (t *TCPConnTrack) GetState() TCPState {
|
||||
return TCPState(t.state.Load())
|
||||
}
|
||||
|
||||
// SetEstablished safely sets the established state
|
||||
func (t *TCPConnTrack) SetEstablished(state bool) {
|
||||
t.established.Store(state)
|
||||
// SetState safely updates the current state
|
||||
func (t *TCPConnTrack) SetState(state TCPState) {
|
||||
t.state.Store(int32(state))
|
||||
}
|
||||
|
||||
// CompareAndSwapState atomically changes the state from old to new if current == old
|
||||
func (t *TCPConnTrack) CompareAndSwapState(old, newState TCPState) bool {
|
||||
return t.state.CompareAndSwap(int32(old), int32(newState))
|
||||
}
|
||||
|
||||
// IsTombstone safely checks if the connection is marked for deletion
|
||||
@@ -125,13 +128,17 @@ type TCPTracker struct {
|
||||
cleanupTicker *time.Ticker
|
||||
tickerCancel context.CancelFunc
|
||||
timeout time.Duration
|
||||
waitTimeout time.Duration
|
||||
flowLogger nftypes.FlowLogger
|
||||
}
|
||||
|
||||
// NewTCPTracker creates a new TCP connection tracker
|
||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
|
||||
waitTimeout := TimeWaitTimeout
|
||||
if timeout == 0 {
|
||||
timeout = DefaultTCPTimeout
|
||||
} else {
|
||||
waitTimeout = timeout / 45
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -142,6 +149,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||
tickerCancel: cancel,
|
||||
timeout: timeout,
|
||||
waitTimeout: waitTimeout,
|
||||
flowLogger: flowLogger,
|
||||
}
|
||||
|
||||
@@ -149,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
||||
return tracker
|
||||
}
|
||||
|
||||
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
@@ -162,12 +170,7 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
conn.Lock()
|
||||
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
|
||||
conn.Unlock()
|
||||
|
||||
conn.UpdateCounters(direction, size)
|
||||
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
return key, true
|
||||
}
|
||||
|
||||
@@ -175,7 +178,7 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
}
|
||||
|
||||
// TrackOutbound records an outbound TCP connection
|
||||
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
|
||||
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
|
||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
|
||||
// if (inverted direction) conn is not tracked, track this direction
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||
@@ -183,14 +186,14 @@ func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort u
|
||||
}
|
||||
|
||||
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||
}
|
||||
|
||||
// track is the common implementation for tracking both inbound and outbound connections
|
||||
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||
if exists {
|
||||
if exists || flags&TCPSyn == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -205,12 +208,11 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
DestPort: dstPort,
|
||||
}
|
||||
|
||||
conn.established.Store(false)
|
||||
conn.tombstone.Store(false)
|
||||
conn.state.Store(int32(TCPStateNew))
|
||||
|
||||
t.logger.Trace("New %s TCP connection: %s", direction, key)
|
||||
t.updateState(key, conn, flags, direction == nftypes.Egress)
|
||||
conn.UpdateCounters(direction, size)
|
||||
t.updateState(key, conn, flags, direction, size)
|
||||
|
||||
t.mutex.Lock()
|
||||
t.connections[key] = conn
|
||||
@@ -220,7 +222,7 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
||||
}
|
||||
|
||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
|
||||
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
|
||||
key := ConnKey{
|
||||
SrcIP: dstIP,
|
||||
DstIP: srcIP,
|
||||
@@ -232,134 +234,125 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
if !exists || conn.IsTombstone() {
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle RST flag specially - it always causes transition to closed
|
||||
if flags&TCPRst != 0 {
|
||||
return t.handleRst(key, conn, size)
|
||||
currentState := conn.GetState()
|
||||
if !t.isValidStateForFlags(currentState, flags) {
|
||||
t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||
// allow all flags for established for now
|
||||
if currentState == TCPStateEstablished {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
conn.Lock()
|
||||
t.updateState(key, conn, flags, false)
|
||||
isEstablished := conn.IsEstablished()
|
||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||
conn.Unlock()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
|
||||
return isEstablished || isValidState
|
||||
}
|
||||
|
||||
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, size int) bool {
|
||||
if conn.IsTombstone() {
|
||||
return true
|
||||
}
|
||||
|
||||
conn.Lock()
|
||||
conn.SetTombstone()
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
conn.Unlock()
|
||||
conn.UpdateCounters(nftypes.Ingress, size)
|
||||
|
||||
t.logger.Trace("TCP connection reset: %s", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.updateState(key, conn, flags, nftypes.Ingress, size)
|
||||
return true
|
||||
}
|
||||
|
||||
// updateState updates the TCP connection state based on flags
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
|
||||
conn.UpdateLastSeen()
|
||||
conn.UpdateCounters(packetDir, size)
|
||||
|
||||
state := conn.State
|
||||
defer func() {
|
||||
if state != conn.State {
|
||||
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
|
||||
currentState := conn.GetState()
|
||||
|
||||
if flags&TCPRst != 0 {
|
||||
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
switch state {
|
||||
var newState TCPState
|
||||
switch currentState {
|
||||
case TCPStateNew:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
conn.State = TCPStateSynSent
|
||||
if conn.Direction == nftypes.Egress {
|
||||
newState = TCPStateSynSent
|
||||
} else {
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateSynSent:
|
||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||
if isOutbound {
|
||||
conn.State = TCPStateEstablished
|
||||
conn.SetEstablished(true)
|
||||
if packetDir != conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
} else {
|
||||
// Simultaneous open
|
||||
conn.State = TCPStateSynReceived
|
||||
newState = TCPStateSynReceived
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateSynReceived:
|
||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||
conn.State = TCPStateEstablished
|
||||
conn.SetEstablished(true)
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateEstablished
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPFin != 0 {
|
||||
if isOutbound {
|
||||
conn.State = TCPStateFinWait1
|
||||
if packetDir == conn.Direction {
|
||||
newState = TCPStateFinWait1
|
||||
} else {
|
||||
conn.State = TCPStateCloseWait
|
||||
newState = TCPStateCloseWait
|
||||
}
|
||||
conn.SetEstablished(false)
|
||||
} else if flags&TCPRst != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetTombstone()
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
case TCPStateFinWait1:
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
conn.State = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
case flags&TCPAck != 0:
|
||||
conn.State = TCPStateFinWait2
|
||||
case flags&TCPRst != 0:
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetTombstone()
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
if packetDir != conn.Direction {
|
||||
switch {
|
||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPFin != 0:
|
||||
newState = TCPStateClosing
|
||||
case flags&TCPAck != 0:
|
||||
newState = TCPStateFinWait2
|
||||
}
|
||||
}
|
||||
|
||||
case TCPStateFinWait2:
|
||||
if flags&TCPFin != 0 {
|
||||
conn.State = TCPStateTimeWait
|
||||
|
||||
t.logger.Trace("TCP connection %s completed", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
newState = TCPStateTimeWait
|
||||
}
|
||||
|
||||
case TCPStateClosing:
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateTimeWait
|
||||
// Keep established = false from previous state
|
||||
|
||||
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
newState = TCPStateTimeWait
|
||||
}
|
||||
|
||||
case TCPStateCloseWait:
|
||||
if flags&TCPFin != 0 {
|
||||
conn.State = TCPStateLastAck
|
||||
newState = TCPStateLastAck
|
||||
}
|
||||
|
||||
case TCPStateLastAck:
|
||||
if flags&TCPAck != 0 {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetTombstone()
|
||||
newState = TCPStateClosed
|
||||
}
|
||||
}
|
||||
|
||||
// Send close event for gracefully closed connections
|
||||
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
|
||||
t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
|
||||
|
||||
switch newState {
|
||||
case TCPStateTimeWait:
|
||||
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
|
||||
case TCPStateClosed:
|
||||
conn.SetTombstone()
|
||||
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
t.logger.Trace("TCP connection %s closed gracefully", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -369,18 +362,22 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
if !isValidFlagCombination(flags) {
|
||||
return false
|
||||
}
|
||||
if flags&TCPRst != 0 {
|
||||
if state == TCPStateSynSent {
|
||||
return flags&TCPAck != 0
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
switch state {
|
||||
case TCPStateNew:
|
||||
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
||||
case TCPStateSynSent:
|
||||
// TODO: support simultaneous open
|
||||
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
||||
case TCPStateSynReceived:
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateEstablished:
|
||||
if flags&TCPRst != 0 {
|
||||
return true
|
||||
}
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateFinWait1:
|
||||
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||
@@ -397,9 +394,7 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
case TCPStateLastAck:
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateClosed:
|
||||
// Accept retransmitted ACKs in closed state
|
||||
// This is important because the final ACK might be lost
|
||||
// and the peer will retransmit their FIN-ACK
|
||||
// Accept retransmitted ACKs in closed state, the final ACK might be lost and the peer will retransmit their FIN-ACK
|
||||
return flags&TCPAck != 0
|
||||
}
|
||||
return false
|
||||
@@ -430,23 +425,24 @@ func (t *TCPTracker) cleanup() {
|
||||
}
|
||||
|
||||
var timeout time.Duration
|
||||
switch {
|
||||
case conn.State == TCPStateTimeWait:
|
||||
timeout = TimeWaitTimeout
|
||||
case conn.IsEstablished():
|
||||
currentState := conn.GetState()
|
||||
switch currentState {
|
||||
case TCPStateTimeWait:
|
||||
timeout = t.waitTimeout
|
||||
case TCPStateEstablished:
|
||||
timeout = t.timeout
|
||||
default:
|
||||
timeout = TCPHandshakeTimeout
|
||||
}
|
||||
|
||||
if conn.timeoutExceeded(timeout) {
|
||||
// Return IPs to pool
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
|
||||
t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
|
||||
// event already handled by state change
|
||||
if conn.State != TCPStateTimeWait {
|
||||
if currentState != TCPStateTimeWait {
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
}
|
||||
|
||||
83
client/firewall/uspfilter/conntrack/tcp_bench_test.go
Normal file
83
client/firewall/uspfilter/conntrack/tcp_bench_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck|TCPSyn, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
if i%2 == 0 {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
} else {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck|TCPSyn, 0)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
// Wait for connections to expire
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.cleanup()
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -124,9 +125,6 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
// Receive RST
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
require.True(t, valid, "RST should be allowed for established connection")
|
||||
|
||||
// Connection is logically dead but we don't enforce blocking subsequent packets
|
||||
// The connection will be cleaned up by timeout
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -217,97 +215,446 @@ func TestRSTHandling(t *testing.T) {
|
||||
conn := tracker.connections[key]
|
||||
if tt.wantValid {
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateClosed, conn.State)
|
||||
require.False(t, conn.IsEstablished())
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTCPRetransmissions(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test SYN retransmission
|
||||
t.Run("SYN Retransmission", func(t *testing.T) {
|
||||
// Initial SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Retransmit SYN (should not affect the state machine)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Verify we're still in SYN-SENT state
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateSynSent, conn.GetState())
|
||||
|
||||
// Complete the handshake
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// Verify we're in ESTABLISHED state
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
})
|
||||
|
||||
// Test ACK retransmission in established state
|
||||
t.Run("ACK Retransmission", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
// Retransmit ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// State should remain ESTABLISHED
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
})
|
||||
|
||||
// Test FIN retransmission
|
||||
t.Run("FIN Retransmission", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Retransmit FIN (should not change state)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPDataTransfer(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Data Transfer", func(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||
|
||||
// Receive ACK for data
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 100)
|
||||
require.True(t, valid)
|
||||
|
||||
// Receive data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1500)
|
||||
require.True(t, valid)
|
||||
|
||||
// Send ACK for received data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||
|
||||
// State should remain ESTABLISHED
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
assert.Equal(t, uint64(1300), conn.BytesTx.Load())
|
||||
assert.Equal(t, uint64(1700), conn.BytesRx.Load())
|
||||
assert.Equal(t, uint64(4), conn.PacketsTx.Load())
|
||||
assert.Equal(t, uint64(3), conn.PacketsRx.Load())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPHalfClosedConnections(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test half-closed connection: local end closes, remote end continues sending data
|
||||
t.Run("Local Close, Remote Data", func(t *testing.T) {
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Send FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||
|
||||
// Receive ACK for FIN
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||
|
||||
// Remote end can still send data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1000)
|
||||
require.True(t, valid)
|
||||
|
||||
// We can still ACK their data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// Receive FIN from remote end
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// Send final ACK
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
// State should remain TIME-WAIT (waiting for possible retransmissions)
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
})
|
||||
|
||||
// Test half-closed connection: remote end closes, local end continues sending data
|
||||
t.Run("Remote Close, Local Data", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
|
||||
// Establish connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Receive FIN from remote
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||
|
||||
// We can still send data
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||
|
||||
// Remote can still ACK our data
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
|
||||
// Send our FIN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||
|
||||
// Receive final ACK
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
require.True(t, valid)
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPAbnormalSequences(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Test handling of unsolicited RST in various states
|
||||
t.Run("Unsolicited RST in SYN-SENT", func(t *testing.T) {
|
||||
// Send SYN
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
|
||||
// Receive unsolicited RST (without proper ACK)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||
require.False(t, valid, "RST without proper ACK in SYN-SENT should be rejected")
|
||||
|
||||
// Receive RST with proper ACK
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||
require.True(t, valid, "RST with proper ACK in SYN-SENT should be accepted")
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||
require.True(t, conn.IsTombstone())
|
||||
})
|
||||
}
|
||||
|
||||
func TestTCPTimeoutHandling(t *testing.T) {
|
||||
// Create tracker with a very short timeout for testing
|
||||
shortTimeout := 100 * time.Millisecond
|
||||
tracker := NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
srcPort := uint16(12345)
|
||||
dstPort := uint16(80)
|
||||
|
||||
t.Run("Connection Timeout", func(t *testing.T) {
|
||||
// Establish a connection
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
// Get connection object
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
|
||||
// Wait for the connection to timeout
|
||||
time.Sleep(2 * shortTimeout)
|
||||
|
||||
// Force cleanup
|
||||
tracker.cleanup()
|
||||
|
||||
// Connection should be removed
|
||||
_, exists := tracker.connections[key]
|
||||
require.False(t, exists, "Connection should be removed after timeout")
|
||||
})
|
||||
|
||||
t.Run("TIME_WAIT Timeout", func(t *testing.T) {
|
||||
tracker = NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||
|
||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
SrcPort: srcPort,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
conn := tracker.connections[key]
|
||||
require.NotNil(t, conn)
|
||||
|
||||
// Complete the connection close to enter TIME_WAIT
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
|
||||
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||
|
||||
// TIME_WAIT should have its own timeout value (usually 2*MSL)
|
||||
// For the test, we're using a short timeout
|
||||
time.Sleep(2 * shortTimeout)
|
||||
|
||||
tracker.cleanup()
|
||||
|
||||
// Connection should be removed
|
||||
_, exists := tracker.connections[key]
|
||||
require.False(t, exists, "Connection should be removed after TIME_WAIT timeout")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSynFlood(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||
basePort := uint16(10000)
|
||||
dstPort := uint16(80)
|
||||
|
||||
// Create a large number of SYN packets to simulate a SYN flood
|
||||
for i := uint16(0); i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, basePort+i, dstPort, TCPSyn, 0)
|
||||
}
|
||||
|
||||
// Check that we're tracking all connections
|
||||
require.Equal(t, 1000, len(tracker.connections))
|
||||
|
||||
// Now simulate SYN timeout
|
||||
var oldConns int
|
||||
tracker.mutex.Lock()
|
||||
for _, conn := range tracker.connections {
|
||||
if conn.GetState() == TCPStateSynSent {
|
||||
// Make the connection appear old
|
||||
conn.lastSeen.Store(time.Now().Add(-TCPHandshakeTimeout - time.Second).UnixNano())
|
||||
oldConns++
|
||||
}
|
||||
}
|
||||
tracker.mutex.Unlock()
|
||||
require.Equal(t, 1000, oldConns)
|
||||
|
||||
// Run cleanup
|
||||
tracker.cleanup()
|
||||
|
||||
// Check that stale connections were cleaned up
|
||||
require.Equal(t, 0, len(tracker.connections))
|
||||
}
|
||||
|
||||
func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
clientIP := netip.MustParseAddr("100.64.0.1")
|
||||
serverIP := netip.MustParseAddr("100.64.0.2")
|
||||
clientPort := uint16(12345)
|
||||
serverPort := uint16(80)
|
||||
|
||||
// 1. Client sends SYN (we receive it as inbound)
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
|
||||
|
||||
key := ConnKey{
|
||||
SrcIP: clientIP,
|
||||
DstIP: serverIP,
|
||||
SrcPort: clientPort,
|
||||
DstPort: serverPort,
|
||||
}
|
||||
|
||||
tracker.mutex.RLock()
|
||||
conn := tracker.connections[key]
|
||||
tracker.mutex.RUnlock()
|
||||
|
||||
require.NotNil(t, conn)
|
||||
require.Equal(t, TCPStateSynReceived, conn.GetState(), "Connection should be in SYN-RECEIVED state after inbound SYN")
|
||||
|
||||
// 2. Server sends SYN-ACK response
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||
|
||||
// 3. Client sends ACK to complete handshake
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
|
||||
|
||||
// 4. Test data transfer
|
||||
// Client sends data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
|
||||
|
||||
// Server sends ACK for data
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||
|
||||
// Server sends data
|
||||
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
|
||||
|
||||
// Client sends ACK for data
|
||||
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||
|
||||
// Verify state and counters
|
||||
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||
assert.Equal(t, uint64(1300), conn.BytesRx.Load()) // 3 packets * 100 + 1000 data
|
||||
assert.Equal(t, uint64(1700), conn.BytesTx.Load()) // 2 packets * 100 + 1500 data
|
||||
assert.Equal(t, uint64(4), conn.PacketsRx.Load()) // SYN, ACK, Data
|
||||
assert.Equal(t, uint64(3), conn.PacketsTx.Load()) // SYN-ACK, Data
|
||||
}
|
||||
|
||||
// Helper to establish a TCP connection
|
||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||
t.Helper()
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||
require.True(t, valid, "SYN-ACK should be allowed")
|
||||
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||
}
|
||||
|
||||
func BenchmarkTCPTracker(b *testing.B) {
|
||||
b.Run("TrackOutbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("IsValidInbound", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
// Pre-populate some connections
|
||||
for i := 0; i < 1000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||
defer tracker.Close()
|
||||
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
if i%2 == 0 {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||
} else {
|
||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Benchmark connection cleanup
|
||||
func BenchmarkCleanup(b *testing.B) {
|
||||
b.Run("TCPCleanup", func(b *testing.B) {
|
||||
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
|
||||
defer tracker.Close()
|
||||
|
||||
// Pre-populate with expired connections
|
||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||
for i := 0; i < 10000; i++ {
|
||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||
}
|
||||
|
||||
// Wait for connections to expire
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
tracker.cleanup()
|
||||
}
|
||||
})
|
||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||
}
|
||||
|
||||
@@ -165,7 +165,7 @@ func (t *UDPTracker) cleanup() {
|
||||
if conn.timeoutExceeded(t.timeout) {
|
||||
delete(t.connections, key)
|
||||
|
||||
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user