Pass optional dialer to constructor instead of mutating internal fields from tests

This commit is contained in:
Viktor Liu
2026-04-15 19:24:53 +02:00
parent 4462550a51
commit 129c951588
2 changed files with 51 additions and 59 deletions

View File

@@ -47,19 +47,36 @@ type ConnTrack struct {
sysctlModified bool
}
// DialFunc is a constructor for netlink conntrack connections.
type DialFunc func() (listener, error)
// Option configures a ConnTrack instance.
type Option func(*ConnTrack)
// WithDialer overrides the default netlink dialer, primarily for testing.
func WithDialer(dial DialFunc) Option {
return func(c *ConnTrack) {
c.dial = dial
}
}
func defaultDial() (listener, error) {
return nfct.Dial(nil)
}
// New creates a new connection tracker that interfaces with the kernel's conntrack system
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
return &ConnTrack{
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack {
ct := &ConnTrack{
flowLogger: flowLogger,
iface: iface,
instanceID: uuid.New(),
dial: defaultDial,
done: make(chan struct{}, 1),
}
for _, opt := range opts {
opt(ct)
}
return ct
}
// Start begins tracking connections by listening for conntrack events. This method is idempotent.

View File

@@ -40,18 +40,14 @@ func (m *mockListener) Close() error {
func TestReconnectAfterError(t *testing.T) {
first := newMockListener()
second := newMockListener()
third := newMockListener()
listeners := []*mockListener{first, second, third}
callCount := atomic.Int32{}
ct := &ConnTrack{
dial: func() (listener, error) {
n := callCount.Add(1)
if n == 1 {
return first, nil
}
return second, nil
},
done: make(chan struct{}, 1),
}
ct := New(nil, nil, WithDialer(func() (listener, error) {
n := int(callCount.Add(1)) - 1
return listeners[n], nil
}))
err := ct.Start(false)
require.NoError(t, err)
@@ -72,15 +68,6 @@ func TestReconnectAfterError(t *testing.T) {
}
// Verify the receiver is still running by injecting and handling a second error.
third := newMockListener()
callCount.Store(2)
ct.mux.Lock()
ct.dial = func() (listener, error) {
callCount.Add(1)
return third, nil
}
ct.mux.Unlock()
second.errChan <- assert.AnError
require.Eventually(t, func() bool {
@@ -93,12 +80,9 @@ func TestReconnectAfterError(t *testing.T) {
func TestStopDuringReconnectBackoff(t *testing.T) {
mock := newMockListener()
ct := &ConnTrack{
dial: func() (listener, error) {
return mock, nil
},
done: make(chan struct{}, 1),
}
ct := New(nil, nil, WithDialer(func() (listener, error) {
return mock, nil
}))
err := ct.Start(false)
require.NoError(t, err)
@@ -129,19 +113,16 @@ func TestStopRaceWithReconnectDial(t *testing.T) {
second := newMockListener()
callCount := atomic.Int32{}
ct := &ConnTrack{
dial: func() (listener, error) {
n := callCount.Add(1)
if n == 1 {
return first, nil
}
// Second dial: signal that we're in progress, wait for test to call Stop.
close(dialStarted)
<-dialProceed
return second, nil
},
done: make(chan struct{}, 1),
}
ct := New(nil, nil, WithDialer(func() (listener, error) {
n := callCount.Add(1)
if n == 1 {
return first, nil
}
// Second dial: signal that we're in progress, wait for test to call Stop.
close(dialStarted)
<-dialProceed
return second, nil
}))
err := ct.Start(false)
require.NoError(t, err)
@@ -182,18 +163,15 @@ func TestCloseRaceWithReconnectDial(t *testing.T) {
second := newMockListener()
callCount := atomic.Int32{}
ct := &ConnTrack{
dial: func() (listener, error) {
n := callCount.Add(1)
if n == 1 {
return first, nil
}
close(dialStarted)
<-dialProceed
return second, nil
},
done: make(chan struct{}, 1),
}
ct := New(nil, nil, WithDialer(func() (listener, error) {
n := callCount.Add(1)
if n == 1 {
return first, nil
}
close(dialStarted)
<-dialProceed
return second, nil
}))
err := ct.Start(false)
require.NoError(t, err)
@@ -228,13 +206,10 @@ func TestStartIsIdempotent(t *testing.T) {
mock := newMockListener()
callCount := atomic.Int32{}
ct := &ConnTrack{
dial: func() (listener, error) {
callCount.Add(1)
return mock, nil
},
done: make(chan struct{}, 1),
}
ct := New(nil, nil, WithDialer(func() (listener, error) {
callCount.Add(1)
return mock, nil
}))
err := ct.Start(false)
require.NoError(t, err)