Pass optional dialer to constructor instead of mutating internal fields from tests

This commit is contained in:
Viktor Liu
2026-04-15 19:24:53 +02:00
parent 4462550a51
commit 129c951588
2 changed files with 51 additions and 59 deletions

View File

@@ -47,19 +47,36 @@ type ConnTrack struct {
sysctlModified bool sysctlModified bool
} }
// DialFunc is a constructor for netlink conntrack connections.
type DialFunc func() (listener, error)
// Option configures a ConnTrack instance.
type Option func(*ConnTrack)
// WithDialer overrides the default netlink dialer, primarily for testing.
func WithDialer(dial DialFunc) Option {
return func(c *ConnTrack) {
c.dial = dial
}
}
func defaultDial() (listener, error) { func defaultDial() (listener, error) {
return nfct.Dial(nil) return nfct.Dial(nil)
} }
// New creates a new connection tracker that interfaces with the kernel's conntrack system // New creates a new connection tracker that interfaces with the kernel's conntrack system
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack { func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack {
return &ConnTrack{ ct := &ConnTrack{
flowLogger: flowLogger, flowLogger: flowLogger,
iface: iface, iface: iface,
instanceID: uuid.New(), instanceID: uuid.New(),
dial: defaultDial, dial: defaultDial,
done: make(chan struct{}, 1), done: make(chan struct{}, 1),
} }
for _, opt := range opts {
opt(ct)
}
return ct
} }
// Start begins tracking connections by listening for conntrack events. This method is idempotent. // Start begins tracking connections by listening for conntrack events. This method is idempotent.

View File

@@ -40,18 +40,14 @@ func (m *mockListener) Close() error {
func TestReconnectAfterError(t *testing.T) { func TestReconnectAfterError(t *testing.T) {
first := newMockListener() first := newMockListener()
second := newMockListener() second := newMockListener()
third := newMockListener()
listeners := []*mockListener{first, second, third}
callCount := atomic.Int32{} callCount := atomic.Int32{}
ct := &ConnTrack{ ct := New(nil, nil, WithDialer(func() (listener, error) {
dial: func() (listener, error) { n := int(callCount.Add(1)) - 1
n := callCount.Add(1) return listeners[n], nil
if n == 1 { }))
return first, nil
}
return second, nil
},
done: make(chan struct{}, 1),
}
err := ct.Start(false) err := ct.Start(false)
require.NoError(t, err) require.NoError(t, err)
@@ -72,15 +68,6 @@ func TestReconnectAfterError(t *testing.T) {
} }
// Verify the receiver is still running by injecting and handling a second error. // Verify the receiver is still running by injecting and handling a second error.
third := newMockListener()
callCount.Store(2)
ct.mux.Lock()
ct.dial = func() (listener, error) {
callCount.Add(1)
return third, nil
}
ct.mux.Unlock()
second.errChan <- assert.AnError second.errChan <- assert.AnError
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
@@ -93,12 +80,9 @@ func TestReconnectAfterError(t *testing.T) {
func TestStopDuringReconnectBackoff(t *testing.T) { func TestStopDuringReconnectBackoff(t *testing.T) {
mock := newMockListener() mock := newMockListener()
ct := &ConnTrack{ ct := New(nil, nil, WithDialer(func() (listener, error) {
dial: func() (listener, error) { return mock, nil
return mock, nil }))
},
done: make(chan struct{}, 1),
}
err := ct.Start(false) err := ct.Start(false)
require.NoError(t, err) require.NoError(t, err)
@@ -129,19 +113,16 @@ func TestStopRaceWithReconnectDial(t *testing.T) {
second := newMockListener() second := newMockListener()
callCount := atomic.Int32{} callCount := atomic.Int32{}
ct := &ConnTrack{ ct := New(nil, nil, WithDialer(func() (listener, error) {
dial: func() (listener, error) { n := callCount.Add(1)
n := callCount.Add(1) if n == 1 {
if n == 1 { return first, nil
return first, nil }
} // Second dial: signal that we're in progress, wait for test to call Stop.
// Second dial: signal that we're in progress, wait for test to call Stop. close(dialStarted)
close(dialStarted) <-dialProceed
<-dialProceed return second, nil
return second, nil }))
},
done: make(chan struct{}, 1),
}
err := ct.Start(false) err := ct.Start(false)
require.NoError(t, err) require.NoError(t, err)
@@ -182,18 +163,15 @@ func TestCloseRaceWithReconnectDial(t *testing.T) {
second := newMockListener() second := newMockListener()
callCount := atomic.Int32{} callCount := atomic.Int32{}
ct := &ConnTrack{ ct := New(nil, nil, WithDialer(func() (listener, error) {
dial: func() (listener, error) { n := callCount.Add(1)
n := callCount.Add(1) if n == 1 {
if n == 1 { return first, nil
return first, nil }
} close(dialStarted)
close(dialStarted) <-dialProceed
<-dialProceed return second, nil
return second, nil }))
},
done: make(chan struct{}, 1),
}
err := ct.Start(false) err := ct.Start(false)
require.NoError(t, err) require.NoError(t, err)
@@ -228,13 +206,10 @@ func TestStartIsIdempotent(t *testing.T) {
mock := newMockListener() mock := newMockListener()
callCount := atomic.Int32{} callCount := atomic.Int32{}
ct := &ConnTrack{ ct := New(nil, nil, WithDialer(func() (listener, error) {
dial: func() (listener, error) { callCount.Add(1)
callCount.Add(1) return mock, nil
return mock, nil }))
},
done: make(chan struct{}, 1),
}
err := ct.Start(false) err := ct.Start(false)
require.NoError(t, err) require.NoError(t, err)