mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
[client] Stop flow grpc receiver properly (#3596)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
@@ -15,7 +14,7 @@ import (
|
||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||
)
|
||||
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
func TestDefaultManager(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
|
||||
@@ -31,7 +31,7 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
)
|
||||
|
||||
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||
|
||||
type mocWGIface struct {
|
||||
filter device.PacketFilter
|
||||
|
||||
@@ -353,7 +353,7 @@ func (e *Engine) Start() error {
|
||||
|
||||
// start flow manager right after interface creation
|
||||
publicKey := e.config.WgPrivateKey.PublicKey()
|
||||
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
|
||||
e.flowManager = netflow.NewManager(e.wgInterface, publicKey[:], e.statusRecorder)
|
||||
|
||||
if e.config.RosenpassEnabled {
|
||||
log.Infof("rosenpass is enabled")
|
||||
|
||||
@@ -19,11 +19,9 @@ import (
|
||||
type rcvChan chan *types.EventFields
|
||||
type Logger struct {
|
||||
mux sync.Mutex
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
enabled atomic.Bool
|
||||
rcvChan atomic.Pointer[rcvChan]
|
||||
cancelReceiver context.CancelFunc
|
||||
cancel context.CancelFunc
|
||||
statusRecorder *peer.Status
|
||||
wgIfaceIPNet net.IPNet
|
||||
dnsCollection atomic.Bool
|
||||
@@ -31,12 +29,9 @@ type Logger struct {
|
||||
Store types.Store
|
||||
}
|
||||
|
||||
func New(ctx context.Context, statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
||||
func New(statusRecorder *peer.Status, wgIfaceIPNet net.IPNet) *Logger {
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
return &Logger{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
statusRecorder: statusRecorder,
|
||||
wgIfaceIPNet: wgIfaceIPNet,
|
||||
Store: store.NewMemoryStore(),
|
||||
@@ -70,8 +65,8 @@ func (l *Logger) startReceiver() {
|
||||
}
|
||||
|
||||
l.mux.Lock()
|
||||
ctx, cancel := context.WithCancel(l.ctx)
|
||||
l.cancelReceiver = cancel
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
l.cancel = cancel
|
||||
l.mux.Unlock()
|
||||
|
||||
c := make(rcvChan, 100)
|
||||
@@ -109,7 +104,7 @@ func (l *Logger) startReceiver() {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Disable() {
|
||||
func (l *Logger) Close() {
|
||||
l.stop()
|
||||
l.Store.Close()
|
||||
}
|
||||
@@ -121,9 +116,9 @@ func (l *Logger) stop() {
|
||||
|
||||
l.enabled.Store(false)
|
||||
l.mux.Lock()
|
||||
if l.cancelReceiver != nil {
|
||||
l.cancelReceiver()
|
||||
l.cancelReceiver = nil
|
||||
if l.cancel != nil {
|
||||
l.cancel()
|
||||
l.cancel = nil
|
||||
}
|
||||
l.rcvChan.Store(nil)
|
||||
l.mux.Unlock()
|
||||
@@ -142,11 +137,6 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) {
|
||||
l.exitNodeCollection.Store(exitNodeCollection)
|
||||
}
|
||||
|
||||
func (l *Logger) Close() {
|
||||
l.stop()
|
||||
l.cancel()
|
||||
}
|
||||
|
||||
func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool {
|
||||
// check dns collection
|
||||
if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == dnsfwd.ListenPort) {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package logger_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -13,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func TestStore(t *testing.T) {
|
||||
logger := logger.New(context.Background(), nil, net.IPNet{})
|
||||
logger := logger.New(nil, net.IPNet{})
|
||||
logger.Enable()
|
||||
|
||||
event := types.EventFields{
|
||||
@@ -40,7 +39,7 @@ func TestStore(t *testing.T) {
|
||||
}
|
||||
|
||||
// test disable
|
||||
logger.Disable()
|
||||
logger.Close()
|
||||
wait()
|
||||
logger.StoreEvent(event)
|
||||
wait()
|
||||
|
||||
@@ -27,18 +27,18 @@ type Manager struct {
|
||||
logger nftypes.FlowLogger
|
||||
flowConfig *nftypes.FlowConfig
|
||||
conntrack nftypes.ConnTracker
|
||||
ctx context.Context
|
||||
receiverClient *client.GRPCClient
|
||||
publicKey []byte
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewManager creates a new netflow manager
|
||||
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||
func NewManager(iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||
var ipNet net.IPNet
|
||||
if iface != nil {
|
||||
ipNet = *iface.Address().Network
|
||||
}
|
||||
flowLogger := logger.New(ctx, statusRecorder, ipNet)
|
||||
flowLogger := logger.New(statusRecorder, ipNet)
|
||||
|
||||
var ct nftypes.ConnTracker
|
||||
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
||||
@@ -48,7 +48,6 @@ func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte
|
||||
return &Manager{
|
||||
logger: flowLogger,
|
||||
conntrack: ct,
|
||||
ctx: ctx,
|
||||
publicKey: publicKey,
|
||||
}
|
||||
}
|
||||
@@ -68,21 +67,9 @@ func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
|
||||
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
||||
// first make sender ready so events don't pile up
|
||||
if m.needsNewClient(previous) {
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("error closing previous flow client: %v", err)
|
||||
}
|
||||
if err := m.resetClient(); err != nil {
|
||||
return fmt.Errorf("reset client: %w", err)
|
||||
}
|
||||
|
||||
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create client: %w", err)
|
||||
}
|
||||
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
|
||||
|
||||
m.receiverClient = flowClient
|
||||
go m.receiveACKs(flowClient)
|
||||
go m.startSender()
|
||||
}
|
||||
|
||||
m.logger.Enable()
|
||||
@@ -96,17 +83,50 @@ func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) resetClient() error {
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("error closing previous flow client: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create client: %w", err)
|
||||
}
|
||||
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
|
||||
|
||||
m.receiverClient = flowClient
|
||||
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
|
||||
go m.receiveACKs(ctx, flowClient)
|
||||
go m.startSender(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// disableFlow stops components for flow tracking
|
||||
func (m *Manager) disableFlow() error {
|
||||
if m.cancel != nil {
|
||||
m.cancel()
|
||||
}
|
||||
|
||||
if m.conntrack != nil {
|
||||
m.conntrack.Stop()
|
||||
}
|
||||
|
||||
m.logger.Disable()
|
||||
m.logger.Close()
|
||||
|
||||
if m.receiverClient != nil {
|
||||
return m.receiverClient.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -133,17 +153,18 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
|
||||
|
||||
m.logger.UpdateConfig(update.DNSCollection, update.ExitNodeCollection)
|
||||
|
||||
changed := previous != nil && update.Enabled != previous.Enabled
|
||||
if update.Enabled {
|
||||
log.Infof("netflow manager enabled; starting netflow manager")
|
||||
if changed {
|
||||
log.Infof("netflow manager enabled; starting netflow manager")
|
||||
}
|
||||
return m.enableFlow(previous)
|
||||
}
|
||||
|
||||
log.Infof("netflow manager disabled; stopping netflow manager")
|
||||
err := m.disableFlow()
|
||||
if err != nil {
|
||||
log.Errorf("failed to disable netflow manager: %v", err)
|
||||
if changed {
|
||||
log.Infof("netflow manager disabled; stopping netflow manager")
|
||||
}
|
||||
return err
|
||||
return m.disableFlow()
|
||||
}
|
||||
|
||||
// Close cleans up all resources
|
||||
@@ -151,17 +172,9 @@ func (m *Manager) Close() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
if m.conntrack != nil {
|
||||
m.conntrack.Close()
|
||||
if err := m.disableFlow(); err != nil {
|
||||
log.Warnf("failed to disable flow manager: %v", err)
|
||||
}
|
||||
|
||||
if m.receiverClient != nil {
|
||||
if err := m.receiverClient.Close(); err != nil {
|
||||
log.Warnf("failed to close receiver client: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Close()
|
||||
}
|
||||
|
||||
// GetLogger returns the flow logger
|
||||
@@ -169,13 +182,13 @@ func (m *Manager) GetLogger() nftypes.FlowLogger {
|
||||
return m.logger
|
||||
}
|
||||
|
||||
func (m *Manager) startSender() {
|
||||
func (m *Manager) startSender(ctx context.Context) {
|
||||
ticker := time.NewTicker(m.flowConfig.Interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-m.ctx.Done():
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
events := m.logger.GetEvents()
|
||||
@@ -190,8 +203,8 @@ func (m *Manager) startSender() {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) receiveACKs(client *client.GRPCClient) {
|
||||
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
|
||||
func (m *Manager) receiveACKs(ctx context.Context, client *client.GRPCClient) {
|
||||
err := client.Receive(ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
|
||||
id, err := uuid.FromBytes(ack.EventId)
|
||||
if err != nil {
|
||||
log.Warnf("failed to convert ack event id to uuid: %v", err)
|
||||
|
||||
200
client/internal/netflow/manager_test.go
Normal file
200
client/internal/netflow/manager_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package netflow
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
type mockIFaceMapper struct {
|
||||
address wgaddr.Address
|
||||
isUserspaceBind bool
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) Name() string {
|
||||
return "wt0"
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) Address() wgaddr.Address {
|
||||
return m.address
|
||||
}
|
||||
|
||||
func (m *mockIFaceMapper) IsUserspaceBind() bool {
|
||||
return m.isUserspaceBind
|
||||
}
|
||||
|
||||
func TestManager_Update(t *testing.T) {
|
||||
mockIFace := &mockIFaceMapper{
|
||||
address: wgaddr.Address{
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
isUserspaceBind: true,
|
||||
}
|
||||
|
||||
publicKey := []byte("test-public-key")
|
||||
statusRecorder := peer.NewRecorder("")
|
||||
|
||||
manager := NewManager(mockIFace, publicKey, statusRecorder)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config *types.FlowConfig
|
||||
}{
|
||||
{
|
||||
name: "nil config",
|
||||
config: nil,
|
||||
},
|
||||
{
|
||||
name: "disabled config",
|
||||
config: &types.FlowConfig{
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enabled config with minimal valid settings",
|
||||
config: &types.FlowConfig{
|
||||
Enabled: true,
|
||||
URL: "https://example.com",
|
||||
TokenPayload: "test-payload",
|
||||
TokenSignature: "test-signature",
|
||||
Interval: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := manager.Update(tc.config)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
if tc.config == nil {
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, manager.flowConfig)
|
||||
|
||||
if tc.config.Enabled {
|
||||
assert.Equal(t, tc.config.Enabled, manager.flowConfig.Enabled)
|
||||
}
|
||||
|
||||
if tc.config.URL != "" {
|
||||
assert.Equal(t, tc.config.URL, manager.flowConfig.URL)
|
||||
}
|
||||
|
||||
if tc.config.TokenPayload != "" {
|
||||
assert.Equal(t, tc.config.TokenPayload, manager.flowConfig.TokenPayload)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestManager_Update_TokenPreservation(t *testing.T) {
|
||||
mockIFace := &mockIFaceMapper{
|
||||
address: wgaddr.Address{
|
||||
Network: &net.IPNet{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
Mask: net.CIDRMask(24, 32),
|
||||
},
|
||||
},
|
||||
isUserspaceBind: true,
|
||||
}
|
||||
|
||||
publicKey := []byte("test-public-key")
|
||||
manager := NewManager(mockIFace, publicKey, nil)
|
||||
|
||||
// First update with tokens
|
||||
initialConfig := &types.FlowConfig{
|
||||
Enabled: false,
|
||||
TokenPayload: "initial-payload",
|
||||
TokenSignature: "initial-signature",
|
||||
}
|
||||
|
||||
err := manager.Update(initialConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Second update without tokens should preserve them
|
||||
updatedConfig := &types.FlowConfig{
|
||||
Enabled: false,
|
||||
URL: "https://example.com",
|
||||
}
|
||||
|
||||
err = manager.Update(updatedConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify tokens were preserved
|
||||
assert.Equal(t, "initial-payload", manager.flowConfig.TokenPayload)
|
||||
assert.Equal(t, "initial-signature", manager.flowConfig.TokenSignature)
|
||||
}
|
||||
|
||||
func TestManager_NeedsNewClient(t *testing.T) {
|
||||
manager := &Manager{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
previous *types.FlowConfig
|
||||
current *types.FlowConfig
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil previous config",
|
||||
previous: nil,
|
||||
current: &types.FlowConfig{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "previous disabled",
|
||||
previous: &types.FlowConfig{Enabled: false},
|
||||
current: &types.FlowConfig{Enabled: true},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different URL",
|
||||
previous: &types.FlowConfig{Enabled: true, URL: "old-url"},
|
||||
current: &types.FlowConfig{Enabled: true, URL: "new-url"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different TokenPayload",
|
||||
previous: &types.FlowConfig{Enabled: true, TokenPayload: "old-payload"},
|
||||
current: &types.FlowConfig{Enabled: true, TokenPayload: "new-payload"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different TokenSignature",
|
||||
previous: &types.FlowConfig{Enabled: true, TokenSignature: "old-signature"},
|
||||
current: &types.FlowConfig{Enabled: true, TokenSignature: "new-signature"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "same config",
|
||||
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
|
||||
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "only interval changed",
|
||||
previous: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 30 * time.Second},
|
||||
current: &types.FlowConfig{Enabled: true, URL: "url", TokenPayload: "payload", TokenSignature: "signature", Interval: 60 * time.Second},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
manager.flowConfig = tc.current
|
||||
result := manager.needsNewClient(tc.previous)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -120,9 +120,6 @@ type FlowLogger interface {
|
||||
Close()
|
||||
// Enable enables the flow logger receiver
|
||||
Enable()
|
||||
// Disable disables the flow logger receiver
|
||||
Disable()
|
||||
|
||||
// UpdateConfig updates the flow manager configuration
|
||||
UpdateConfig(dnsCollection, exitNodeCollection bool)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user