diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go index a05f4f83a..a43301e73 100644 --- a/client/internal/peer/wg_watcher.go +++ b/client/internal/peer/wg_watcher.go @@ -16,7 +16,6 @@ const ( var ( wgHandshakeOvertime = 30 * time.Second - wgReadErrorRetry = 5 * time.Second checkPeriod = wgHandshakePeriod + wgHandshakeOvertime ) @@ -51,11 +50,14 @@ func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn return } - ctx, ctxCancel := context.WithCancel(parentCtx) - w.ctx = ctx - w.ctxCancel = ctxCancel + w.ctx, w.ctxCancel = context.WithCancel(parentCtx) - w.wgStateCheck(ctx, ctxCancel, onDisconnectedFn) + initialHandshake, err := w.wgState() + if err != nil { + w.log.Warnf("failed to read wg stats: %v", err) + } + + go w.periodicHandshakeCheck(w.ctx, w.ctxCancel, onDisconnectedFn, initialHandshake) } func (w *WGWatcher) DisableWgWatcher() { @@ -72,46 +74,30 @@ func (w *WGWatcher) DisableWgWatcher() { } // wgStateCheck help to check the state of the WireGuard handshake and relay connection -func (w *WGWatcher) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func()) { +func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func(), initialHandshake time.Time) { w.log.Debugf("WireGuard watcher started") - lastHandshake, err := w.wgState() - if err != nil { - w.log.Warnf("failed to read wg stats: %v", err) - lastHandshake = time.Time{} - } - go func(lastHandshake time.Time) { - timer := time.NewTimer(wgHandshakeOvertime) - defer timer.Stop() - defer ctxCancel() + timer := time.NewTimer(wgHandshakeOvertime) + defer timer.Stop() + defer ctxCancel() - for { - select { - case <-timer.C: - handshake, err := w.wgState() - if err != nil { - w.log.Errorf("failed to read wg stats: %v", err) - timer.Reset(wgReadErrorRetry) - continue - } + lastHandshake := initialHandshake - w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake) - - if handshake.Equal(lastHandshake) { - w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake) - onDisconnectedFn() - return - } - - resetTime := time.Until(handshake.Add(checkPeriod)) - lastHandshake = handshake - timer.Reset(resetTime) - case <-ctx.Done(): - w.log.Debugf("WireGuard watcher stopped") + for { + select { + case <-timer.C: + handshake, ok := w.handshakeCheck(lastHandshake) + if !ok { + onDisconnectedFn() return } + timer.Reset(time.Until(handshake.Add(checkPeriod))) + lastHandshake = *handshake + case <-ctx.Done(): + w.log.Debugf("WireGuard watcher stopped") + return } - }(lastHandshake) + } } func (w *WGWatcher) wgState() (time.Time, error) { @@ -121,3 +107,20 @@ func (w *WGWatcher) wgState() (time.Time, error) { } return wgState.LastHandshake, nil } + +func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) { + handshake, err := w.wgState() + if err != nil { + w.log.Errorf("failed to read wg stats: %v", err) + return nil, false + } + + w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake) + + if handshake.Equal(lastHandshake) { + w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake) + return nil, false + } + + return &handshake, true +}