diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index 1d5390a57..af5606511 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -14,19 +14,27 @@ import ( ) type ProxyBind struct { - Bind *bind.ICEBind + bind *bind.ICEBind // wgEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address - wgEndpoint *bind.Endpoint - remoteConn net.Conn - ctx context.Context - cancel context.CancelFunc - closeMu sync.Mutex - closed bool + wgRelayedEndpoint *bind.Endpoint + wgCurrentUsed *bind.Endpoint + remoteConn net.Conn + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool +} + +func NewProxyBind(bind *bind.ICEBind) *ProxyBind { + return &ProxyBind{ + bind: bind, + pausedCond: sync.NewCond(&sync.Mutex{}), + } } // AddTurnConn adds a new connection to the bind. @@ -38,19 +46,19 @@ type ProxyBind struct { // - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address // - remoteConn: The established TURN connection to the remote peer func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { - fakeAddr, err := p.Bind.SetEndpoint(nbAddr, remoteConn) + fakeAddr, err := p.bind.SetEndpoint(nbAddr, remoteConn) if err != nil { return err } - p.wgEndpoint = addrToEndpoint(fakeAddr) + p.wgRelayedEndpoint = addrToEndpoint(fakeAddr) p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) return err } func (p *ProxyBind) EndpointAddr() *net.UDPAddr { - return bind.EndpointToUDPAddr(*p.wgEndpoint) + return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint) } func (p *ProxyBind) Work() { @@ -58,15 +66,20 @@ func (p *ProxyBind) Work() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + + p.wgCurrentUsed = p.wgRelayedEndpoint // Start the proxy only once if !p.isStarted { p.isStarted = true go p.proxyToLocal(p.ctx) } + + p.pausedCond.L.Unlock() + // todo: review to should be inside the lock scope + p.pausedCond.Signal() } func (p *ProxyBind) Pause() { @@ -74,9 +87,19 @@ func (p *ProxyBind) Pause() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +func (p *ProxyBind) RedirectTo(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + p.paused = false + + p.wgCurrentUsed = addrToEndpoint(endpoint) + + p.pausedCond.L.Unlock() + p.pausedCond.Signal() } func (p *ProxyBind) CloseConn() error { @@ -97,7 +120,12 @@ func (p *ProxyBind) close() error { p.cancel() - p.Bind.RemoveEndpoint(bind.EndpointToUDPAddr(*p.wgEndpoint)) + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.L.Unlock() + p.pausedCond.Signal() + + p.bind.RemoveEndpoint(bind.EndpointToUDPAddr(*p.wgCurrentUsed)) if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { return rErr @@ -123,18 +151,25 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + for { + p.pausedCond.L.Lock() + if p.paused { + p.pausedCond.Wait() + if !p.paused { + break + } + p.pausedCond.L.Unlock() + continue + } + break } msg := bind.RecvMessage{ - Endpoint: p.wgEndpoint, + Endpoint: p.wgCurrentUsed, Buffer: buf[:n], } - p.Bind.RecvChan <- msg - p.pausedMu.Unlock() + p.bind.RecvChan <- msg + p.pausedCond.L.Unlock() } } diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index 54cab4e1b..412afdee1 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -69,6 +69,10 @@ func (p *ProxyWrapper) Pause() { p.pausedMu.Unlock() } +func (p *ProxyWrapper) RedirectTo(endpoint *net.UDPAddr) { + // todo implement me +} + // CloseConn close the remoteConn and automatically remove the conn instance from the map func (e *ProxyWrapper) CloseConn() error { if e.cancel == nil { diff --git a/client/iface/wgproxy/factory_usp.go b/client/iface/wgproxy/factory_usp.go index e2d479331..141b4c1f9 100644 --- a/client/iface/wgproxy/factory_usp.go +++ b/client/iface/wgproxy/factory_usp.go @@ -20,9 +20,7 @@ func NewUSPFactory(iceBind *bind.ICEBind) *USPFactory { } func (w *USPFactory) GetProxy() Proxy { - return &proxyBind.ProxyBind{ - Bind: w.bind, - } + return proxyBind.NewProxyBind(w.bind) } func (w *USPFactory) Free() error { diff --git a/client/iface/wgproxy/proxy.go b/client/iface/wgproxy/proxy.go index 243aa2bd2..53a5ca7b9 100644 --- a/client/iface/wgproxy/proxy.go +++ b/client/iface/wgproxy/proxy.go @@ -12,4 +12,5 @@ type Proxy interface { Work() // Work start or resume the proxy Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. CloseConn() error + RedirectTo(endpoint *net.UDPAddr) } diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index ba0004b8a..93bb293fc 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -95,6 +95,10 @@ func (p *WGUDPProxy) Pause() { p.pausedMu.Unlock() } +func (p *WGUDPProxy) RedirectTo(endpoint *net.UDPAddr) { + // todo implement me +} + // CloseConn close the localConn func (p *WGUDPProxy) CloseConn() error { if p.cancel == nil { diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 3639d62ee..927000647 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -370,18 +370,26 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC conn.workerRelay.DisableWgWatcher() if conn.wgProxyRelay != nil { + conn.log.Debugf("pause Relayed proxy") conn.wgProxyRelay.Pause() } if wgProxy != nil { + conn.log.Debugf("run ICE proxy") wgProxy.Work() } + conn.log.Infof("configure WireGuard endpoint to: %s", ep.String()) if err = conn.endpointUpdater.configureWGEndpoint(ep); err != nil { conn.handleConfigurationFailure(err, wgProxy) return } wgConfigWorkaround() + + if conn.wgProxyRelay != nil { + conn.wgProxyRelay.RedirectTo(ep) + } + conn.currentConnPriority = priority conn.statusICE.Set(StatusConnected) conn.updateIceState(iceConnInfo) @@ -407,11 +415,12 @@ func (conn *Conn) onICEStateDisconnected() { // switch back to relay connection if conn.isReadyToUpgrade() { conn.log.Infof("ICE disconnected, set Relay to active connection") - conn.wgProxyRelay.Work() if err := conn.endpointUpdater.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { conn.log.Errorf("failed to switch to relay conn: %v", err) } + + conn.wgProxyRelay.Work() conn.workerRelay.EnableWgWatcher(conn.ctx) conn.currentConnPriority = connPriorityRelay } else {