mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-15 23:06:38 +00:00
Add tests for conntrack reconnect behavior
This commit is contained in:
@@ -26,27 +26,38 @@ const (
|
||||
reconnectRandomization = 0.5
|
||||
)
|
||||
|
||||
// listener abstracts a netlink conntrack connection for testability.
|
||||
type listener interface {
|
||||
Listen(evChan chan<- nfct.Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// ConnTrack manages kernel-based conntrack events
|
||||
type ConnTrack struct {
|
||||
flowLogger nftypes.FlowLogger
|
||||
iface nftypes.IFaceMapper
|
||||
|
||||
conn *nfct.Conn
|
||||
conn listener
|
||||
mux sync.Mutex
|
||||
|
||||
dial func() (listener, error)
|
||||
instanceID uuid.UUID
|
||||
started bool
|
||||
done chan struct{}
|
||||
sysctlModified bool
|
||||
}
|
||||
|
||||
func defaultDial() (listener, error) {
|
||||
return nfct.Dial(nil)
|
||||
}
|
||||
|
||||
// New creates a new connection tracker that interfaces with the kernel's conntrack system
|
||||
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
|
||||
return &ConnTrack{
|
||||
flowLogger: flowLogger,
|
||||
iface: iface,
|
||||
instanceID: uuid.New(),
|
||||
started: false,
|
||||
dial: defaultDial,
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
@@ -66,7 +77,7 @@ func (c *ConnTrack) Start(enableCounters bool) error {
|
||||
c.EnableAccounting()
|
||||
}
|
||||
|
||||
conn, err := nfct.Dial(nil)
|
||||
conn, err := c.dial()
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial conntrack: %w", err)
|
||||
}
|
||||
@@ -160,7 +171,7 @@ func (c *ConnTrack) reconnect() (chan nfct.Event, chan error) {
|
||||
case <-time.After(delay):
|
||||
}
|
||||
|
||||
conn, err := nfct.Dial(nil)
|
||||
conn, err := c.dial()
|
||||
if err != nil {
|
||||
log.Warnf("reconnect conntrack dial: %v", err)
|
||||
continue
|
||||
|
||||
198
client/internal/netflow/conntrack/conntrack_test.go
Normal file
198
client/internal/netflow/conntrack/conntrack_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
//go:build linux && !android
|
||||
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
nfct "github.com/ti-mo/conntrack"
|
||||
"github.com/ti-mo/netfilter"
|
||||
)
|
||||
|
||||
type mockListener struct {
|
||||
mu sync.Mutex
|
||||
errChan chan error
|
||||
closed atomic.Bool
|
||||
closedCh chan struct{}
|
||||
}
|
||||
|
||||
func newMockListener() *mockListener {
|
||||
return &mockListener{
|
||||
errChan: make(chan error, 1),
|
||||
closedCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockListener) Listen(evChan chan<- nfct.Event, _ uint8, _ []netfilter.NetlinkGroup) (chan error, error) {
|
||||
return m.errChan, nil
|
||||
}
|
||||
|
||||
func (m *mockListener) Close() error {
|
||||
if m.closed.CompareAndSwap(false, true) {
|
||||
close(m.closedCh)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestReconnectAfterError(t *testing.T) {
|
||||
first := newMockListener()
|
||||
second := newMockListener()
|
||||
callCount := atomic.Int32{}
|
||||
|
||||
ct := &ConnTrack{
|
||||
dial: func() (listener, error) {
|
||||
n := callCount.Add(1)
|
||||
if n == 1 {
|
||||
return first, nil
|
||||
}
|
||||
return second, nil
|
||||
},
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
err := ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Inject an error on the first listener.
|
||||
first.errChan <- assert.AnError
|
||||
|
||||
// Wait for reconnect to complete.
|
||||
require.Eventually(t, func() bool {
|
||||
return callCount.Load() >= 2
|
||||
}, 15*time.Second, 100*time.Millisecond, "reconnect should dial a new connection")
|
||||
|
||||
// The first connection must have been closed.
|
||||
select {
|
||||
case <-first.closedCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("first connection was not closed")
|
||||
}
|
||||
|
||||
// Verify the receiver is still running by injecting and handling a second error.
|
||||
third := newMockListener()
|
||||
callCount.Store(2)
|
||||
ct.mux.Lock()
|
||||
ct.dial = func() (listener, error) {
|
||||
callCount.Add(1)
|
||||
return third, nil
|
||||
}
|
||||
ct.mux.Unlock()
|
||||
|
||||
second.errChan <- assert.AnError
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return callCount.Load() >= 3
|
||||
}, 15*time.Second, 100*time.Millisecond, "second reconnect should succeed")
|
||||
|
||||
ct.Stop()
|
||||
}
|
||||
|
||||
func TestStopDuringReconnectBackoff(t *testing.T) {
|
||||
mock := newMockListener()
|
||||
|
||||
ct := &ConnTrack{
|
||||
dial: func() (listener, error) {
|
||||
return mock, nil
|
||||
},
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
err := ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Trigger an error so the receiver enters reconnect.
|
||||
mock.errChan <- assert.AnError
|
||||
|
||||
// Give the goroutine time to enter the reconnect backoff wait.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
// Stop while reconnecting.
|
||||
ct.Stop()
|
||||
|
||||
ct.mux.Lock()
|
||||
assert.False(t, ct.started, "started should be false after Stop")
|
||||
assert.Nil(t, ct.conn, "conn should be nil after Stop")
|
||||
ct.mux.Unlock()
|
||||
}
|
||||
|
||||
func TestStopRaceWithReconnectDial(t *testing.T) {
|
||||
first := newMockListener()
|
||||
dialStarted := make(chan struct{})
|
||||
dialProceed := make(chan struct{})
|
||||
second := newMockListener()
|
||||
callCount := atomic.Int32{}
|
||||
|
||||
ct := &ConnTrack{
|
||||
dial: func() (listener, error) {
|
||||
n := callCount.Add(1)
|
||||
if n == 1 {
|
||||
return first, nil
|
||||
}
|
||||
// Second dial: signal that we're in progress, wait for test to call Stop.
|
||||
close(dialStarted)
|
||||
<-dialProceed
|
||||
return second, nil
|
||||
},
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
err := ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Trigger error to enter reconnect.
|
||||
first.errChan <- assert.AnError
|
||||
|
||||
// Wait for reconnect's second dial to begin.
|
||||
select {
|
||||
case <-dialStarted:
|
||||
case <-time.After(15 * time.Second):
|
||||
t.Fatal("timed out waiting for reconnect dial")
|
||||
}
|
||||
|
||||
// Stop while dial is in progress (conn is nil at this point).
|
||||
ct.Stop()
|
||||
|
||||
// Let the dial complete. reconnect should detect started==false and close the new conn.
|
||||
close(dialProceed)
|
||||
|
||||
// The second connection should be closed (not leaked).
|
||||
select {
|
||||
case <-second.closedCh:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("second connection was leaked after Stop")
|
||||
}
|
||||
|
||||
ct.mux.Lock()
|
||||
assert.False(t, ct.started)
|
||||
assert.Nil(t, ct.conn)
|
||||
ct.mux.Unlock()
|
||||
}
|
||||
|
||||
func TestStartIsIdempotent(t *testing.T) {
|
||||
mock := newMockListener()
|
||||
callCount := atomic.Int32{}
|
||||
|
||||
ct := &ConnTrack{
|
||||
dial: func() (listener, error) {
|
||||
callCount.Add(1)
|
||||
return mock, nil
|
||||
},
|
||||
done: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
err := ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second Start should be a no-op.
|
||||
err = ct.Start(false)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, int32(1), callCount.Load(), "dial should only be called once")
|
||||
|
||||
ct.Stop()
|
||||
}
|
||||
Reference in New Issue
Block a user