diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 6b1daf866..982337fbb 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -118,6 +118,8 @@ type Conn struct { // debug purpose dumpState *stateDump + + endpointUpdater *endpointUpdater } // NewConn creates a new not opened Conn to the remote peer. @@ -141,6 +143,11 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { statusRelay: worker.NewAtomicStatus(), statusICE: worker.NewAtomicStatus(), dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), + endpointUpdater: &endpointUpdater{ + log: connLog, + wgConfig: config.WgConfig, + initiator: isWireGuardInitiator(config), + }, } return conn, nil @@ -250,7 +257,7 @@ func (conn *Conn) Close(signalToRemote bool) { conn.wgProxyICE = nil } - if err := conn.removeWgPeer(); err != nil { + if err := conn.endpointUpdater.removeWgPeer(); err != nil { conn.Log.Errorf("failed to remove wg endpoint: %v", err) } @@ -377,13 +384,12 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn } conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String()) - if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil { + if err = conn.endpointUpdater.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil { conn.handleConfigurationFailure(err, wgProxy) return } wgConfigWorkaround() - if conn.wgProxyRelay != nil { conn.Log.Debugf("redirect packages from relayed conn to WireGuard") conn.wgProxyRelay.RedirectAs(ep) @@ -417,7 +423,7 @@ func (conn *Conn) onICEStateDisconnected() { conn.dumpState.SwitchToRelay() conn.wgProxyRelay.Work() - if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil { + if err := conn.endpointUpdater.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil { conn.Log.Errorf("failed to switch to relay conn: %v", err) } @@ -486,7 +492,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { } wgProxy.Work() - if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { + if err := conn.endpointUpdater.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.Log.Warnf("Failed to close relay connection: %v", err) } @@ -554,17 +560,6 @@ func (conn *Conn) onGuardEvent() { } } -func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error { - presharedKey := conn.presharedKey(remoteRPKey) - return conn.config.WgConfig.WgInterface.UpdatePeer( - conn.config.WgConfig.RemoteKey, - conn.config.WgConfig.AllowedIps, - defaultWgKeepAlive, - addr, - presharedKey, - ) -} - func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) { peerState := State{ PubKey: conn.config.Key, @@ -707,10 +702,6 @@ func (conn *Conn) isICEActive() bool { return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected } -func (conn *Conn) removeWgPeer() error { - return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) -} - func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { conn.Log.Warnf("Failed to update wg peer configuration: %v", err) if wgProxy != nil { @@ -791,6 +782,10 @@ func isController(config ConnConfig) bool { return config.LocalKey > config.Key } +func isWireGuardInitiator(config ConnConfig) bool { + return isController(config) +} + func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { return remoteRosenpassPubKey != nil } diff --git a/client/internal/peer/endpoint.go b/client/internal/peer/endpoint.go new file mode 100644 index 000000000..b70c17782 --- /dev/null +++ b/client/internal/peer/endpoint.go @@ -0,0 +1,88 @@ +package peer + +import ( + "context" + "net" + "sync" + "time" + + "github.com/sirupsen/logrus" +) + +// fallbackDelay could be const but because of testing it is a var +var fallbackDelay = 5 * time.Second + +type endpointUpdater struct { + log *logrus.Entry + wgConfig WgConfig + initiator bool + + cancelFunc func() + configUpdateMutex sync.Mutex +} + +// configureWGEndpoint sets up the WireGuard endpoint configuration. +// The initiator immediately configures the endpoint, while the non-initiator +// waits for a fallback period before configuring to avoid handshake congestion. +func (e *endpointUpdater) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error { + if e.initiator { + return e.updateWireGuardPeer(addr, remoteRPKey) + } + + // prevent to run new update while cancel the previous update + e.configUpdateMutex.Lock() + if e.cancelFunc != nil { + e.cancelFunc() + } + e.configUpdateMutex.Unlock() + + var ctx context.Context + ctx, e.cancelFunc = context.WithCancel(context.Background()) + go e.scheduleDelayedUpdate(ctx, addr, remoteRPKey) + + return e.updateWireGuardPeer(nil, remoteRPKey) +} + +func (e *endpointUpdater) removeWgPeer() error { + e.configUpdateMutex.Lock() + defer e.configUpdateMutex.Unlock() + + if e.cancelFunc != nil { + e.cancelFunc() + } + + return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey) +} + +// scheduleDelayedUpdate waits for the fallback period before updating the endpoint +func (e *endpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr, remoteRPKey []byte) { + t := time.NewTimer(fallbackDelay) + defer t.Stop() + + select { + case <-ctx.Done(): + return + case <-t.C: + e.configUpdateMutex.Lock() + defer e.configUpdateMutex.Unlock() + + if ctx.Err() != nil { + return + } + + if err := e.updateWireGuardPeer(addr, remoteRPKey); err != nil { + e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err) + } + } +} + +func (e *endpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, remoteRPKey []byte) error { + // todo add, "presharedKey := e.presharedKey(remote)" + return e.wgConfig.WgInterface.UpdatePeer( + e.wgConfig.RemoteKey, + e.wgConfig.AllowedIps, + defaultWgKeepAlive, + endpoint, + e.wgConfig.PreSharedKey, + ) +}