diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go new file mode 100644 index 000000000..a05f4f83a --- /dev/null +++ b/client/internal/peer/wg_watcher.go @@ -0,0 +1,123 @@ +package peer + +import ( + "context" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/configurer" +) + +const ( + wgHandshakePeriod = 3 * time.Minute +) + +var ( + wgHandshakeOvertime = 30 * time.Second + wgReadErrorRetry = 5 * time.Second + checkPeriod = wgHandshakePeriod + wgHandshakeOvertime +) + +type WGInterfaceStater interface { + GetStats(key string) (configurer.WGStats, error) +} + +type WGWatcher struct { + log *log.Entry + wgIfaceStater WGInterfaceStater + peerKey string + + ctx context.Context + ctxCancel context.CancelFunc + ctxLock sync.Mutex +} + +func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string) *WGWatcher { + return &WGWatcher{ + log: log, + wgIfaceStater: wgIfaceStater, + peerKey: peerKey, + } +} + +func (w *WGWatcher) EnableWgWatcher(parentCtx context.Context, onDisconnectedFn func()) { + w.log.Debugf("enable WireGuard watcher") + w.ctxLock.Lock() + defer w.ctxLock.Unlock() + + if w.ctx != nil && w.ctx.Err() == nil { + return + } + + ctx, ctxCancel := context.WithCancel(parentCtx) + w.ctx = ctx + w.ctxCancel = ctxCancel + + w.wgStateCheck(ctx, ctxCancel, onDisconnectedFn) +} + +func (w *WGWatcher) DisableWgWatcher() { + w.ctxLock.Lock() + defer w.ctxLock.Unlock() + + if w.ctxCancel == nil { + return + } + + w.log.Debugf("disable WireGuard watcher") + + w.ctxCancel() +} + +// wgStateCheck help to check the state of the WireGuard handshake and relay connection +func (w *WGWatcher) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc, onDisconnectedFn func()) { + w.log.Debugf("WireGuard watcher started") + lastHandshake, err := w.wgState() + if err != nil { + w.log.Warnf("failed to read wg stats: %v", err) + lastHandshake = time.Time{} + } + + go func(lastHandshake time.Time) { + timer := time.NewTimer(wgHandshakeOvertime) + defer timer.Stop() + defer ctxCancel() + + for { + select { + case <-timer.C: + handshake, err := w.wgState() + if err != nil { + w.log.Errorf("failed to read wg stats: %v", err) + timer.Reset(wgReadErrorRetry) + continue + } + + w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake) + + if handshake.Equal(lastHandshake) { + w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake) + onDisconnectedFn() + return + } + + resetTime := time.Until(handshake.Add(checkPeriod)) + lastHandshake = handshake + timer.Reset(resetTime) + case <-ctx.Done(): + w.log.Debugf("WireGuard watcher stopped") + return + } + } + }(lastHandshake) +} + +func (w *WGWatcher) wgState() (time.Time, error) { + wgState, err := w.wgIfaceStater.GetStats(w.peerKey) + if err != nil { + return time.Time{}, err + } + return wgState.LastHandshake, nil +} diff --git a/client/internal/peer/wg_watcher_test.go b/client/internal/peer/wg_watcher_test.go new file mode 100644 index 000000000..a5b9026ad --- /dev/null +++ b/client/internal/peer/wg_watcher_test.go @@ -0,0 +1,98 @@ +package peer + +import ( + "context" + "testing" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/configurer" +) + +type MocWgIface struct { + initial bool + lastHandshake time.Time + stop bool +} + +func (m *MocWgIface) GetStats(key string) (configurer.WGStats, error) { + if !m.initial { + m.initial = true + return configurer.WGStats{}, nil + } + + if !m.stop { + m.lastHandshake = time.Now() + } + + stats := configurer.WGStats{ + LastHandshake: m.lastHandshake, + } + + return stats, nil +} + +func (m *MocWgIface) disconnect() { + m.stop = true +} + +func TestWGWatcher_EnableWgWatcher(t *testing.T) { + checkPeriod = 5 * time.Second + wgHandshakeOvertime = 1 * time.Second + + mlog := log.WithField("peer", "tet") + mocWgIface := &MocWgIface{} + watcher := NewWGWatcher(mlog, mocWgIface, "") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + onDisconnected := make(chan struct{}, 1) + watcher.EnableWgWatcher(ctx, func() { + mlog.Infof("onDisconnectedFn") + onDisconnected <- struct{}{} + }) + + // wait for initial reading + time.Sleep(2 * time.Second) + mocWgIface.disconnect() + + select { + case <-onDisconnected: + case <-time.After(10 * time.Second): + t.Errorf("timeout") + } + watcher.DisableWgWatcher() +} + +func TestWGWatcher_ReEnable(t *testing.T) { + checkPeriod = 5 * time.Second + wgHandshakeOvertime = 1 * time.Second + + mlog := log.WithField("peer", "tet") + mocWgIface := &MocWgIface{} + watcher := NewWGWatcher(mlog, mocWgIface, "") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + onDisconnected := make(chan struct{}, 1) + + watcher.EnableWgWatcher(ctx, func() {}) + watcher.DisableWgWatcher() + + watcher.EnableWgWatcher(ctx, func() { + onDisconnected <- struct{}{} + }) + + time.Sleep(2 * time.Second) + mocWgIface.disconnect() + + select { + case <-onDisconnected: + case <-time.After(10 * time.Second): + t.Errorf("timeout") + } + watcher.DisableWgWatcher() +} diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index c22dcdeda..d2df8cd36 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -6,18 +6,12 @@ import ( "net" "sync" "sync/atomic" - "time" log "github.com/sirupsen/logrus" relayClient "github.com/netbirdio/netbird/relay/client" ) -var ( - wgHandshakePeriod = 3 * time.Minute - wgHandshakeOvertime = 30 * time.Second -) - type RelayConnInfo struct { relayedConn net.Conn rosenpassPubKey []byte @@ -36,13 +30,12 @@ type WorkerRelay struct { relayManager relayClient.ManagerService callBacks WorkerRelayCallbacks - relayedConn net.Conn - relayLock sync.Mutex - ctxWgWatch context.Context - ctxCancelWgWatch context.CancelFunc - ctxLock sync.Mutex + relayedConn net.Conn + relayLock sync.Mutex relaySupportedOnRemotePeer atomic.Bool + + wgWatcher *WGWatcher } func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager relayClient.ManagerService, callbacks WorkerRelayCallbacks) *WorkerRelay { @@ -52,6 +45,7 @@ func NewWorkerRelay(log *log.Entry, ctrl bool, config ConnConfig, relayManager r config: config, relayManager: relayManager, callBacks: callbacks, + wgWatcher: NewWGWatcher(log, config.WgConfig.WgInterface, config.Key), } return r } @@ -103,32 +97,11 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) { } func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) { - w.log.Debugf("enable WireGuard watcher") - w.ctxLock.Lock() - defer w.ctxLock.Unlock() - - if w.ctxWgWatch != nil && w.ctxWgWatch.Err() == nil { - return - } - - ctx, ctxCancel := context.WithCancel(ctx) - w.ctxWgWatch = ctx - w.ctxCancelWgWatch = ctxCancel - - w.wgStateCheck(ctx, ctxCancel) + w.wgWatcher.EnableWgWatcher(ctx, w.disconnected) } func (w *WorkerRelay) DisableWgWatcher() { - w.ctxLock.Lock() - defer w.ctxLock.Unlock() - - if w.ctxCancelWgWatch == nil { - return - } - - w.log.Debugf("disable WireGuard watcher") - - w.ctxCancelWgWatch() + w.wgWatcher.DisableWgWatcher() } func (w *WorkerRelay) RelayInstanceAddress() (string, error) { @@ -150,57 +123,17 @@ func (w *WorkerRelay) CloseConn() { return } - err := w.relayedConn.Close() - if err != nil { + if err := w.relayedConn.Close(); err != nil { w.log.Warnf("failed to close relay connection: %v", err) } } -// wgStateCheck help to check the state of the WireGuard handshake and relay connection -func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) { - w.log.Debugf("WireGuard watcher started") - lastHandshake, err := w.wgState() - if err != nil { - w.log.Warnf("failed to read wg stats: %v", err) - lastHandshake = time.Time{} - } - - go func(lastHandshake time.Time) { - timer := time.NewTimer(wgHandshakeOvertime) - defer timer.Stop() - defer ctxCancel() - - for { - select { - case <-timer.C: - handshake, err := w.wgState() - if err != nil { - w.log.Errorf("failed to read wg stats: %v", err) - timer.Reset(wgHandshakeOvertime) - continue - } - - w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake) - - if handshake.Equal(lastHandshake) { - w.log.Infof("WireGuard handshake timed out, closing relay connection: %v", handshake) - w.relayLock.Lock() - _ = w.relayedConn.Close() - w.relayLock.Unlock() - w.callBacks.OnDisconnected() - return - } - - resetTime := time.Until(handshake.Add(wgHandshakePeriod + wgHandshakeOvertime)) - lastHandshake = handshake - timer.Reset(resetTime) - case <-ctx.Done(): - w.log.Debugf("WireGuard watcher stopped") - return - } - } - }(lastHandshake) +func (w *WorkerRelay) disconnected() { + w.relayLock.Lock() + _ = w.relayedConn.Close() + w.relayLock.Unlock() + w.callBacks.OnDisconnected() } func (w *WorkerRelay) isRelaySupported(answer *OfferAnswer) bool { @@ -217,20 +150,7 @@ func (w *WorkerRelay) preferredRelayServer(myRelayAddress, remoteRelayAddress st return remoteRelayAddress } -func (w *WorkerRelay) wgState() (time.Time, error) { - wgState, err := w.config.WgConfig.WgInterface.GetStats(w.config.Key) - if err != nil { - return time.Time{}, err - } - return wgState.LastHandshake, nil -} - func (w *WorkerRelay) onRelayMGDisconnected() { - w.ctxLock.Lock() - defer w.ctxLock.Unlock() - - if w.ctxCancelWgWatch != nil { - w.ctxCancelWgWatch() - } + w.wgWatcher.DisableWgWatcher() go w.callBacks.OnDisconnected() }