mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Add tests for conntrack reconnect behavior
This commit is contained in:
@@ -26,27 +26,38 @@ const (
|
|||||||
reconnectRandomization = 0.5
|
reconnectRandomization = 0.5
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// listener abstracts a netlink conntrack connection for testability.
|
||||||
|
type listener interface {
|
||||||
|
Listen(evChan chan<- nfct.Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error)
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
// ConnTrack manages kernel-based conntrack events
|
// ConnTrack manages kernel-based conntrack events
|
||||||
type ConnTrack struct {
|
type ConnTrack struct {
|
||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
iface nftypes.IFaceMapper
|
iface nftypes.IFaceMapper
|
||||||
|
|
||||||
conn *nfct.Conn
|
conn listener
|
||||||
mux sync.Mutex
|
mux sync.Mutex
|
||||||
|
|
||||||
|
dial func() (listener, error)
|
||||||
instanceID uuid.UUID
|
instanceID uuid.UUID
|
||||||
started bool
|
started bool
|
||||||
done chan struct{}
|
done chan struct{}
|
||||||
sysctlModified bool
|
sysctlModified bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func defaultDial() (listener, error) {
|
||||||
|
return nfct.Dial(nil)
|
||||||
|
}
|
||||||
|
|
||||||
// New creates a new connection tracker that interfaces with the kernel's conntrack system
|
// New creates a new connection tracker that interfaces with the kernel's conntrack system
|
||||||
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
|
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
|
||||||
return &ConnTrack{
|
return &ConnTrack{
|
||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
iface: iface,
|
iface: iface,
|
||||||
instanceID: uuid.New(),
|
instanceID: uuid.New(),
|
||||||
started: false,
|
dial: defaultDial,
|
||||||
done: make(chan struct{}, 1),
|
done: make(chan struct{}, 1),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -66,7 +77,7 @@ func (c *ConnTrack) Start(enableCounters bool) error {
|
|||||||
c.EnableAccounting()
|
c.EnableAccounting()
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := nfct.Dial(nil)
|
conn, err := c.dial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("dial conntrack: %w", err)
|
return fmt.Errorf("dial conntrack: %w", err)
|
||||||
}
|
}
|
||||||
@@ -160,7 +171,7 @@ func (c *ConnTrack) reconnect() (chan nfct.Event, chan error) {
|
|||||||
case <-time.After(delay):
|
case <-time.After(delay):
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := nfct.Dial(nil)
|
conn, err := c.dial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("reconnect conntrack dial: %v", err)
|
log.Warnf("reconnect conntrack dial: %v", err)
|
||||||
continue
|
continue
|
||||||
|
|||||||
198
client/internal/netflow/conntrack/conntrack_test.go
Normal file
198
client/internal/netflow/conntrack/conntrack_test.go
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
nfct "github.com/ti-mo/conntrack"
|
||||||
|
"github.com/ti-mo/netfilter"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockListener struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
errChan chan error
|
||||||
|
closed atomic.Bool
|
||||||
|
closedCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockListener() *mockListener {
|
||||||
|
return &mockListener{
|
||||||
|
errChan: make(chan error, 1),
|
||||||
|
closedCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockListener) Listen(evChan chan<- nfct.Event, _ uint8, _ []netfilter.NetlinkGroup) (chan error, error) {
|
||||||
|
return m.errChan, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockListener) Close() error {
|
||||||
|
if m.closed.CompareAndSwap(false, true) {
|
||||||
|
close(m.closedCh)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconnectAfterError(t *testing.T) {
|
||||||
|
first := newMockListener()
|
||||||
|
second := newMockListener()
|
||||||
|
callCount := atomic.Int32{}
|
||||||
|
|
||||||
|
ct := &ConnTrack{
|
||||||
|
dial: func() (listener, error) {
|
||||||
|
n := callCount.Add(1)
|
||||||
|
if n == 1 {
|
||||||
|
return first, nil
|
||||||
|
}
|
||||||
|
return second, nil
|
||||||
|
},
|
||||||
|
done: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ct.Start(false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Inject an error on the first listener.
|
||||||
|
first.errChan <- assert.AnError
|
||||||
|
|
||||||
|
// Wait for reconnect to complete.
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return callCount.Load() >= 2
|
||||||
|
}, 15*time.Second, 100*time.Millisecond, "reconnect should dial a new connection")
|
||||||
|
|
||||||
|
// The first connection must have been closed.
|
||||||
|
select {
|
||||||
|
case <-first.closedCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("first connection was not closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the receiver is still running by injecting and handling a second error.
|
||||||
|
third := newMockListener()
|
||||||
|
callCount.Store(2)
|
||||||
|
ct.mux.Lock()
|
||||||
|
ct.dial = func() (listener, error) {
|
||||||
|
callCount.Add(1)
|
||||||
|
return third, nil
|
||||||
|
}
|
||||||
|
ct.mux.Unlock()
|
||||||
|
|
||||||
|
second.errChan <- assert.AnError
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return callCount.Load() >= 3
|
||||||
|
}, 15*time.Second, 100*time.Millisecond, "second reconnect should succeed")
|
||||||
|
|
||||||
|
ct.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStopDuringReconnectBackoff(t *testing.T) {
|
||||||
|
mock := newMockListener()
|
||||||
|
|
||||||
|
ct := &ConnTrack{
|
||||||
|
dial: func() (listener, error) {
|
||||||
|
return mock, nil
|
||||||
|
},
|
||||||
|
done: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ct.Start(false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Trigger an error so the receiver enters reconnect.
|
||||||
|
mock.errChan <- assert.AnError
|
||||||
|
|
||||||
|
// Give the goroutine time to enter the reconnect backoff wait.
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
|
// Stop while reconnecting.
|
||||||
|
ct.Stop()
|
||||||
|
|
||||||
|
ct.mux.Lock()
|
||||||
|
assert.False(t, ct.started, "started should be false after Stop")
|
||||||
|
assert.Nil(t, ct.conn, "conn should be nil after Stop")
|
||||||
|
ct.mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStopRaceWithReconnectDial(t *testing.T) {
|
||||||
|
first := newMockListener()
|
||||||
|
dialStarted := make(chan struct{})
|
||||||
|
dialProceed := make(chan struct{})
|
||||||
|
second := newMockListener()
|
||||||
|
callCount := atomic.Int32{}
|
||||||
|
|
||||||
|
ct := &ConnTrack{
|
||||||
|
dial: func() (listener, error) {
|
||||||
|
n := callCount.Add(1)
|
||||||
|
if n == 1 {
|
||||||
|
return first, nil
|
||||||
|
}
|
||||||
|
// Second dial: signal that we're in progress, wait for test to call Stop.
|
||||||
|
close(dialStarted)
|
||||||
|
<-dialProceed
|
||||||
|
return second, nil
|
||||||
|
},
|
||||||
|
done: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ct.Start(false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Trigger error to enter reconnect.
|
||||||
|
first.errChan <- assert.AnError
|
||||||
|
|
||||||
|
// Wait for reconnect's second dial to begin.
|
||||||
|
select {
|
||||||
|
case <-dialStarted:
|
||||||
|
case <-time.After(15 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for reconnect dial")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop while dial is in progress (conn is nil at this point).
|
||||||
|
ct.Stop()
|
||||||
|
|
||||||
|
// Let the dial complete. reconnect should detect started==false and close the new conn.
|
||||||
|
close(dialProceed)
|
||||||
|
|
||||||
|
// The second connection should be closed (not leaked).
|
||||||
|
select {
|
||||||
|
case <-second.closedCh:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("second connection was leaked after Stop")
|
||||||
|
}
|
||||||
|
|
||||||
|
ct.mux.Lock()
|
||||||
|
assert.False(t, ct.started)
|
||||||
|
assert.Nil(t, ct.conn)
|
||||||
|
ct.mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartIsIdempotent(t *testing.T) {
|
||||||
|
mock := newMockListener()
|
||||||
|
callCount := atomic.Int32{}
|
||||||
|
|
||||||
|
ct := &ConnTrack{
|
||||||
|
dial: func() (listener, error) {
|
||||||
|
callCount.Add(1)
|
||||||
|
return mock, nil
|
||||||
|
},
|
||||||
|
done: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ct.Start(false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Second Start should be a no-op.
|
||||||
|
err = ct.Start(false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, int32(1), callCount.Load(), "dial should only be called once")
|
||||||
|
|
||||||
|
ct.Stop()
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user