diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go index a43301e73..ff94e352d 100644 --- a/client/internal/peer/wg_watcher.go +++ b/client/internal/peer/wg_watcher.go @@ -28,9 +28,9 @@ type WGWatcher struct { wgIfaceStater WGInterfaceStater peerKey string - ctx context.Context ctxCancel context.CancelFunc ctxLock sync.Mutex + waitGroup sync.WaitGroup } func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string) *WGWatcher { @@ -41,25 +41,30 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin } } +// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing. func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { w.log.Debugf("enable WireGuard watcher") w.ctxLock.Lock() defer w.ctxLock.Unlock() - if w.ctx != nil && w.ctx.Err() == nil { + if w.ctxCancel != nil { + w.log.Errorf("WireGuard watcher already enabled") return } - w.ctx, w.ctxCancel = context.WithCancel(parentCtx) + ctx, ctxCancel := context.WithCancel(parentCtx) + w.ctxCancel = ctxCancel initialHandshake, err := w.wgState() if err != nil { w.log.Warnf("failed to read wg stats: %v", err) } - go w.periodicHandshakeCheck(w.ctx, w.ctxCancel, onDisconnectedFn, initialHandshake) + w.waitGroup.Add(1) + go w.periodicHandshakeCheck(ctx, w.ctxCancel, onDisconnectedFn, initialHandshake) } +// DisableWgWatcher stops the WireGuard watcher and wait for the watcher to exit func (w *WGWatcher) DisableWgWatcher() { w.ctxLock.Lock() defer w.ctxLock.Unlock() @@ -71,11 +76,14 @@ func (w *WGWatcher) DisableWgWatcher() { w.log.Debugf("disable WireGuard watcher") w.ctxCancel() + w.ctxCancel = nil + w.waitGroup.Wait() } // wgStateCheck help to check the state of the WireGuard handshake and relay connection func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) { w.log.Debugf("WireGuard watcher started") + defer w.waitGroup.Done() timer := time.NewTimer(wgHandshakeOvertime) defer timer.Stop()