Block on all subsystems on shutdown

This commit is contained in:
Viktor Liu
2025-10-27 22:57:37 +01:00
parent 7f08983207
commit ae801d77fb
6 changed files with 122 additions and 52 deletions

View File

@@ -200,8 +200,10 @@ type Engine struct {
flowManager nftypes.FlowManager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup
wgIfaceMonitor *WGIfaceMonitor
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
shutdownWg sync.WaitGroup
// dns forwarder port
dnsFwdPort uint16
@@ -326,10 +328,6 @@ func (e *Engine) Stop() error {
e.cancel()
}
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously
time.Sleep(500 * time.Millisecond)
e.close()
// stop flow manager after wg interface is gone
@@ -337,8 +335,6 @@ func (e *Engine) Stop() error {
e.flowManager.Close()
}
log.Infof("stopped Netbird Engine")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
@@ -349,12 +345,52 @@ func (e *Engine) Stop() error {
log.Errorf("failed to persist state: %v", err)
}
// Stop WireGuard interface monitor and wait for it to exit
e.wgIfaceMonitorWg.Wait()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
}
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
func (e *Engine) calculateShutdownTimeout() time.Duration {
peerCount := len(e.peerStore.PeersPubKey())
baseTimeout := 10 * time.Second
perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond
timeout := baseTimeout + perPeerTimeout
maxTimeout := 30 * time.Second
if timeout > maxTimeout {
timeout = maxTimeout
}
return timeout
}
// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout.
func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
@@ -484,14 +520,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.wgIfaceMonitorWg.Add(1)
e.shutdownWg.Add(1)
go func() {
defer e.wgIfaceMonitorWg.Done()
defer e.shutdownWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.restartEngine()
e.triggerClientRestart()
} else if err != nil {
log.Warnf("WireGuard interface monitor: %s", err)
}
@@ -892,7 +928,9 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if err != nil {
return fmt.Errorf("create ssh server: %w", err)
}
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
// blocking
err = e.sshServer.Start()
if err != nil {
@@ -950,7 +988,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
@@ -1368,7 +1408,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
func (e *Engine) receiveSignalEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
// connect to a stream of messages coming from the signal server
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
e.syncMsgMux.Lock()
@@ -1724,8 +1766,10 @@ func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult {
)
}
// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() {
// triggerClientRestart triggers a full client restart by cancelling the client context.
// Note: This does NOT just restart the engine - it cancels the entire client context,
// which causes the connect client's retry loop to create a completely new engine.
func (e *Engine) triggerClientRestart() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -1747,7 +1791,9 @@ func (e *Engine) startNetworkMonitor() {
}
e.networkMonitor = networkmonitor.New()
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
if err := e.networkMonitor.Listen(e.ctx); err != nil {
if errors.Is(err, context.Canceled) {
log.Infof("network monitor stopped")
@@ -1757,8 +1803,8 @@ func (e *Engine) startNetworkMonitor() {
return
}
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
log.Infof("Network monitor: detected network change, triggering client restart")
e.triggerClientRestart()
}()
}