diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index eb455431d..af6ab3f83 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -410,7 +410,7 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn conn.doOnConnected(iceConnInfo.RosenpassPubKey, iceConnInfo.RosenpassAddr) } -func (conn *Conn) onICEStateDisconnected() { +func (conn *Conn) onICEStateDisconnected(sessionChanged bool) { conn.mu.Lock() defer conn.mu.Unlock() @@ -430,6 +430,10 @@ func (conn *Conn) onICEStateDisconnected() { if conn.isReadyToUpgrade() { conn.Log.Infof("ICE disconnected, set Relay to active connection") conn.dumpState.SwitchToRelay() + if sessionChanged { + conn.resetEndpoint() + } + conn.wgProxyRelay.Work() presharedKey := conn.presharedKey(conn.rosenpassRemoteKey) @@ -757,6 +761,17 @@ func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { return wgProxy, nil } +func (conn *Conn) resetEndpoint() { + if !isController(conn.config) { + return + } + conn.Log.Infof("reset wg endpoint") + conn.wgWatcher.Reset() + if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil { + conn.Log.Warnf("failed to remove endpoint address before update: %v", err) + } +} + func (conn *Conn) isReadyToUpgrade() bool { return conn.wgProxyRelay != nil && conn.currentConnPriority != conntype.Relay } diff --git a/client/internal/peer/endpoint.go b/client/internal/peer/endpoint.go index 52d66159c..372f33ec6 100644 --- a/client/internal/peer/endpoint.go +++ b/client/internal/peer/endpoint.go @@ -66,6 +66,10 @@ func (e *EndpointUpdater) RemoveWgPeer() error { return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey) } +func (e *EndpointUpdater) RemoveEndpointAddress() error { + return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey) +} + func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() { if e.cancelFunc == nil { return diff --git a/client/internal/peer/wg_watcher.go b/client/internal/peer/wg_watcher.go index d40ec7a80..799a9375e 100644 --- a/client/internal/peer/wg_watcher.go +++ b/client/internal/peer/wg_watcher.go @@ -32,6 +32,8 @@ type WGWatcher struct { enabled bool muEnabled sync.RWMutex + + resetCh chan struct{} } func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey string, stateDump *stateDump) *WGWatcher { @@ -40,6 +42,7 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin wgIfaceStater: wgIfaceStater, peerKey: peerKey, stateDump: stateDump, + resetCh: make(chan struct{}, 1), } } @@ -76,6 +79,15 @@ func (w *WGWatcher) IsEnabled() bool { return w.enabled } +// Reset signals the watcher that the WireGuard peer has been reset and a new +// handshake is expected. This restarts the handshake timeout from scratch. +func (w *WGWatcher) Reset() { + select { + case w.resetCh <- struct{}{}: + default: + } +} + // wgStateCheck help to check the state of the WireGuard handshake and relay connection func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn func(), enabledTime time.Time, initialHandshake time.Time) { w.log.Infof("WireGuard watcher started") @@ -105,6 +117,12 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn w.stateDump.WGcheckSuccess() w.log.Debugf("WireGuard watcher reset timer: %v", resetTime) + case <-w.resetCh: + w.log.Infof("WireGuard watcher received peer reset, restarting handshake timeout") + lastHandshake = time.Time{} + enabledTime = time.Now() + timer.Stop() + timer.Reset(wgHandshakeOvertime) case <-ctx.Done(): w.log.Infof("WireGuard watcher stopped") return diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 464f57bff..edd70fb20 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -52,8 +52,9 @@ type WorkerICE struct { // increase by one when disconnecting the agent // with it the remote peer can discard the already deprecated offer/answer // Without it the remote peer may recreate a workable ICE connection - sessionID ICESessionID - muxAgent sync.Mutex + sessionID ICESessionID + remoteSessionChanged bool + muxAgent sync.Mutex localUfrag string localPwd string @@ -106,6 +107,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { return } w.log.Debugf("agent already exists, recreate the connection") + w.remoteSessionChanged = true w.agentDialerCancel() if w.agent != nil { if err := w.agent.Close(); err != nil { @@ -306,13 +308,17 @@ func (w *WorkerICE) connect(ctx context.Context, agent *icemaker.ThreadSafeAgent w.conn.onICEConnectionIsReady(selectedPriority(pair), ci) } -func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) { +func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.CancelFunc) bool { cancel() if err := agent.Close(); err != nil { w.log.Warnf("failed to close ICE agent: %s", err) } w.muxAgent.Lock() + defer w.muxAgent.Unlock() + + sessionChanged := w.remoteSessionChanged + w.remoteSessionChanged = false if w.agent == agent { // consider to remove from here and move to the OnNewOffer @@ -325,7 +331,7 @@ func (w *WorkerICE) closeAgent(agent *icemaker.ThreadSafeAgent, cancel context.C w.agentConnecting = false w.remoteSessionID = "" } - w.muxAgent.Unlock() + return sessionChanged } func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { @@ -426,11 +432,11 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia // ice.ConnectionStateClosed happens when we recreate the agent. For the P2P to TURN switch important to // notify the conn.onICEStateDisconnected changes to update the current used priority - w.closeAgent(agent, dialerCancel) + sessionChanged := w.closeAgent(agent, dialerCancel) if w.lastKnownState == ice.ConnectionStateConnected { w.lastKnownState = ice.ConnectionStateDisconnected - w.conn.onICEStateDisconnected() + w.conn.onICEStateDisconnected(sessionChanged) } default: return