diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index 53619ab8c..2420b1fdf 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -47,19 +47,36 @@ type ConnTrack struct { sysctlModified bool } +// DialFunc is a constructor for netlink conntrack connections. +type DialFunc func() (listener, error) + +// Option configures a ConnTrack instance. +type Option func(*ConnTrack) + +// WithDialer overrides the default netlink dialer, primarily for testing. +func WithDialer(dial DialFunc) Option { + return func(c *ConnTrack) { + c.dial = dial + } +} + func defaultDial() (listener, error) { return nfct.Dial(nil) } // New creates a new connection tracker that interfaces with the kernel's conntrack system -func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack { - return &ConnTrack{ +func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack { + ct := &ConnTrack{ flowLogger: flowLogger, iface: iface, instanceID: uuid.New(), dial: defaultDial, done: make(chan struct{}, 1), } + for _, opt := range opts { + opt(ct) + } + return ct } // Start begins tracking connections by listening for conntrack events. This method is idempotent. diff --git a/client/internal/netflow/conntrack/conntrack_test.go b/client/internal/netflow/conntrack/conntrack_test.go index 1d8ec268a..35ceec90d 100644 --- a/client/internal/netflow/conntrack/conntrack_test.go +++ b/client/internal/netflow/conntrack/conntrack_test.go @@ -40,18 +40,14 @@ func (m *mockListener) Close() error { func TestReconnectAfterError(t *testing.T) { first := newMockListener() second := newMockListener() + third := newMockListener() + listeners := []*mockListener{first, second, third} callCount := atomic.Int32{} - ct := &ConnTrack{ - dial: func() (listener, error) { - n := callCount.Add(1) - if n == 1 { - return first, nil - } - return second, nil - }, - done: make(chan struct{}, 1), - } + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := int(callCount.Add(1)) - 1 + return listeners[n], nil + })) err := ct.Start(false) require.NoError(t, err) @@ -72,15 +68,6 @@ func TestReconnectAfterError(t *testing.T) { } // Verify the receiver is still running by injecting and handling a second error. - third := newMockListener() - callCount.Store(2) - ct.mux.Lock() - ct.dial = func() (listener, error) { - callCount.Add(1) - return third, nil - } - ct.mux.Unlock() - second.errChan <- assert.AnError require.Eventually(t, func() bool { @@ -93,12 +80,9 @@ func TestReconnectAfterError(t *testing.T) { func TestStopDuringReconnectBackoff(t *testing.T) { mock := newMockListener() - ct := &ConnTrack{ - dial: func() (listener, error) { - return mock, nil - }, - done: make(chan struct{}, 1), - } + ct := New(nil, nil, WithDialer(func() (listener, error) { + return mock, nil + })) err := ct.Start(false) require.NoError(t, err) @@ -129,19 +113,16 @@ func TestStopRaceWithReconnectDial(t *testing.T) { second := newMockListener() callCount := atomic.Int32{} - ct := &ConnTrack{ - dial: func() (listener, error) { - n := callCount.Add(1) - if n == 1 { - return first, nil - } - // Second dial: signal that we're in progress, wait for test to call Stop. - close(dialStarted) - <-dialProceed - return second, nil - }, - done: make(chan struct{}, 1), - } + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := callCount.Add(1) + if n == 1 { + return first, nil + } + // Second dial: signal that we're in progress, wait for test to call Stop. + close(dialStarted) + <-dialProceed + return second, nil + })) err := ct.Start(false) require.NoError(t, err) @@ -182,18 +163,15 @@ func TestCloseRaceWithReconnectDial(t *testing.T) { second := newMockListener() callCount := atomic.Int32{} - ct := &ConnTrack{ - dial: func() (listener, error) { - n := callCount.Add(1) - if n == 1 { - return first, nil - } - close(dialStarted) - <-dialProceed - return second, nil - }, - done: make(chan struct{}, 1), - } + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := callCount.Add(1) + if n == 1 { + return first, nil + } + close(dialStarted) + <-dialProceed + return second, nil + })) err := ct.Start(false) require.NoError(t, err) @@ -228,13 +206,10 @@ func TestStartIsIdempotent(t *testing.T) { mock := newMockListener() callCount := atomic.Int32{} - ct := &ConnTrack{ - dial: func() (listener, error) { - callCount.Add(1) - return mock, nil - }, - done: make(chan struct{}, 1), - } + ct := New(nil, nil, WithDialer(func() (listener, error) { + callCount.Add(1) + return mock, nil + })) err := ct.Start(false) require.NoError(t, err)