From bd58eea8ea253f0f5d8aae16c9f3c671875ad263 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Fri, 17 May 2024 09:43:18 +0200 Subject: [PATCH] Refactor network monitor to wait for stop (#1992) --- client/internal/engine.go | 47 ++++++++++++------- client/internal/networkmonitor/monitor.go | 14 ++++-- client/internal/networkmonitor/monitor_bsd.go | 8 ++-- .../networkmonitor/monitor_generic.go | 38 ++++++++------- .../internal/networkmonitor/monitor_linux.go | 10 ++-- .../internal/networkmonitor/monitor_mobile.go | 5 +- .../networkmonitor/monitor_windows.go | 4 +- 7 files changed, 74 insertions(+), 52 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index bdc05ea69..351b21b2e 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -133,7 +133,7 @@ type Engine struct { // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service networkSerial uint64 - networkWatcher *networkmonitor.NetworkWatcher + networkMonitor *networkmonitor.NetworkMonitor sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error) sshServer nbssh.Server @@ -212,7 +212,6 @@ func NewEngineWithProbes( networkSerial: 0, sshServerFunc: nbssh.DefaultSSHServer, statusRecorder: statusRecorder, - networkWatcher: networkmonitor.New(), mgmProbe: mgmProbe, signalProbe: signalProbe, relayProbe: relayProbe, @@ -229,7 +228,10 @@ func (e *Engine) Stop() error { } // stopping network monitor first to avoid starting the engine again - e.networkWatcher.Stop() + if e.networkMonitor != nil { + e.networkMonitor.Stop() + } + log.Info("Network monitor: stopped") err := e.removeAllPeers() if err != nil { @@ -344,20 +346,8 @@ func (e *Engine) Start() error { e.receiveManagementEvents() e.receiveProbeEvents() - if e.config.NetworkMonitor { - // starting network monitor at the very last to avoid disruptions - go e.networkWatcher.Start(e.ctx, func() { - log.Infof("Network monitor detected network change, restarting engine") - if err := e.Stop(); err != nil { - log.Errorf("Failed to stop engine: %v", err) - } - if err := e.Start(); err != nil { - log.Errorf("Failed to start engine: %v", err) - } - }) - } else { - log.Infof("Network monitor is disabled, not starting") - } + // starting network monitor at the very last to avoid disruptions + e.startNetworkMonitor() return nil } @@ -1399,3 +1389,26 @@ func (e *Engine) probeSTUNs() []relay.ProbeResult { func (e *Engine) probeTURNs() []relay.ProbeResult { return relay.ProbeAll(e.ctx, relay.ProbeTURN, e.TURNs) } + +func (e *Engine) startNetworkMonitor() { + if !e.config.NetworkMonitor { + log.Infof("Network monitor is disabled, not starting") + return + } + + e.networkMonitor = networkmonitor.New() + go func() { + err := e.networkMonitor.Start(e.ctx, func() { + log.Infof("Network monitor detected network change, restarting engine") + if err := e.Stop(); err != nil { + log.Errorf("Failed to stop engine: %v", err) + } + if err := e.Start(); err != nil { + log.Errorf("Failed to start engine: %v", err) + } + }) + if err != nil && !errors.Is(err, networkmonitor.ErrStopped) { + log.Errorf("Network monitor: %v", err) + } + }() +} diff --git a/client/internal/networkmonitor/monitor.go b/client/internal/networkmonitor/monitor.go index 71cf031ba..5475455c6 100644 --- a/client/internal/networkmonitor/monitor.go +++ b/client/internal/networkmonitor/monitor.go @@ -2,14 +2,20 @@ package networkmonitor import ( "context" + "errors" + "sync" ) -// NetworkWatcher watches for changes in network configuration. -type NetworkWatcher struct { +var ErrStopped = errors.New("monitor has been stopped") + +// NetworkMonitor watches for changes in network configuration. +type NetworkMonitor struct { cancel context.CancelFunc + wg sync.WaitGroup + mu sync.Mutex } // New creates a new network monitor. -func New() *NetworkWatcher { - return &NetworkWatcher{} +func New() *NetworkMonitor { + return &NetworkMonitor{} } diff --git a/client/internal/networkmonitor/monitor_bsd.go b/client/internal/networkmonitor/monitor_bsd.go index e15c08d7e..de4209f5d 100644 --- a/client/internal/networkmonitor/monitor_bsd.go +++ b/client/internal/networkmonitor/monitor_bsd.go @@ -31,7 +31,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac for { select { case <-ctx.Done(): - return ctx.Err() + return ErrStopped default: buf := make([]byte, 2048) n, err := unix.Read(fd, buf) @@ -63,7 +63,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac } log.Infof("Network monitor: monitored interface (%s) is down.", ifinfo.Name) - callback() + go callback() // handle route changes case unix.RTM_ADD, syscall.RTM_DELETE: @@ -84,11 +84,11 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac switch msg.Type { case unix.RTM_ADD: log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) - callback() + go callback() case unix.RTM_DELETE: if intfv4 != nil && route.Gw.Compare(nexthopv4) == 0 || intfv6 != nil && route.Gw.Compare(nexthopv6) == 0 { log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) - callback() + go callback() } } } diff --git a/client/internal/networkmonitor/monitor_generic.go b/client/internal/networkmonitor/monitor_generic.go index 329246c8f..97cfbc2ca 100644 --- a/client/internal/networkmonitor/monitor_generic.go +++ b/client/internal/networkmonitor/monitor_generic.go @@ -5,6 +5,7 @@ package networkmonitor import ( "context" "errors" + "fmt" "net" "net/netip" "runtime/debug" @@ -15,20 +16,18 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager" ) -// Start begins watching for network changes and calls the callback function and stops when a change is detected. -func (nw *NetworkWatcher) Start(ctx context.Context, callback func()) { - if nw.cancel != nil { - log.Warn("Network monitor: already running, stopping previous watcher") - nw.Stop() - } - +// Start begins monitoring network changes. When a change is detected, it calls the callback asynchronously and returns. +func (nw *NetworkMonitor) Start(ctx context.Context, callback func()) (err error) { if ctx.Err() != nil { - log.Info("Network monitor: not starting, context is already cancelled") - return + return ctx.Err() } + nw.mu.Lock() ctx, nw.cancel = context.WithCancel(ctx) - defer nw.Stop() + nw.mu.Unlock() + + nw.wg.Add(1) + defer nw.wg.Done() var nexthop4, nexthop6 netip.Addr var intf4, intf6 *net.Interface @@ -56,27 +55,30 @@ func (nw *NetworkWatcher) Start(ctx context.Context, callback func()) { expBackOff := backoff.WithContext(backoff.NewExponentialBackOff(), ctx) if err := backoff.Retry(operation, expBackOff); err != nil { - log.Errorf("Network monitor: failed to get default next hops: %v", err) - return + return fmt.Errorf("failed to get default next hops: %w", err) } // recover in case sys ops panic defer func() { if r := recover(); r != nil { - log.Errorf("Network monitor: panic occurred: %v, stack trace: %s", r, string(debug.Stack())) + err = fmt.Errorf("panic occurred: %v, stack trace: %s", r, string(debug.Stack())) } }() - if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil && !errors.Is(err, context.Canceled) { - log.Errorf("Network monitor: failed to start: %v", err) + if err := checkChange(ctx, nexthop4, intf4, nexthop6, intf6, callback); err != nil { + return fmt.Errorf("check change: %w", err) } + + return nil } // Stop stops the network monitor. -func (nw *NetworkWatcher) Stop() { +func (nw *NetworkMonitor) Stop() { + nw.mu.Lock() + defer nw.mu.Unlock() + if nw.cancel != nil { nw.cancel() - nw.cancel = nil - log.Info("Network monitor: stopped") + nw.wg.Wait() } } diff --git a/client/internal/networkmonitor/monitor_linux.go b/client/internal/networkmonitor/monitor_linux.go index f39f1235c..3f93c6ac6 100644 --- a/client/internal/networkmonitor/monitor_linux.go +++ b/client/internal/networkmonitor/monitor_linux.go @@ -36,7 +36,7 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac for { select { case <-ctx.Done(): - return ctx.Err() + return ErrStopped // handle interface state changes case update := <-linkChan: @@ -47,12 +47,12 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac switch update.Header.Type { case syscall.RTM_DELLINK: log.Infof("Network monitor: monitored interface (%s) is gone", update.Link.Attrs().Name) - callback() + go callback() return nil case syscall.RTM_NEWLINK: if (update.IfInfomsg.Flags&syscall.IFF_RUNNING) == 0 && update.Link.Attrs().OperState == netlink.OperDown { log.Infof("Network monitor: monitored interface (%s) is down.", update.Link.Attrs().Name) - callback() + go callback() return nil } } @@ -67,12 +67,12 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac // triggered on added/replaced routes case syscall.RTM_NEWROUTE: log.Infof("Network monitor: default route changed: via %s, interface %d", route.Gw, route.LinkIndex) - callback() + go callback() return nil case syscall.RTM_DELROUTE: if intfv4 != nil && route.Gw.Equal(nexthopv4.AsSlice()) || intfv6 != nil && route.Gw.Equal(nexthop6.AsSlice()) { log.Infof("Network monitor: default route removed: via %s, interface %d", route.Gw, route.LinkIndex) - callback() + go callback() return nil } } diff --git a/client/internal/networkmonitor/monitor_mobile.go b/client/internal/networkmonitor/monitor_mobile.go index 988f296bb..c81fad16c 100644 --- a/client/internal/networkmonitor/monitor_mobile.go +++ b/client/internal/networkmonitor/monitor_mobile.go @@ -4,8 +4,9 @@ package networkmonitor import "context" -func (nw *NetworkWatcher) Start(context.Context, func()) { +func (nw *NetworkMonitor) Start(context.Context, func()) error { + return nil } -func (nw *NetworkWatcher) Stop() { +func (nw *NetworkMonitor) Stop() { } diff --git a/client/internal/networkmonitor/monitor_windows.go b/client/internal/networkmonitor/monitor_windows.go index f6c5d963f..b8d9c6de7 100644 --- a/client/internal/networkmonitor/monitor_windows.go +++ b/client/internal/networkmonitor/monitor_windows.go @@ -48,10 +48,10 @@ func checkChange(ctx context.Context, nexthopv4 netip.Addr, intfv4 *net.Interfac for { select { case <-ctx.Done(): - return ctx.Err() + return ErrStopped case <-ticker.C: if changed(nexthopv4, intfv4, neighborv4, nexthopv6, intfv6, neighborv6) { - callback() + go callback() return nil } }