diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index 259f5e37e..75e2dfcb6 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -26,27 +26,38 @@ const ( reconnectRandomization = 0.5 ) +// listener abstracts a netlink conntrack connection for testability. +type listener interface { + Listen(evChan chan<- nfct.Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error) + Close() error +} + // ConnTrack manages kernel-based conntrack events type ConnTrack struct { flowLogger nftypes.FlowLogger iface nftypes.IFaceMapper - conn *nfct.Conn + conn listener mux sync.Mutex + dial func() (listener, error) instanceID uuid.UUID started bool done chan struct{} sysctlModified bool } +func defaultDial() (listener, error) { + return nfct.Dial(nil) +} + // New creates a new connection tracker that interfaces with the kernel's conntrack system func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack { return &ConnTrack{ flowLogger: flowLogger, iface: iface, instanceID: uuid.New(), - started: false, + dial: defaultDial, done: make(chan struct{}, 1), } } @@ -66,7 +77,7 @@ func (c *ConnTrack) Start(enableCounters bool) error { c.EnableAccounting() } - conn, err := nfct.Dial(nil) + conn, err := c.dial() if err != nil { return fmt.Errorf("dial conntrack: %w", err) } @@ -160,7 +171,7 @@ func (c *ConnTrack) reconnect() (chan nfct.Event, chan error) { case <-time.After(delay): } - conn, err := nfct.Dial(nil) + conn, err := c.dial() if err != nil { log.Warnf("reconnect conntrack dial: %v", err) continue diff --git a/client/internal/netflow/conntrack/conntrack_test.go b/client/internal/netflow/conntrack/conntrack_test.go new file mode 100644 index 000000000..aa4a1b2b7 --- /dev/null +++ b/client/internal/netflow/conntrack/conntrack_test.go @@ -0,0 +1,198 @@ +//go:build linux && !android + +package conntrack + +import ( + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + nfct "github.com/ti-mo/conntrack" + "github.com/ti-mo/netfilter" +) + +type mockListener struct { + mu sync.Mutex + errChan chan error + closed atomic.Bool + closedCh chan struct{} +} + +func newMockListener() *mockListener { + return &mockListener{ + errChan: make(chan error, 1), + closedCh: make(chan struct{}), + } +} + +func (m *mockListener) Listen(evChan chan<- nfct.Event, _ uint8, _ []netfilter.NetlinkGroup) (chan error, error) { + return m.errChan, nil +} + +func (m *mockListener) Close() error { + if m.closed.CompareAndSwap(false, true) { + close(m.closedCh) + } + return nil +} + +func TestReconnectAfterError(t *testing.T) { + first := newMockListener() + second := newMockListener() + callCount := atomic.Int32{} + + ct := &ConnTrack{ + dial: func() (listener, error) { + n := callCount.Add(1) + if n == 1 { + return first, nil + } + return second, nil + }, + done: make(chan struct{}, 1), + } + + err := ct.Start(false) + require.NoError(t, err) + + // Inject an error on the first listener. + first.errChan <- assert.AnError + + // Wait for reconnect to complete. + require.Eventually(t, func() bool { + return callCount.Load() >= 2 + }, 15*time.Second, 100*time.Millisecond, "reconnect should dial a new connection") + + // The first connection must have been closed. + select { + case <-first.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("first connection was not closed") + } + + // Verify the receiver is still running by injecting and handling a second error. + third := newMockListener() + callCount.Store(2) + ct.mux.Lock() + ct.dial = func() (listener, error) { + callCount.Add(1) + return third, nil + } + ct.mux.Unlock() + + second.errChan <- assert.AnError + + require.Eventually(t, func() bool { + return callCount.Load() >= 3 + }, 15*time.Second, 100*time.Millisecond, "second reconnect should succeed") + + ct.Stop() +} + +func TestStopDuringReconnectBackoff(t *testing.T) { + mock := newMockListener() + + ct := &ConnTrack{ + dial: func() (listener, error) { + return mock, nil + }, + done: make(chan struct{}, 1), + } + + err := ct.Start(false) + require.NoError(t, err) + + // Trigger an error so the receiver enters reconnect. + mock.errChan <- assert.AnError + + // Give the goroutine time to enter the reconnect backoff wait. + time.Sleep(500 * time.Millisecond) + + // Stop while reconnecting. + ct.Stop() + + ct.mux.Lock() + assert.False(t, ct.started, "started should be false after Stop") + assert.Nil(t, ct.conn, "conn should be nil after Stop") + ct.mux.Unlock() +} + +func TestStopRaceWithReconnectDial(t *testing.T) { + first := newMockListener() + dialStarted := make(chan struct{}) + dialProceed := make(chan struct{}) + second := newMockListener() + callCount := atomic.Int32{} + + ct := &ConnTrack{ + dial: func() (listener, error) { + n := callCount.Add(1) + if n == 1 { + return first, nil + } + // Second dial: signal that we're in progress, wait for test to call Stop. + close(dialStarted) + <-dialProceed + return second, nil + }, + done: make(chan struct{}, 1), + } + + err := ct.Start(false) + require.NoError(t, err) + + // Trigger error to enter reconnect. + first.errChan <- assert.AnError + + // Wait for reconnect's second dial to begin. + select { + case <-dialStarted: + case <-time.After(15 * time.Second): + t.Fatal("timed out waiting for reconnect dial") + } + + // Stop while dial is in progress (conn is nil at this point). + ct.Stop() + + // Let the dial complete. reconnect should detect started==false and close the new conn. + close(dialProceed) + + // The second connection should be closed (not leaked). + select { + case <-second.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("second connection was leaked after Stop") + } + + ct.mux.Lock() + assert.False(t, ct.started) + assert.Nil(t, ct.conn) + ct.mux.Unlock() +} + +func TestStartIsIdempotent(t *testing.T) { + mock := newMockListener() + callCount := atomic.Int32{} + + ct := &ConnTrack{ + dial: func() (listener, error) { + callCount.Add(1) + return mock, nil + }, + done: make(chan struct{}, 1), + } + + err := ct.Start(false) + require.NoError(t, err) + + // Second Start should be a no-op. + err = ct.Start(false) + require.NoError(t, err) + + assert.Equal(t, int32(1), callCount.Load(), "dial should only be called once") + + ct.Stop() +}