mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Pass optional dialer to constructor instead of mutating internal fields from tests
This commit is contained in:
@@ -47,19 +47,36 @@ type ConnTrack struct {
|
|||||||
sysctlModified bool
|
sysctlModified bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DialFunc is a constructor for netlink conntrack connections.
|
||||||
|
type DialFunc func() (listener, error)
|
||||||
|
|
||||||
|
// Option configures a ConnTrack instance.
|
||||||
|
type Option func(*ConnTrack)
|
||||||
|
|
||||||
|
// WithDialer overrides the default netlink dialer, primarily for testing.
|
||||||
|
func WithDialer(dial DialFunc) Option {
|
||||||
|
return func(c *ConnTrack) {
|
||||||
|
c.dial = dial
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func defaultDial() (listener, error) {
|
func defaultDial() (listener, error) {
|
||||||
return nfct.Dial(nil)
|
return nfct.Dial(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new connection tracker that interfaces with the kernel's conntrack system
|
// New creates a new connection tracker that interfaces with the kernel's conntrack system
|
||||||
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
|
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack {
|
||||||
return &ConnTrack{
|
ct := &ConnTrack{
|
||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
iface: iface,
|
iface: iface,
|
||||||
instanceID: uuid.New(),
|
instanceID: uuid.New(),
|
||||||
dial: defaultDial,
|
dial: defaultDial,
|
||||||
done: make(chan struct{}, 1),
|
done: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(ct)
|
||||||
|
}
|
||||||
|
return ct
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
|
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
|
||||||
|
|||||||
@@ -40,18 +40,14 @@ func (m *mockListener) Close() error {
|
|||||||
func TestReconnectAfterError(t *testing.T) {
|
func TestReconnectAfterError(t *testing.T) {
|
||||||
first := newMockListener()
|
first := newMockListener()
|
||||||
second := newMockListener()
|
second := newMockListener()
|
||||||
|
third := newMockListener()
|
||||||
|
listeners := []*mockListener{first, second, third}
|
||||||
callCount := atomic.Int32{}
|
callCount := atomic.Int32{}
|
||||||
|
|
||||||
ct := &ConnTrack{
|
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||||
dial: func() (listener, error) {
|
n := int(callCount.Add(1)) - 1
|
||||||
n := callCount.Add(1)
|
return listeners[n], nil
|
||||||
if n == 1 {
|
}))
|
||||||
return first, nil
|
|
||||||
}
|
|
||||||
return second, nil
|
|
||||||
},
|
|
||||||
done: make(chan struct{}, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ct.Start(false)
|
err := ct.Start(false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -72,15 +68,6 @@ func TestReconnectAfterError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify the receiver is still running by injecting and handling a second error.
|
// Verify the receiver is still running by injecting and handling a second error.
|
||||||
third := newMockListener()
|
|
||||||
callCount.Store(2)
|
|
||||||
ct.mux.Lock()
|
|
||||||
ct.dial = func() (listener, error) {
|
|
||||||
callCount.Add(1)
|
|
||||||
return third, nil
|
|
||||||
}
|
|
||||||
ct.mux.Unlock()
|
|
||||||
|
|
||||||
second.errChan <- assert.AnError
|
second.errChan <- assert.AnError
|
||||||
|
|
||||||
require.Eventually(t, func() bool {
|
require.Eventually(t, func() bool {
|
||||||
@@ -93,12 +80,9 @@ func TestReconnectAfterError(t *testing.T) {
|
|||||||
func TestStopDuringReconnectBackoff(t *testing.T) {
|
func TestStopDuringReconnectBackoff(t *testing.T) {
|
||||||
mock := newMockListener()
|
mock := newMockListener()
|
||||||
|
|
||||||
ct := &ConnTrack{
|
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||||
dial: func() (listener, error) {
|
return mock, nil
|
||||||
return mock, nil
|
}))
|
||||||
},
|
|
||||||
done: make(chan struct{}, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ct.Start(false)
|
err := ct.Start(false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -129,19 +113,16 @@ func TestStopRaceWithReconnectDial(t *testing.T) {
|
|||||||
second := newMockListener()
|
second := newMockListener()
|
||||||
callCount := atomic.Int32{}
|
callCount := atomic.Int32{}
|
||||||
|
|
||||||
ct := &ConnTrack{
|
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||||
dial: func() (listener, error) {
|
n := callCount.Add(1)
|
||||||
n := callCount.Add(1)
|
if n == 1 {
|
||||||
if n == 1 {
|
return first, nil
|
||||||
return first, nil
|
}
|
||||||
}
|
// Second dial: signal that we're in progress, wait for test to call Stop.
|
||||||
// Second dial: signal that we're in progress, wait for test to call Stop.
|
close(dialStarted)
|
||||||
close(dialStarted)
|
<-dialProceed
|
||||||
<-dialProceed
|
return second, nil
|
||||||
return second, nil
|
}))
|
||||||
},
|
|
||||||
done: make(chan struct{}, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ct.Start(false)
|
err := ct.Start(false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -182,18 +163,15 @@ func TestCloseRaceWithReconnectDial(t *testing.T) {
|
|||||||
second := newMockListener()
|
second := newMockListener()
|
||||||
callCount := atomic.Int32{}
|
callCount := atomic.Int32{}
|
||||||
|
|
||||||
ct := &ConnTrack{
|
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||||
dial: func() (listener, error) {
|
n := callCount.Add(1)
|
||||||
n := callCount.Add(1)
|
if n == 1 {
|
||||||
if n == 1 {
|
return first, nil
|
||||||
return first, nil
|
}
|
||||||
}
|
close(dialStarted)
|
||||||
close(dialStarted)
|
<-dialProceed
|
||||||
<-dialProceed
|
return second, nil
|
||||||
return second, nil
|
}))
|
||||||
},
|
|
||||||
done: make(chan struct{}, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ct.Start(false)
|
err := ct.Start(false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -228,13 +206,10 @@ func TestStartIsIdempotent(t *testing.T) {
|
|||||||
mock := newMockListener()
|
mock := newMockListener()
|
||||||
callCount := atomic.Int32{}
|
callCount := atomic.Int32{}
|
||||||
|
|
||||||
ct := &ConnTrack{
|
ct := New(nil, nil, WithDialer(func() (listener, error) {
|
||||||
dial: func() (listener, error) {
|
callCount.Add(1)
|
||||||
callCount.Add(1)
|
return mock, nil
|
||||||
return mock, nil
|
}))
|
||||||
},
|
|
||||||
done: make(chan struct{}, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := ct.Start(false)
|
err := ct.Start(false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
Reference in New Issue
Block a user