From e8d301fdc9357b9ca9ecf9ce40d4d347255d3bf9 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 30 Sep 2025 15:31:18 +0200 Subject: [PATCH] [client] Fix/pkg loss (#3338) The Relayed connection setup is optimistic. It does not have any confirmation of an established end-to-end connection. Peers start sending WireGuard handshake packets immediately after the successful offer-answer handshake. Meanwhile, for successful P2P connection negotiation, we change the WireGuard endpoint address, but this change does not trigger new handshake initiation. Because the peer switched from Relayed connection to P2P, the packets from the Relay server are dropped and must wait for the next WireGuard handshake via P2P. To avoid this scenario, the relayed WireGuard proxy no longer drops the packets. Instead, it rewrites the source address to the new P2P endpoint and continues forwarding the packets. We still have one corner case: if the Relayed server negotiation chooses a server that has not been used before. In this case, one side of the peer connection will be slower to reach the Relay server, and the Relay server will drop the handshake packet. If everything goes well we should see exactly 5 seconds improvements between the WireGuard configuration time and the handshake time. --- client/iface/bind/endpoint.go | 14 +- client/iface/bind/ice_bind.go | 15 +- client/iface/iface_new_freebsd.go | 41 +++++ .../{iface_new_unix.go => iface_new_linux.go} | 2 +- client/iface/wgproxy/bind/proxy.go | 107 ++++++++---- client/iface/wgproxy/ebpf/proxy.go | 59 ++----- client/iface/wgproxy/ebpf/wrapper.go | 79 +++++---- client/iface/wgproxy/factory_kernel.go | 1 - .../iface/wgproxy/factory_kernel_freebsd.go | 31 ---- client/iface/wgproxy/proxy.go | 5 + client/iface/wgproxy/proxy_linux_test.go | 104 +++++++----- client/iface/wgproxy/proxy_seed_test.go | 39 +++++ client/iface/wgproxy/proxy_test.go | 152 ++++++++++++++---- client/iface/wgproxy/rawsocket/rawsocket.go | 50 ++++++ client/iface/wgproxy/udp/proxy.go | 94 ++++++++--- client/iface/wgproxy/udp/rawsocket.go | 101 ++++++++++++ client/internal/peer/conn.go | 62 ++++--- client/internal/peer/endpoint.go | 105 ++++++++++++ 18 files changed, 784 insertions(+), 277 deletions(-) create mode 100644 client/iface/iface_new_freebsd.go rename client/iface/{iface_new_unix.go => iface_new_linux.go} (97%) delete mode 100644 client/iface/wgproxy/factory_kernel_freebsd.go create mode 100644 client/iface/wgproxy/proxy_seed_test.go create mode 100644 client/iface/wgproxy/rawsocket/rawsocket.go create mode 100644 client/iface/wgproxy/udp/rawsocket.go create mode 100644 client/internal/peer/endpoint.go diff --git a/client/iface/bind/endpoint.go b/client/iface/bind/endpoint.go index 1926ff88f..caa92f05d 100644 --- a/client/iface/bind/endpoint.go +++ b/client/iface/bind/endpoint.go @@ -1,5 +1,17 @@ package bind -import wgConn "golang.zx2c4.com/wireguard/conn" +import ( + "net" + + wgConn "golang.zx2c4.com/wireguard/conn" +) type Endpoint = wgConn.StdNetEndpoint + +func EndpointToUDPAddr(e Endpoint) *net.UDPAddr { + return &net.UDPAddr{ + IP: e.Addr().AsSlice(), + Port: int(e.Port()), + Zone: e.Addr().Zone(), + } +} diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index 577c7c0c4..ef630b9d0 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -1,6 +1,7 @@ package bind import ( + "context" "encoding/binary" "fmt" "net" @@ -42,7 +43,7 @@ func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UD // use the port because in the Send function the wgConn.Endpoint the port info is not exported. type ICEBind struct { *wgConn.StdNetBind - RecvChan chan RecvMessage + recvChan chan RecvMessage transportNet transport.Net filterFn udpmux.FilterFn @@ -65,7 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wg b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) ib := &ICEBind{ StdNetBind: b, - RecvChan: make(chan RecvMessage, 1), + recvChan: make(chan RecvMessage, 1), transportNet: transportNet, filterFn: filterFn, endpoints: make(map[netip.Addr]net.Conn), @@ -155,6 +156,14 @@ func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { return nil } +func (b *ICEBind) Recv(ctx context.Context, msg RecvMessage) { + select { + case <-ctx.Done(): + return + case b.recvChan <- msg: + } +} + func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgsPool *sync.Pool) wgConn.ReceiveFunc { s.muUDPMux.Lock() defer s.muUDPMux.Unlock() @@ -271,7 +280,7 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo select { case <-c.closedChan: return 0, net.ErrClosed - case msg, ok := <-c.RecvChan: + case msg, ok := <-c.recvChan: if !ok { return 0, net.ErrClosed } diff --git a/client/iface/iface_new_freebsd.go b/client/iface/iface_new_freebsd.go new file mode 100644 index 000000000..86ed14ce1 --- /dev/null +++ b/client/iface/iface_new_freebsd.go @@ -0,0 +1,41 @@ +//go:build freebsd + +package iface + +import ( + "fmt" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// NewWGIFace Creates a new WireGuard interface instance +func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { + wgAddress, err := wgaddr.ParseWGAddress(opts.Address) + if err != nil { + return nil, err + } + + wgIFace := &WGIface{} + + if netstack.IsEnabled() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + return wgIFace, nil + } + + if device.ModuleTunIsLoaded() { + iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress, opts.MTU) + wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) + wgIFace.userspaceBind = true + wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) + return wgIFace, nil + } + + return nil, fmt.Errorf("couldn't check or load tun module") +} diff --git a/client/iface/iface_new_unix.go b/client/iface/iface_new_linux.go similarity index 97% rename from client/iface/iface_new_unix.go rename to client/iface/iface_new_linux.go index 493144f13..77fd30fae 100644 --- a/client/iface/iface_new_unix.go +++ b/client/iface/iface_new_linux.go @@ -1,4 +1,4 @@ -//go:build (linux && !android) || freebsd +//go:build linux && !android package iface diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index bf6da72c2..dbc694e91 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -16,28 +16,37 @@ import ( "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) +type IceBind interface { + SetEndpoint(fakeIP netip.Addr, conn net.Conn) + RemoveEndpoint(fakeIP netip.Addr) + Recv(ctx context.Context, msg bind.RecvMessage) + MTU() uint16 +} + type ProxyBind struct { - Bind *bind.ICEBind + bind IceBind - fakeNetIP *netip.AddrPort - wgBindEndpoint *bind.Endpoint - remoteConn net.Conn - ctx context.Context - cancel context.CancelFunc - closeMu sync.Mutex - closed bool + // wgRelayedEndpoint is a fake address that generated by the Bind.SetEndpoint based on the remote NetBird peer address + wgRelayedEndpoint *bind.Endpoint + wgCurrentUsed *bind.Endpoint + remoteConn net.Conn + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool closeListener *listener.CloseListener } -func NewProxyBind(bind *bind.ICEBind) *ProxyBind { +func NewProxyBind(bind IceBind) *ProxyBind { p := &ProxyBind{ - Bind: bind, + bind: bind, closeListener: listener.NewCloseListener(), + pausedCond: sync.NewCond(&sync.Mutex{}), } return p @@ -46,25 +55,25 @@ func NewProxyBind(bind *bind.ICEBind) *ProxyBind { // AddTurnConn adds a new connection to the bind. // endpoint is the NetBird address of the remote peer. The SetEndpoint return with the address what will be used in the // WireGuard configuration. +// +// Parameters: +// - ctx: Context is used for proxyToLocal to avoid unnecessary error messages +// - nbAddr: The NetBird UDP address of the remote peer, it required to generate fake address +// - remoteConn: The established TURN connection to the remote peer func (p *ProxyBind) AddTurnConn(ctx context.Context, nbAddr *net.UDPAddr, remoteConn net.Conn) error { fakeNetIP, err := fakeAddress(nbAddr) if err != nil { return err } - - p.fakeNetIP = fakeNetIP - p.wgBindEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP} + p.wgRelayedEndpoint = &bind.Endpoint{AddrPort: *fakeNetIP} p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) return nil } + func (p *ProxyBind) EndpointAddr() *net.UDPAddr { - return &net.UDPAddr{ - IP: p.fakeNetIP.Addr().AsSlice(), - Port: int(p.fakeNetIP.Port()), - Zone: p.fakeNetIP.Addr().Zone(), - } + return bind.EndpointToUDPAddr(*p.wgRelayedEndpoint) } func (p *ProxyBind) SetDisconnectListener(disconnected func()) { @@ -76,17 +85,21 @@ func (p *ProxyBind) Work() { return } - p.Bind.SetEndpoint(p.fakeNetIP.Addr(), p.remoteConn) + p.bind.SetEndpoint(p.wgRelayedEndpoint.Addr(), p.remoteConn) - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + + p.wgCurrentUsed = p.wgRelayedEndpoint // Start the proxy only once if !p.isStarted { p.isStarted = true go p.proxyToLocal(p.ctx) } + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } func (p *ProxyBind) Pause() { @@ -94,9 +107,25 @@ func (p *ProxyBind) Pause() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + p.paused = false + + p.wgCurrentUsed = addrToEndpoint(endpoint) + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() +} + +func addrToEndpoint(addr *net.UDPAddr) *bind.Endpoint { + ip, _ := netip.AddrFromSlice(addr.IP.To4()) + addrPort := netip.AddrPortFrom(ip, uint16(addr.Port)) + return &bind.Endpoint{AddrPort: addrPort} } func (p *ProxyBind) CloseConn() error { @@ -107,6 +136,10 @@ func (p *ProxyBind) CloseConn() error { } func (p *ProxyBind) close() error { + if p.remoteConn == nil { + return nil + } + p.closeMu.Lock() defer p.closeMu.Unlock() @@ -120,7 +153,12 @@ func (p *ProxyBind) close() error { p.cancel() - p.Bind.RemoveEndpoint(p.fakeNetIP.Addr()) + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + + p.bind.RemoveEndpoint(p.wgRelayedEndpoint.Addr()) if rErr := p.remoteConn.Close(); rErr != nil && !errors.Is(rErr, net.ErrClosed) { return rErr @@ -136,7 +174,7 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { }() for { - buf := make([]byte, p.Bind.MTU()+bufsize.WGBufferOverhead) + buf := make([]byte, p.bind.MTU()+bufsize.WGBufferOverhead) n, err := p.remoteConn.Read(buf) if err != nil { if ctx.Err() != nil { @@ -147,18 +185,17 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } msg := bind.RecvMessage{ - Endpoint: p.wgBindEndpoint, + Endpoint: p.wgCurrentUsed, Buffer: buf[:n], } - p.Bind.RecvChan <- msg - p.pausedMu.Unlock() + p.bind.Recv(ctx, msg) + p.pausedCond.L.Unlock() } } diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index b899f1694..858143091 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -6,9 +6,7 @@ import ( "context" "fmt" "net" - "os" "sync" - "syscall" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -18,6 +16,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface/bufsize" + "github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket" "github.com/netbirdio/netbird/client/internal/ebpf" ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager" nbnet "github.com/netbirdio/netbird/client/net" @@ -27,6 +26,10 @@ const ( loopbackAddr = "127.0.0.1" ) +var ( + localHostNetIP = net.ParseIP("127.0.0.1") +) + // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { localWGListenPort int @@ -64,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error { return err } - p.rawConn, err = p.prepareSenderRawSocket() + p.rawConn, err = rawsocket.PrepareSenderRawSocket() if err != nil { return err } @@ -214,57 +217,17 @@ generatePort: return p.lastUsedPort, nil } -func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) { - // Create a raw socket. - fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) - if err != nil { - return nil, fmt.Errorf("creating raw socket failed: %w", err) - } - - // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. - err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) - if err != nil { - return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) - } - - // Bind the socket to the "lo" interface. - err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") - if err != nil { - return nil, fmt.Errorf("binding to lo interface failed: %w", err) - } - - // Set the fwmark on the socket. - err = nbnet.SetSocketOpt(fd) - if err != nil { - return nil, fmt.Errorf("setting fwmark failed: %w", err) - } - - // Convert the file descriptor to a PacketConn. - file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) - if file == nil { - return nil, fmt.Errorf("converting fd to file failed") - } - packetConn, err := net.FilePacketConn(file) - if err != nil { - return nil, fmt.Errorf("converting file to packet conn failed: %w", err) - } - - return packetConn, nil -} - -func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { - localhost := net.ParseIP("127.0.0.1") - +func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error { payload := gopacket.Payload(data) ipH := &layers.IPv4{ - DstIP: localhost, - SrcIP: localhost, + DstIP: localHostNetIP, + SrcIP: endpointAddr.IP, Version: 4, TTL: 64, Protocol: layers.IPProtocolUDP, } udpH := &layers.UDP{ - SrcPort: layers.UDPPort(port), + SrcPort: layers.UDPPort(endpointAddr.Port), DstPort: layers.UDPPort(p.localWGListenPort), } @@ -279,7 +242,7 @@ func (p *WGEBPFProxy) sendPkg(data []byte, port int) error { if err != nil { return fmt.Errorf("serialize layers: %w", err) } - if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localhost}); err != nil { + if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil { return fmt.Errorf("write to raw conn: %w", err) } return nil diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index 3d71b01bd..ff44d30c0 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -18,41 +18,42 @@ import ( // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call type ProxyWrapper struct { - WgeBPFProxy *WGEBPFProxy + wgeBPFProxy *WGEBPFProxy remoteConn net.Conn ctx context.Context cancel context.CancelFunc - wgEndpointAddr *net.UDPAddr + wgRelayedEndpointAddr *net.UDPAddr + wgEndpointCurrentUsedAddr *net.UDPAddr - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool closeListener *listener.CloseListener } -func NewProxyWrapper(WgeBPFProxy *WGEBPFProxy) *ProxyWrapper { +func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper { return &ProxyWrapper{ - WgeBPFProxy: WgeBPFProxy, + wgeBPFProxy: proxy, + pausedCond: sync.NewCond(&sync.Mutex{}), closeListener: listener.NewCloseListener(), } } - func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { - addr, err := p.WgeBPFProxy.AddTurnConn(remoteConn) + addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) } p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) - p.wgEndpointAddr = addr + p.wgRelayedEndpointAddr = addr return err } func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { - return p.wgEndpointAddr + return p.wgRelayedEndpointAddr } func (p *ProxyWrapper) SetDisconnectListener(disconnected func()) { @@ -64,14 +65,18 @@ func (p *ProxyWrapper) Work() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + + p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr if !p.isStarted { p.isStarted = true go p.proxyToLocal(p.ctx) } + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } func (p *ProxyWrapper) Pause() { @@ -80,45 +85,59 @@ func (p *ProxyWrapper) Pause() { } log.Tracef("pause proxy reading from: %s", p.remoteConn.RemoteAddr()) - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + p.paused = false + + p.wgEndpointCurrentUsedAddr = endpoint + + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } // CloseConn close the remoteConn and automatically remove the conn instance from the map -func (e *ProxyWrapper) CloseConn() error { - if e.cancel == nil { +func (p *ProxyWrapper) CloseConn() error { + if p.cancel == nil { return fmt.Errorf("proxy not started") } - e.cancel() + p.cancel() - e.closeListener.SetCloseListener(nil) + p.closeListener.SetCloseListener(nil) - if err := e.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { - return fmt.Errorf("close remote conn: %w", err) + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + + if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + return fmt.Errorf("failed to close remote conn: %w", err) } return nil } func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { - defer p.WgeBPFProxy.removeTurnConn(uint16(p.wgEndpointAddr.Port)) + defer p.wgeBPFProxy.removeTurnConn(uint16(p.wgRelayedEndpointAddr.Port)) - buf := make([]byte, p.WgeBPFProxy.mtu+bufsize.WGBufferOverhead) + buf := make([]byte, p.wgeBPFProxy.mtu+bufsize.WGBufferOverhead) for { n, err := p.readFromRemote(ctx, buf) if err != nil { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } - err = p.WgeBPFProxy.sendPkg(buf[:n], p.wgEndpointAddr.Port) - p.pausedMu.Unlock() + err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr) + p.pausedCond.L.Unlock() if err != nil { if ctx.Err() != nil { @@ -137,7 +156,7 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err } p.closeListener.Notify() if !errors.Is(err, io.EOF) { - log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgEndpointAddr.Port, err) + log.Errorf("failed to read from turn conn (endpoint: :%d): %s", p.wgRelayedEndpointAddr.Port, err) } return 0, err } diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go index 63bc2ed24..ad2807546 100644 --- a/client/iface/wgproxy/factory_kernel.go +++ b/client/iface/wgproxy/factory_kernel.go @@ -39,7 +39,6 @@ func (w *KernelFactory) GetProxy() Proxy { } return ebpf.NewProxyWrapper(w.ebpfProxy) - } func (w *KernelFactory) Free() error { diff --git a/client/iface/wgproxy/factory_kernel_freebsd.go b/client/iface/wgproxy/factory_kernel_freebsd.go deleted file mode 100644 index 039f1cd3a..000000000 --- a/client/iface/wgproxy/factory_kernel_freebsd.go +++ /dev/null @@ -1,31 +0,0 @@ -package wgproxy - -import ( - log "github.com/sirupsen/logrus" - - udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" -) - -// KernelFactory todo: check eBPF support on FreeBSD -type KernelFactory struct { - wgPort int - mtu uint16 -} - -func NewKernelFactory(wgPort int, mtu uint16) *KernelFactory { - log.Infof("WireGuard Proxy Factory will produce UDP proxy") - f := &KernelFactory{ - wgPort: wgPort, - mtu: mtu, - } - - return f -} - -func (w *KernelFactory) GetProxy() Proxy { - return udpProxy.NewWGUDPProxy(w.wgPort, w.mtu) -} - -func (w *KernelFactory) Free() error { - return nil -} diff --git a/client/iface/wgproxy/proxy.go b/client/iface/wgproxy/proxy.go index c2879877e..3c8dfd30e 100644 --- a/client/iface/wgproxy/proxy.go +++ b/client/iface/wgproxy/proxy.go @@ -11,6 +11,11 @@ type Proxy interface { EndpointAddr() *net.UDPAddr // EndpointAddr returns the address of the WireGuard peer endpoint Work() // Work start or resume the proxy Pause() // Pause to forward the packages from remote connection to WireGuard. The opposite way still works. + + //RedirectAs resume the forwarding the packages from relayed connection to WireGuard interface if it was paused + //and rewrite the src address to the endpoint address. + //With this logic can avoid the package loss from relayed connections. + RedirectAs(endpoint *net.UDPAddr) CloseConn() error SetDisconnectListener(disconnected func()) } diff --git a/client/iface/wgproxy/proxy_linux_test.go b/client/iface/wgproxy/proxy_linux_test.go index 5add503e1..9526e91d2 100644 --- a/client/iface/wgproxy/proxy_linux_test.go +++ b/client/iface/wgproxy/proxy_linux_test.go @@ -3,54 +3,82 @@ package wgproxy import ( - "context" - "os" - "testing" + "fmt" + "net" + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/wgaddr" + bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind" "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + "github.com/netbirdio/netbird/client/iface/wgproxy/udp" ) -func TestProxyCloseByRemoteConnEBPF(t *testing.T) { - if os.Getenv("GITHUB_ACTIONS") != "true" { - t.Skip("Skipping test as it requires root privileges") - } - ctx := context.Background() +func seedProxies() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) if err := ebpfProxy.Listen(); err != nil { - t.Fatalf("failed to initialize ebpf proxy: %s", err) + return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err) } - defer func() { - if err := ebpfProxy.Free(); err != nil { - t.Errorf("failed to free ebpf proxy: %s", err) - } - }() - - tests := []struct { - name string - proxy Proxy - }{ - { - name: "ebpf proxy", - proxy: &ebpf.ProxyWrapper{ - WgeBPFProxy: ebpfProxy, - }, - }, + pEbpf := proxyInstance{ + name: "ebpf kernel proxy", + proxy: ebpf.NewProxyWrapper(ebpfProxy), + wgPort: 51831, + closeFn: ebpfProxy.Free, } + pl = append(pl, pEbpf) - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - relayedConn := newMockConn() - err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) - if err != nil { - t.Errorf("error: %v", err) - } - - _ = relayedConn.Close() - if err := tt.proxy.CloseConn(); err != nil { - t.Errorf("error: %v", err) - } - }) + pUDP := proxyInstance{ + name: "udp kernel proxy", + proxy: udp.NewWGUDPProxy(51832, 1280), + wgPort: 51832, + closeFn: func() error { return nil }, } + pl = append(pl, pUDP) + return pl, nil +} + +func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) + + ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) + if err := ebpfProxy.Listen(); err != nil { + return nil, fmt.Errorf("failed to initialize ebpf proxy: %s", err) + } + + pEbpf := proxyInstance{ + name: "ebpf kernel proxy", + proxy: ebpf.NewProxyWrapper(ebpfProxy), + wgPort: 51831, + closeFn: ebpfProxy.Free, + } + pl = append(pl, pEbpf) + + pUDP := proxyInstance{ + name: "udp kernel proxy", + proxy: udp.NewWGUDPProxy(51832, 1280), + wgPort: 51832, + closeFn: func() error { return nil }, + } + pl = append(pl, pUDP) + wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32") + if err != nil { + return nil, err + } + iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280) + endpointAddress := &net.UDPAddr{ + IP: net.IPv4(10, 0, 0, 1), + Port: 1234, + } + + pBind := proxyInstance{ + name: "bind proxy", + proxy: bindproxy.NewProxyBind(iceBind), + endpointAddr: endpointAddress, + closeFn: func() error { return nil }, + } + pl = append(pl, pBind) + + return pl, nil } diff --git a/client/iface/wgproxy/proxy_seed_test.go b/client/iface/wgproxy/proxy_seed_test.go new file mode 100644 index 000000000..4d244f18a --- /dev/null +++ b/client/iface/wgproxy/proxy_seed_test.go @@ -0,0 +1,39 @@ +//go:build !linux + +package wgproxy + +import ( + "net" + + "github.com/netbirdio/netbird/client/iface/bind" + "github.com/netbirdio/netbird/client/iface/wgaddr" + bindproxy "github.com/netbirdio/netbird/client/iface/wgproxy/bind" +) + +func seedProxies() ([]proxyInstance, error) { + // todo extend with Bind proxy + pl := make([]proxyInstance, 0) + return pl, nil +} + +func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) { + pl := make([]proxyInstance, 0) + wgAddress, err := wgaddr.ParseWGAddress("10.0.0.1/32") + if err != nil { + return nil, err + } + iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280) + endpointAddress := &net.UDPAddr{ + IP: net.IPv4(10, 0, 0, 1), + Port: 1234, + } + + pBind := proxyInstance{ + name: "bind proxy", + proxy: bindproxy.NewProxyBind(iceBind), + endpointAddr: endpointAddress, + closeFn: func() error { return nil }, + } + pl = append(pl, pBind) + return pl, nil +} diff --git a/client/iface/wgproxy/proxy_test.go b/client/iface/wgproxy/proxy_test.go index 76e5ed6f7..1aeab66b7 100644 --- a/client/iface/wgproxy/proxy_test.go +++ b/client/iface/wgproxy/proxy_test.go @@ -1,5 +1,3 @@ -//go:build linux - package wgproxy import ( @@ -7,12 +5,9 @@ import ( "io" "net" "os" - "runtime" "testing" "time" - "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" - udpProxy "github.com/netbirdio/netbird/client/iface/wgproxy/udp" "github.com/netbirdio/netbird/util" ) @@ -22,6 +17,14 @@ func TestMain(m *testing.M) { os.Exit(code) } +type proxyInstance struct { + name string + proxy Proxy + wgPort int + endpointAddr *net.UDPAddr + closeFn func() error +} + type mocConn struct { closeChan chan struct{} closed bool @@ -78,41 +81,21 @@ func (m *mocConn) SetWriteDeadline(t time.Time) error { func TestProxyCloseByRemoteConn(t *testing.T) { ctx := context.Background() - tests := []struct { - name string - proxy Proxy - }{ - { - name: "userspace proxy", - proxy: udpProxy.NewWGUDPProxy(51830, 1280), - }, + tests, err := seedProxyForProxyCloseByRemoteConn() + if err != nil { + t.Fatalf("error: %v", err) } - if runtime.GOOS == "linux" && os.Getenv("GITHUB_ACTIONS") != "true" { - ebpfProxy := ebpf.NewWGEBPFProxy(51831, 1280) - if err := ebpfProxy.Listen(); err != nil { - t.Fatalf("failed to initialize ebpf proxy: %s", err) - } - defer func() { - if err := ebpfProxy.Free(); err != nil { - t.Errorf("failed to free ebpf proxy: %s", err) - } - }() - proxyWrapper := ebpf.NewProxyWrapper(ebpfProxy) - - tests = append(tests, struct { - name string - proxy Proxy - }{ - name: "ebpf proxy", - proxy: proxyWrapper, - }) - } + relayedConn, _ := net.Dial("udp", "127.0.0.1:1234") + defer func() { + _ = relayedConn.Close() + }() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + addr, _ := net.ResolveUDPAddr("udp", "100.108.135.221:51892") relayedConn := newMockConn() - err := tt.proxy.AddTurnConn(ctx, nil, relayedConn) + err := tt.proxy.AddTurnConn(ctx, addr, relayedConn) if err != nil { t.Errorf("error: %v", err) } @@ -124,3 +107,104 @@ func TestProxyCloseByRemoteConn(t *testing.T) { }) } } + +// TestProxyRedirect todo extend the proxies with Bind proxy +func TestProxyRedirect(t *testing.T) { + tests, err := seedProxies() + if err != nil { + t.Fatalf("error: %v", err) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + redirectTraffic(t, tt.proxy, tt.wgPort, tt.endpointAddr) + if err := tt.closeFn(); err != nil { + t.Errorf("error: %v", err) + } + }) + } +} + +func redirectTraffic(t *testing.T, proxy Proxy, wgPort int, endPointAddr *net.UDPAddr) { + t.Helper() + + msgHelloFromRelay := []byte("hello from relay") + msgRedirected := [][]byte{ + []byte("hello 1. to p2p"), + []byte("hello 2. to p2p"), + []byte("hello 3. to p2p"), + } + + dummyWgListener, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: wgPort}) + if err != nil { + t.Fatalf("failed to listen on udp port: %s", err) + } + + relayedServer, _ := net.ListenUDP("udp", + &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 1234, + }, + ) + + relayedConn, _ := net.Dial("udp", "127.0.0.1:1234") + + defer func() { + _ = dummyWgListener.Close() + _ = relayedConn.Close() + _ = relayedServer.Close() + }() + + if err := proxy.AddTurnConn(context.Background(), endPointAddr, relayedConn); err != nil { + t.Errorf("error: %v", err) + } + defer func() { + if err := proxy.CloseConn(); err != nil { + t.Errorf("error: %v", err) + } + }() + + proxy.Work() + + if _, err := relayedServer.WriteTo(msgHelloFromRelay, relayedConn.LocalAddr()); err != nil { + t.Errorf("error relayedServer.Write(msgHelloFromRelay): %v", err) + } + + n, err := dummyWgListener.Read(make([]byte, 1024)) + if err != nil { + t.Errorf("error: %v", err) + } + + if n != len(msgHelloFromRelay) { + t.Errorf("expected %d bytes, got %d", len(msgHelloFromRelay), n) + } + + p2pEndpointAddr := &net.UDPAddr{ + IP: net.IPv4(192, 168, 0, 56), + Port: 1234, + } + proxy.RedirectAs(p2pEndpointAddr) + + for _, msg := range msgRedirected { + if _, err := relayedServer.WriteTo(msg, relayedConn.LocalAddr()); err != nil { + t.Errorf("error: %v", err) + } + } + + for i := 0; i < len(msgRedirected); i++ { + buf := make([]byte, 1024) + n, rAddr, err := dummyWgListener.ReadFrom(buf) + if err != nil { + t.Errorf("error: %v", err) + } + + if rAddr.String() != p2pEndpointAddr.String() { + t.Errorf("expected %s, got %s", p2pEndpointAddr.String(), rAddr.String()) + } + if string(buf[:n]) != string(msgRedirected[i]) { + t.Errorf("expected %s, got %s", string(msgRedirected[i]), string(buf[:n])) + } + } +} diff --git a/client/iface/wgproxy/rawsocket/rawsocket.go b/client/iface/wgproxy/rawsocket/rawsocket.go new file mode 100644 index 000000000..a11ac46d5 --- /dev/null +++ b/client/iface/wgproxy/rawsocket/rawsocket.go @@ -0,0 +1,50 @@ +//go:build linux && !android + +package rawsocket + +import ( + "fmt" + "net" + "os" + "syscall" + + nbnet "github.com/netbirdio/netbird/client/net" +) + +func PrepareSenderRawSocket() (net.PacketConn, error) { + // Create a raw socket. + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) + if err != nil { + return nil, fmt.Errorf("creating raw socket failed: %w", err) + } + + // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. + err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) + if err != nil { + return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) + } + + // Bind the socket to the "lo" interface. + err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") + if err != nil { + return nil, fmt.Errorf("binding to lo interface failed: %w", err) + } + + // Set the fwmark on the socket. + err = nbnet.SetSocketOpt(fd) + if err != nil { + return nil, fmt.Errorf("setting fwmark failed: %w", err) + } + + // Convert the file descriptor to a PacketConn. + file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) + if file == nil { + return nil, fmt.Errorf("converting fd to file failed") + } + packetConn, err := net.FilePacketConn(file) + if err != nil { + return nil, fmt.Errorf("converting file to packet conn failed: %w", err) + } + + return packetConn, nil +} diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index be65e2b27..4ef2f19c4 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -1,3 +1,5 @@ +//go:build linux && !android + package udp import ( @@ -21,16 +23,18 @@ type WGUDPProxy struct { localWGListenPort int mtu uint16 - remoteConn net.Conn - localConn net.Conn - ctx context.Context - cancel context.CancelFunc - closeMu sync.Mutex - closed bool + remoteConn net.Conn + localConn net.Conn + srcFakerConn *SrcFaker + sendPkg func(data []byte) (int, error) + ctx context.Context + cancel context.CancelFunc + closeMu sync.Mutex + closed bool - pausedMu sync.Mutex - paused bool - isStarted bool + paused bool + pausedCond *sync.Cond + isStarted bool closeListener *listener.CloseListener } @@ -41,6 +45,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy { p := &WGUDPProxy{ localWGListenPort: wgPort, mtu: mtu, + pausedCond: sync.NewCond(&sync.Mutex{}), closeListener: listener.NewCloseListener(), } return p @@ -61,6 +66,7 @@ func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, rem p.ctx, p.cancel = context.WithCancel(ctx) p.localConn = localConn + p.sendPkg = p.localConn.Write p.remoteConn = remoteConn return err @@ -84,15 +90,24 @@ func (p *WGUDPProxy) Work() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = false - p.pausedMu.Unlock() + p.sendPkg = p.localConn.Write + + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + log.Errorf("failed to close src faker conn: %s", err) + } + p.srcFakerConn = nil + } if !p.isStarted { p.isStarted = true go p.proxyToRemote(p.ctx) go p.proxyToLocal(p.ctx) } + p.pausedCond.Signal() + p.pausedCond.L.Unlock() } // Pause pauses the proxy from receiving data from the remote peer @@ -101,9 +116,35 @@ func (p *WGUDPProxy) Pause() { return } - p.pausedMu.Lock() + p.pausedCond.L.Lock() p.paused = true - p.pausedMu.Unlock() + p.pausedCond.L.Unlock() +} + +// RedirectAs start to use the fake sourced raw socket as package sender +func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) { + p.pausedCond.L.Lock() + defer func() { + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + }() + + p.paused = false + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + log.Errorf("failed to close src faker conn: %s", err) + } + p.srcFakerConn = nil + } + srcFakerConn, err := NewSrcFaker(p.localWGListenPort, endpoint) + if err != nil { + log.Errorf("failed to create src faker conn: %s", err) + // fallback to continue without redirecting + p.paused = true + return + } + p.srcFakerConn = srcFakerConn + p.sendPkg = p.srcFakerConn.SendPkg } // CloseConn close the localConn @@ -115,6 +156,8 @@ func (p *WGUDPProxy) CloseConn() error { } func (p *WGUDPProxy) close() error { + var result *multierror.Error + p.closeMu.Lock() defer p.closeMu.Unlock() @@ -128,7 +171,11 @@ func (p *WGUDPProxy) close() error { p.cancel() - var result *multierror.Error + p.pausedCond.L.Lock() + p.paused = false + p.pausedCond.Signal() + p.pausedCond.L.Unlock() + if err := p.remoteConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { result = multierror.Append(result, fmt.Errorf("remote conn: %s", err)) } @@ -136,6 +183,13 @@ func (p *WGUDPProxy) close() error { if err := p.localConn.Close(); err != nil { result = multierror.Append(result, fmt.Errorf("local conn: %s", err)) } + + if p.srcFakerConn != nil { + if err := p.srcFakerConn.Close(); err != nil { + result = multierror.Append(result, fmt.Errorf("src faker raw conn: %s", err)) + } + } + return cerrors.FormatErrorOrNil(result) } @@ -194,14 +248,12 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { return } - p.pausedMu.Lock() - if p.paused { - p.pausedMu.Unlock() - continue + p.pausedCond.L.Lock() + for p.paused { + p.pausedCond.Wait() } - - _, err = p.localConn.Write(buf[:n]) - p.pausedMu.Unlock() + _, err = p.sendPkg(buf[:n]) + p.pausedCond.L.Unlock() if err != nil { if ctx.Err() != nil { diff --git a/client/iface/wgproxy/udp/rawsocket.go b/client/iface/wgproxy/udp/rawsocket.go new file mode 100644 index 000000000..fdc911463 --- /dev/null +++ b/client/iface/wgproxy/udp/rawsocket.go @@ -0,0 +1,101 @@ +//go:build linux && !android + +package udp + +import ( + "fmt" + "net" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/wgproxy/rawsocket" +) + +var ( + serializeOpts = gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + localHostNetIPAddr = &net.IPAddr{ + IP: net.ParseIP("127.0.0.1"), + } +) + +type SrcFaker struct { + srcAddr *net.UDPAddr + + rawSocket net.PacketConn + ipH gopacket.SerializableLayer + udpH gopacket.SerializableLayer + layerBuffer gopacket.SerializeBuffer +} + +func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) { + rawSocket, err := rawsocket.PrepareSenderRawSocket() + if err != nil { + return nil, err + } + + ipH, udpH, err := prepareHeaders(dstPort, srcAddr) + if err != nil { + return nil, err + } + + f := &SrcFaker{ + srcAddr: srcAddr, + rawSocket: rawSocket, + ipH: ipH, + udpH: udpH, + layerBuffer: gopacket.NewSerializeBuffer(), + } + + return f, nil +} + +func (f *SrcFaker) Close() error { + return f.rawSocket.Close() +} + +func (f *SrcFaker) SendPkg(data []byte) (int, error) { + defer func() { + if err := f.layerBuffer.Clear(); err != nil { + log.Errorf("failed to clear layer buffer: %s", err) + } + }() + + payload := gopacket.Payload(data) + + err := gopacket.SerializeLayers(f.layerBuffer, serializeOpts, f.ipH, f.udpH, payload) + if err != nil { + return 0, fmt.Errorf("serialize layers: %w", err) + } + n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr) + if err != nil { + return 0, fmt.Errorf("write to raw conn: %w", err) + } + return n, nil +} + +func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) { + ipH := &layers.IPv4{ + DstIP: net.ParseIP("127.0.0.1"), + SrcIP: srcAddr.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + udpH := &layers.UDP{ + SrcPort: layers.UDPPort(srcAddr.Port), + DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port + } + + err := udpH.SetNetworkLayerForChecksum(ipH) + if err != nil { + return nil, nil, fmt.Errorf("set network layer for checksum: %w", err) + } + + return ipH, udpH, nil +} diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 86e4596d4..8db9e58f4 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -28,10 +28,6 @@ import ( semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) -const ( - defaultWgKeepAlive = 25 * time.Second -) - type ServiceDependencies struct { StatusRecorder *Status Signaler *Signaler @@ -117,6 +113,8 @@ type Conn struct { // debug purpose dumpState *stateDump + + endpointUpdater *EndpointUpdater } // NewConn creates a new not opened Conn to the remote peer. @@ -129,17 +127,18 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) { connLog := log.WithField("peer", config.Key) var conn = &Conn{ - Log: connLog, - config: config, - statusRecorder: services.StatusRecorder, - signaler: services.Signaler, - iFaceDiscover: services.IFaceDiscover, - relayManager: services.RelayManager, - srWatcher: services.SrWatcher, - semaphore: services.Semaphore, - statusRelay: worker.NewAtomicStatus(), - statusICE: worker.NewAtomicStatus(), - dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), + Log: connLog, + config: config, + statusRecorder: services.StatusRecorder, + signaler: services.Signaler, + iFaceDiscover: services.IFaceDiscover, + relayManager: services.RelayManager, + srWatcher: services.SrWatcher, + semaphore: services.Semaphore, + statusRelay: worker.NewAtomicStatus(), + statusICE: worker.NewAtomicStatus(), + dumpState: newStateDump(config.Key, connLog, services.StatusRecorder), + endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)), } return conn, nil @@ -249,7 +248,7 @@ func (conn *Conn) Close(signalToRemote bool) { conn.wgProxyICE = nil } - if err := conn.removeWgPeer(); err != nil { + if err := conn.endpointUpdater.RemoveWgPeer(); err != nil { conn.Log.Errorf("failed to remove wg endpoint: %v", err) } @@ -375,12 +374,19 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn wgProxy.Work() } - if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil { + conn.Log.Infof("configure WireGuard endpoint to: %s", ep.String()) + presharedKey := conn.presharedKey(iceConnInfo.RosenpassPubKey) + if err = conn.endpointUpdater.ConfigureWGEndpoint(ep, presharedKey); err != nil { conn.handleConfigurationFailure(err, wgProxy) return } wgConfigWorkaround() + if conn.wgProxyRelay != nil { + conn.Log.Debugf("redirect packets from relayed conn to WireGuard") + conn.wgProxyRelay.RedirectAs(ep) + } + conn.currentConnPriority = priority conn.statusICE.SetConnected() conn.updateIceState(iceConnInfo) @@ -409,7 +415,8 @@ func (conn *Conn) onICEStateDisconnected() { conn.dumpState.SwitchToRelay() conn.wgProxyRelay.Work() - if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil { + presharedKey := conn.presharedKey(conn.rosenpassRemoteKey) + if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil { conn.Log.Errorf("failed to switch to relay conn: %v", err) } @@ -418,6 +425,7 @@ func (conn *Conn) onICEStateDisconnected() { defer conn.wgWatcherWg.Done() conn.workerRelay.EnableWgWatcher(conn.ctx) }() + conn.wgProxyRelay.Work() conn.currentConnPriority = conntype.Relay } else { conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String()) @@ -477,7 +485,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { } wgProxy.Work() - if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { + presharedKey := conn.presharedKey(rci.rosenpassPubKey) + if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.Log.Warnf("Failed to close relay connection: %v", err) } @@ -545,17 +554,6 @@ func (conn *Conn) onGuardEvent() { } } -func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error { - presharedKey := conn.presharedKey(remoteRPKey) - return conn.config.WgConfig.WgInterface.UpdatePeer( - conn.config.WgConfig.RemoteKey, - conn.config.WgConfig.AllowedIps, - defaultWgKeepAlive, - addr, - presharedKey, - ) -} - func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) { peerState := State{ PubKey: conn.config.Key, @@ -698,10 +696,6 @@ func (conn *Conn) isICEActive() bool { return (conn.currentConnPriority == conntype.ICEP2P || conn.currentConnPriority == conntype.ICETurn) && conn.statusICE.Get() == worker.StatusConnected } -func (conn *Conn) removeWgPeer() error { - return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) -} - func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { conn.Log.Warnf("Failed to update wg peer configuration: %v", err) if wgProxy != nil { diff --git a/client/internal/peer/endpoint.go b/client/internal/peer/endpoint.go new file mode 100644 index 000000000..39cb95591 --- /dev/null +++ b/client/internal/peer/endpoint.go @@ -0,0 +1,105 @@ +package peer + +import ( + "context" + "net" + "sync" + "time" + + "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +const ( + defaultWgKeepAlive = 25 * time.Second + fallbackDelay = 5 * time.Second +) + +type EndpointUpdater struct { + log *logrus.Entry + wgConfig WgConfig + initiator bool + + // mu protects updateWireGuardPeer and cancelFunc + mu sync.Mutex + cancelFunc func() + updateWg sync.WaitGroup +} + +func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *EndpointUpdater { + return &EndpointUpdater{ + log: log, + wgConfig: wgConfig, + initiator: initiator, + } +} + +// ConfigureWGEndpoint sets up the WireGuard endpoint configuration. +// The initiator immediately configures the endpoint, while the non-initiator +// waits for a fallback period before configuring to avoid handshake congestion. +func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error { + e.mu.Lock() + defer e.mu.Unlock() + + if e.initiator { + e.log.Debugf("configure up WireGuard as initiatr") + return e.updateWireGuardPeer(addr, presharedKey) + } + + // prevent to run new update while cancel the previous update + e.waitForCloseTheDelayedUpdate() + + var ctx context.Context + ctx, e.cancelFunc = context.WithCancel(context.Background()) + e.updateWg.Add(1) + go e.scheduleDelayedUpdate(ctx, addr, presharedKey) + + e.log.Debugf("configure up WireGuard and wait for handshake") + return e.updateWireGuardPeer(nil, presharedKey) +} + +func (e *EndpointUpdater) RemoveWgPeer() error { + e.mu.Lock() + defer e.mu.Unlock() + + e.waitForCloseTheDelayedUpdate() + return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey) +} + +func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() { + if e.cancelFunc == nil { + return + } + + e.cancelFunc() + e.cancelFunc = nil + e.updateWg.Wait() +} + +// scheduleDelayedUpdate waits for the fallback period before updating the endpoint +func (e *EndpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr, presharedKey *wgtypes.Key) { + defer e.updateWg.Done() + t := time.NewTimer(fallbackDelay) + defer t.Stop() + + select { + case <-ctx.Done(): + return + case <-t.C: + e.mu.Lock() + if err := e.updateWireGuardPeer(addr, presharedKey); err != nil { + e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err) + } + e.mu.Unlock() + } +} + +func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKey *wgtypes.Key) error { + return e.wgConfig.WgInterface.UpdatePeer( + e.wgConfig.RemoteKey, + e.wgConfig.AllowedIps, + defaultWgKeepAlive, + endpoint, + presharedKey, + ) +}