mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-15 23:06:38 +00:00
Pass optional dialer to constructor instead of mutating internal fields from tests
This commit is contained in:
@@ -47,19 +47,36 @@ type ConnTrack struct {
|
||||
sysctlModified bool
|
||||
}
|
||||
|
||||
// DialFunc is a constructor for netlink conntrack connections.
|
||||
type DialFunc func() (listener, error)
|
||||
|
||||
// Option configures a ConnTrack instance.
|
||||
type Option func(*ConnTrack)
|
||||
|
||||
// WithDialer overrides the default netlink dialer, primarily for testing.
|
||||
func WithDialer(dial DialFunc) Option {
|
||||
return func(c *ConnTrack) {
|
||||
c.dial = dial
|
||||
}
|
||||
}
|
||||
|
||||
func defaultDial() (listener, error) {
|
||||
return nfct.Dial(nil)
|
||||
}
|
||||
|
||||
// New creates a new connection tracker that interfaces with the kernel's conntrack system
|
||||
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
|
||||
return &ConnTrack{
|
||||
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack {
|
||||
ct := &ConnTrack{
|
||||
flowLogger: flowLogger,
|
||||
iface: iface,
|
||||
instanceID: uuid.New(),
|
||||
dial: defaultDial,
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(ct)
|
||||
}
|
||||
return ct
|
||||
}
|
||||
|
||||
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
|
||||
|
||||
@@ -40,18 +40,14 @@ func (m *mockListener) Close() error {
|
||||
func TestReconnectAfterError(t *testing.T) {
|
||||
first := newMockListener()
|
||||
second := newMockListener()
|
||||
third := newMockListener()
|
||||
listeners := []*mockListener{first, second, third}
|
||||
callCount := atomic.Int32{}
|
||||
|
||||
ct := &ConnTrack{
|
||||
dial: func() (listener, error) {
|
||||
n := callCount.Add(1)
|
||||
if n == 1 {
|
||||
return first, nil
|
||||
}
|
||||
return second, nil
|
||||
},
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||
n := int(callCount.Add(1)) - 1
|
||||
return listeners[n], nil
|
||||
}))
|
||||
|
||||
err := ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
@@ -72,15 +68,6 @@ func TestReconnectAfterError(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify the receiver is still running by injecting and handling a second error.
|
||||
third := newMockListener()
|
||||
callCount.Store(2)
|
||||
ct.mux.Lock()
|
||||
ct.dial = func() (listener, error) {
|
||||
callCount.Add(1)
|
||||
return third, nil
|
||||
}
|
||||
ct.mux.Unlock()
|
||||
|
||||
second.errChan <- assert.AnError
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
@@ -93,12 +80,9 @@ func TestReconnectAfterError(t *testing.T) {
|
||||
func TestStopDuringReconnectBackoff(t *testing.T) {
|
||||
mock := newMockListener()
|
||||
|
||||
ct := &ConnTrack{
|
||||
dial: func() (listener, error) {
|
||||
return mock, nil
|
||||
},
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||
return mock, nil
|
||||
}))
|
||||
|
||||
err := ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
@@ -129,19 +113,16 @@ func TestStopRaceWithReconnectDial(t *testing.T) {
|
||||
second := newMockListener()
|
||||
callCount := atomic.Int32{}
|
||||
|
||||
ct := &ConnTrack{
|
||||
dial: func() (listener, error) {
|
||||
n := callCount.Add(1)
|
||||
if n == 1 {
|
||||
return first, nil
|
||||
}
|
||||
// Second dial: signal that we're in progress, wait for test to call Stop.
|
||||
close(dialStarted)
|
||||
<-dialProceed
|
||||
return second, nil
|
||||
},
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||
n := callCount.Add(1)
|
||||
if n == 1 {
|
||||
return first, nil
|
||||
}
|
||||
// Second dial: signal that we're in progress, wait for test to call Stop.
|
||||
close(dialStarted)
|
||||
<-dialProceed
|
||||
return second, nil
|
||||
}))
|
||||
|
||||
err := ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
@@ -182,18 +163,15 @@ func TestCloseRaceWithReconnectDial(t *testing.T) {
|
||||
second := newMockListener()
|
||||
callCount := atomic.Int32{}
|
||||
|
||||
ct := &ConnTrack{
|
||||
dial: func() (listener, error) {
|
||||
n := callCount.Add(1)
|
||||
if n == 1 {
|
||||
return first, nil
|
||||
}
|
||||
close(dialStarted)
|
||||
<-dialProceed
|
||||
return second, nil
|
||||
},
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||
n := callCount.Add(1)
|
||||
if n == 1 {
|
||||
return first, nil
|
||||
}
|
||||
close(dialStarted)
|
||||
<-dialProceed
|
||||
return second, nil
|
||||
}))
|
||||
|
||||
err := ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
@@ -228,13 +206,10 @@ func TestStartIsIdempotent(t *testing.T) {
|
||||
mock := newMockListener()
|
||||
callCount := atomic.Int32{}
|
||||
|
||||
ct := &ConnTrack{
|
||||
dial: func() (listener, error) {
|
||||
callCount.Add(1)
|
||||
return mock, nil
|
||||
},
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||
callCount.Add(1)
|
||||
return mock, nil
|
||||
}))
|
||||
|
||||
err := ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
Reference in New Issue
Block a user