diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 423a626d1..8c23bee85 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/internal/proxy" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/iface" + "github.com/netbirdio/netbird/iface/bind" signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/version" @@ -326,12 +327,7 @@ func (conn *Conn) Open() error { // * Local peer uses userspace interface with bind.ICEBind and is not relayed // // Please note, that this check happens when peers were already able to ping each other using ICE layer. -func shouldUseProxy(pair *ice.CandidatePair, userspaceBind bool) bool { - - if !isRelayCandidate(pair.Local) && userspaceBind { - log.Debugf("shouldn't use proxy because using Bind and the connection is not relayed") - return false - } +func shouldUseProxy(pair *ice.CandidatePair) bool { if !isHardNATCandidate(pair.Local) && isHostCandidateWithPublicIP(pair.Remote) { log.Debugf("shouldn't use proxy because the local peer is not behind a hard NAT and the remote one has a public IP") @@ -436,26 +432,54 @@ func (conn *Conn) startProxy(remoteConn net.Conn, remoteWgPort int) error { } func (conn *Conn) getProxyWithMessageExchange(pair *ice.CandidatePair, remoteWgPort int) proxy.Proxy { - useProxy := shouldUseProxy(pair, conn.config.UserspaceBind) - localDirectMode := !useProxy - remoteDirectMode := localDirectMode - if conn.meta.protoSupport.DirectCheck { - go conn.sendLocalDirectMode(localDirectMode) - // will block until message received or timeout - remoteDirectMode = conn.receiveRemoteDirectMode() + if !conn.config.UserspaceBind { + useProxy := shouldUseProxy(pair) + localDirectMode := !useProxy + remoteDirectMode := localDirectMode + + if conn.meta.protoSupport.DirectCheck { + go conn.sendLocalDirectMode(localDirectMode) + // will block until message received or timeout + remoteDirectMode = conn.receiveRemoteDirectMode() + } + + if localDirectMode && remoteDirectMode { + return proxy.NewDirectNoProxy(conn.config.ProxyConfig, remoteWgPort) + } + + log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key) + return proxy.NewWireGuardProxy(conn.config.ProxyConfig) } - if conn.config.UserspaceBind && localDirectMode { - return proxy.NewNoProxy(conn.config.ProxyConfig) + if isRelayCandidate(pair.Local) { + return proxy.NewWireGuardProxy(conn.config.ProxyConfig) } - if localDirectMode && remoteDirectMode { - return proxy.NewDirectNoProxy(conn.config.ProxyConfig, remoteWgPort) + // We decided to ignore the proxy decision when using Bind. Instead, we always punch remote WireGuard port to open a + // hole in the firewall for that remote port to avoid cases when old clients assumes direct mode. + mux := conn.config.UDPMuxSrflx.(*bind.UniversalUDPMuxDefault) + err := punchRemote(pair, remoteWgPort, mux) + if err != nil { + log.Warnf("failed to punch remote WireGuard port") } - log.Debugf("falling back to local proxy mode with peer %s", conn.config.Key) - return proxy.NewWireGuardProxy(conn.config.ProxyConfig) + return proxy.NewNoProxy(conn.config.ProxyConfig) + +} + +func punchRemote(pair *ice.CandidatePair, remoteWgPort int, muxDefault *bind.UniversalUDPMuxDefault) error { + addr, err := net.ResolveUDPAddr("udp", pair.Remote.Address()) + if err != nil { + return err + } + addr.Port = remoteWgPort + + _, err = muxDefault.GetSharedConn().WriteTo([]byte{1}, addr) + if err != nil { + return err + } + return err } func (conn *Conn) sendLocalDirectMode(localMode bool) { diff --git a/iface/bind/udp_mux_universal.go b/iface/bind/udp_mux_universal.go index 91f0ee18a..445c3532b 100644 --- a/iface/bind/udp_mux_universal.go +++ b/iface/bind/udp_mux_universal.go @@ -75,6 +75,10 @@ type udpConn struct { logger logging.LeveledLogger } +func (m *UniversalUDPMuxDefault) GetSharedConn() net.PacketConn { + return m.params.UDPConn +} + // GetListenAddresses returns the listen addr of this UDP func (m *UniversalUDPMuxDefault) GetListenAddresses() []net.Addr { return []net.Addr{m.LocalAddr()}