From ae801d77fb579cb7c6d39ab1cc45979c2e91c753 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Mon, 27 Oct 2025 22:57:37 +0100 Subject: [PATCH] Block on all subsystems on shutdown --- client/internal/connect.go | 32 +++++----- client/internal/dns/server.go | 9 ++- client/internal/engine.go | 80 +++++++++++++++++++----- client/internal/netflow/manager.go | 17 +++-- client/internal/peer/guard/sr_watcher.go | 22 ++++--- client/internal/routemanager/manager.go | 14 ++++- 6 files changed, 122 insertions(+), 52 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index c9331baf5..bb7c2b38b 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -25,6 +25,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/stdnet" + nbnet "github.com/netbirdio/netbird/client/net" cProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -34,7 +35,6 @@ import ( relayClient "github.com/netbirdio/netbird/shared/relay/client" signal "github.com/netbirdio/netbird/shared/signal/client" "github.com/netbirdio/netbird/util" - nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/version" ) @@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan } <-engineCtx.Done() + c.engineMutex.Lock() - if c.engine != nil && c.engine.wgInterface != nil { - log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name()) - if err := c.engine.Stop(); err != nil { + engine := c.engine + c.engine = nil + c.engineMutex.Unlock() + + if engine != nil && engine.wgInterface != nil { + log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name()) + if err := engine.Stop(); err != nil { log.Errorf("Failed to stop engine: %v", err) } - c.engine = nil } - c.engineMutex.Unlock() c.statusRecorder.ClientTeardown() backOff.Reset() @@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType { } func (c *ConnectClient) Stop() error { - if c == nil { - return nil + engine := c.Engine() + if engine != nil { + if err := engine.Stop(); err != nil { + return fmt.Errorf("stop engine: %w", err) + } } - c.engineMutex.Lock() - defer c.engineMutex.Unlock() - - if c.engine == nil { - return nil - } - if err := c.engine.Stop(); err != nil { - return fmt.Errorf("stop engine: %w", err) - } - return nil } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 8cb886203..afaf0579f 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface { // DefaultServer dns server object type DefaultServer struct { - ctx context.Context - ctxCancel context.CancelFunc + ctx context.Context + ctxCancel context.CancelFunc + shutdownWg sync.WaitGroup // disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running. // This is different from ServiceEnable=false from management which completely disables the DNS service. disableSys bool @@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr { // Stop stops the server func (s *DefaultServer) Stop() { s.ctxCancel() + s.shutdownWg.Wait() s.mux.Lock() defer s.mux.Unlock() @@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.applyHostConfig() + s.shutdownWg.Add(1) go func() { - // persist dns state right away + defer s.shutdownWg.Done() if err := s.stateManager.PersistState(s.ctx); err != nil { log.Errorf("Failed to persist dns state: %v", err) } diff --git a/client/internal/engine.go b/client/internal/engine.go index bebf04f6c..dc92855f0 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -200,8 +200,10 @@ type Engine struct { flowManager nftypes.FlowManager // WireGuard interface monitor - wgIfaceMonitor *WGIfaceMonitor - wgIfaceMonitorWg sync.WaitGroup + wgIfaceMonitor *WGIfaceMonitor + + // shutdownWg tracks all long-running goroutines to ensure clean shutdown + shutdownWg sync.WaitGroup // dns forwarder port dnsFwdPort uint16 @@ -326,10 +328,6 @@ func (e *Engine) Stop() error { e.cancel() } - // very ugly but we want to remove peers from the WireGuard interface first before removing interface. - // Removing peers happens in the conn.Close() asynchronously - time.Sleep(500 * time.Millisecond) - e.close() // stop flow manager after wg interface is gone @@ -337,8 +335,6 @@ func (e *Engine) Stop() error { e.flowManager.Close() } - log.Infof("stopped Netbird Engine") - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -349,12 +345,52 @@ func (e *Engine) Stop() error { log.Errorf("failed to persist state: %v", err) } - // Stop WireGuard interface monitor and wait for it to exit - e.wgIfaceMonitorWg.Wait() + timeout := e.calculateShutdownTimeout() + log.Debugf("waiting for goroutines to finish with timeout: %v", timeout) + shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil { + log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout) + } + + log.Infof("stopped Netbird Engine") return nil } +// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s. +func (e *Engine) calculateShutdownTimeout() time.Duration { + peerCount := len(e.peerStore.PeersPubKey()) + + baseTimeout := 10 * time.Second + perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond + timeout := baseTimeout + perPeerTimeout + + maxTimeout := 30 * time.Second + if timeout > maxTimeout { + timeout = maxTimeout + } + + return timeout +} + +// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout. +func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Connections to remote peers are not established here. // However, they will be established once an event with a list of peers to connect to will be received from Management Service @@ -484,14 +520,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // monitor WireGuard interface lifecycle and restart engine on changes e.wgIfaceMonitor = NewWGIfaceMonitor() - e.wgIfaceMonitorWg.Add(1) + e.shutdownWg.Add(1) go func() { - defer e.wgIfaceMonitorWg.Done() + defer e.shutdownWg.Done() if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { log.Infof("WireGuard interface monitor: %s, restarting engine", err) - e.restartEngine() + e.triggerClientRestart() } else if err != nil { log.Warnf("WireGuard interface monitor: %s", err) } @@ -892,7 +928,9 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { if err != nil { return fmt.Errorf("create ssh server: %w", err) } + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() // blocking err = e.sshServer.Start() if err != nil { @@ -950,7 +988,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { // receiveManagementEvents connects to the Management Service event stream to receive updates from the management service // E.g. when a new peer has been registered and we are allowed to connect to it. func (e *Engine) receiveManagementEvents() { + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() info, err := system.GetInfoWithChecks(e.ctx, e.checks) if err != nil { log.Warnf("failed to get system info with checks: %v", err) @@ -1368,7 +1408,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV // receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers func (e *Engine) receiveSignalEvents() { + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() // connect to a stream of messages coming from the signal server err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error { e.syncMsgMux.Lock() @@ -1724,8 +1766,10 @@ func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult { ) } -// restartEngine restarts the engine by cancelling the client context -func (e *Engine) restartEngine() { +// triggerClientRestart triggers a full client restart by cancelling the client context. +// Note: This does NOT just restart the engine - it cancels the entire client context, +// which causes the connect client's retry loop to create a completely new engine. +func (e *Engine) triggerClientRestart() { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -1747,7 +1791,9 @@ func (e *Engine) startNetworkMonitor() { } e.networkMonitor = networkmonitor.New() + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() if err := e.networkMonitor.Listen(e.ctx); err != nil { if errors.Is(err, context.Canceled) { log.Infof("network monitor stopped") @@ -1757,8 +1803,8 @@ func (e *Engine) startNetworkMonitor() { return } - log.Infof("Network monitor: detected network change, restarting engine") - e.restartEngine() + log.Infof("Network monitor: detected network change, triggering client restart") + e.triggerClientRestart() }() } diff --git a/client/internal/netflow/manager.go b/client/internal/netflow/manager.go index e3b188468..7752c97b0 100644 --- a/client/internal/netflow/manager.go +++ b/client/internal/netflow/manager.go @@ -24,6 +24,7 @@ import ( // Manager handles netflow tracking and logging type Manager struct { mux sync.Mutex + shutdownWg sync.WaitGroup logger nftypes.FlowLogger flowConfig *nftypes.FlowConfig conntrack nftypes.ConnTracker @@ -105,8 +106,15 @@ func (m *Manager) resetClient() error { ctx, cancel := context.WithCancel(context.Background()) m.cancel = cancel - go m.receiveACKs(ctx, flowClient) - go m.startSender(ctx) + m.shutdownWg.Add(2) + go func() { + defer m.shutdownWg.Done() + m.receiveACKs(ctx, flowClient) + }() + go func() { + defer m.shutdownWg.Done() + m.startSender(ctx) + }() return nil } @@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error { // Close cleans up all resources func (m *Manager) Close() { m.mux.Lock() - defer m.mux.Unlock() - if err := m.disableFlow(); err != nil { log.Warnf("failed to disable flow manager: %v", err) } + m.mux.Unlock() + + m.shutdownWg.Wait() } // GetLogger returns the flow logger diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 686430752..510bc390d 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -19,11 +19,11 @@ type SRWatcher struct { signalClient chNotifier relayManager chNotifier - listeners map[chan struct{}]struct{} - mu sync.Mutex - iFaceDiscover stdnet.ExternalIFaceDiscover - iceConfig ice.Config - + listeners map[chan struct{}]struct{} + mu sync.Mutex + shutdownWg sync.WaitGroup + iFaceDiscover stdnet.ExternalIFaceDiscover + iceConfig ice.Config cancelIceMonitor context.CancelFunc } @@ -52,7 +52,11 @@ func (w *SRWatcher) Start() { w.cancelIceMonitor = cancel iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) - go iceMonitor.Start(ctx, w.onICEChanged) + w.shutdownWg.Add(1) + go func() { + defer w.shutdownWg.Done() + iceMonitor.Start(ctx, w.onICEChanged) + }() w.signalClient.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected) @@ -60,14 +64,16 @@ func (w *SRWatcher) Start() { func (w *SRWatcher) Close() { w.mu.Lock() - defer w.mu.Unlock() - if w.cancelIceMonitor == nil { + w.mu.Unlock() return } w.cancelIceMonitor() w.signalClient.SetOnReconnectedListener(nil) w.relayManager.SetOnReconnectedListener(nil) + w.mu.Unlock() + + w.shutdownWg.Wait() } func (w *SRWatcher) NewListener() chan struct{} { diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index d590dba0d..7db5487c6 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -78,6 +78,7 @@ type DefaultManager struct { ctx context.Context stop context.CancelFunc mux sync.Mutex + shutdownWg sync.WaitGroup clientNetworks map[route.HAUniqueID]*client.Watcher routeSelector *routeselector.RouteSelector serverRouter *server.Router @@ -273,6 +274,7 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error { // Stop stops the manager watchers and clean firewall rules func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { m.stop() + m.shutdownWg.Wait() if m.serverRouter != nil { m.serverRouter.CleanUp() } @@ -474,7 +476,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { } clientNetworkWatcher := client.NewWatcher(config) m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.Start() + m.shutdownWg.Add(1) + go func() { + defer m.shutdownWg.Done() + clientNetworkWatcher.Start() + }() clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes}) } @@ -516,7 +522,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout } clientNetworkWatcher = client.NewWatcher(config) m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.Start() + m.shutdownWg.Add(1) + go func() { + defer m.shutdownWg.Done() + clientNetworkWatcher.Start() + }() } update := client.RoutesUpdate{ UpdateSerial: updateSerial,