From 05af39a69b510bb6094588eee1f6776fef165476 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Mon, 26 Jan 2026 14:03:32 +0100 Subject: [PATCH] [client] Add IPv6 support to UDP WireGuard proxy (#5169) * Add IPv6 support to UDP WireGuard proxy Add IPv6 packet header support in UDP raw socket proxy to handle both IPv4 and IPv6 source addresses. Refactor error handling in proxy bind implementations to validate endpoints before acquiring locks. --- client/iface/wgproxy/bind/proxy.go | 40 +-- client/iface/wgproxy/ebpf/proxy.go | 93 +++++- client/iface/wgproxy/ebpf/wrapper.go | 10 +- client/iface/wgproxy/rawsocket/rawsocket.go | 56 +++- client/iface/wgproxy/redirect_test.go | 353 ++++++++++++++++++++ client/iface/wgproxy/udp/proxy.go | 2 +- client/iface/wgproxy/udp/rawsocket.go | 78 +++-- 7 files changed, 565 insertions(+), 67 deletions(-) create mode 100644 client/iface/wgproxy/redirect_test.go diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index 9978cceee..9ac3ea6df 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -114,34 +114,21 @@ func (p *ProxyBind) Pause() { } func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) { + ep, err := addrToEndpoint(endpoint) + if err != nil { + log.Errorf("failed to start package redirection: %v", err) + return + } + p.pausedCond.L.Lock() p.paused = false - ep, err := addrToEndpoint(endpoint) - if err != nil { - log.Errorf("failed to convert endpoint address: %v", err) - } else { - p.wgCurrentUsed = ep - } + p.wgCurrentUsed = ep p.pausedCond.Signal() p.pausedCond.L.Unlock() } -func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) { - if addr == nil { - return nil, errors.New("nil address") - } - - ip, ok := netip.AddrFromSlice(addr.IP) - if !ok { - return nil, fmt.Errorf("convert %s to netip.Addr", addr) - } - - addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port)) - return &bind.Endpoint{AddrPort: addrPort}, nil -} - func (p *ProxyBind) CloseConn() error { if p.cancel == nil { return fmt.Errorf("proxy not started") @@ -225,3 +212,16 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) { netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port)) return &netipAddr, nil } + +func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) { + if addr == nil { + return nil, fmt.Errorf("invalid address") + } + ip, ok := netip.AddrFromSlice(addr.IP) + if !ok { + return nil, fmt.Errorf("convert %s to netip.Addr", addr) + } + + addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port)) + return &bind.Endpoint{AddrPort: addrPort}, nil +} diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index 858143091..0c1c886d7 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -27,7 +27,13 @@ const ( ) var ( - localHostNetIP = net.ParseIP("127.0.0.1") + localHostNetIPv4 = net.ParseIP("127.0.0.1") + localHostNetIPv6 = net.ParseIP("::1") + + serializeOpts = gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } ) // WGEBPFProxy definition for proxy with EBPF support @@ -40,7 +46,8 @@ type WGEBPFProxy struct { turnConnMutex sync.Mutex lastUsedPort uint16 - rawConn net.PacketConn + rawConnIPv4 net.PacketConn + rawConnIPv6 net.PacketConn conn transport.UDPConn ctx context.Context @@ -67,13 +74,28 @@ func (p *WGEBPFProxy) Listen() error { return err } - p.rawConn, err = rawsocket.PrepareSenderRawSocket() + // Prepare IPv4 raw socket (required) + p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4() if err != nil { return err } + // Prepare IPv6 raw socket (optional) + p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6() + if err != nil { + log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err) + } + err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort) if err != nil { + if closeErr := p.rawConnIPv4.Close(); closeErr != nil { + log.Warnf("failed to close IPv4 raw socket: %v", closeErr) + } + if p.rawConnIPv6 != nil { + if closeErr := p.rawConnIPv6.Close(); closeErr != nil { + log.Warnf("failed to close IPv6 raw socket: %v", closeErr) + } + } return err } @@ -135,8 +157,16 @@ func (p *WGEBPFProxy) Free() error { result = multierror.Append(result, err) } - if err := p.rawConn.Close(); err != nil { - result = multierror.Append(result, err) + if p.rawConnIPv4 != nil { + if err := p.rawConnIPv4.Close(); err != nil { + result = multierror.Append(result, err) + } + } + + if p.rawConnIPv6 != nil { + if err := p.rawConnIPv6.Close(); err != nil { + result = multierror.Append(result, err) + } } return nberrors.FormatErrorOrNil(result) } @@ -218,31 +248,60 @@ generatePort: } func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error { - payload := gopacket.Payload(data) - ipH := &layers.IPv4{ - DstIP: localHostNetIP, - SrcIP: endpointAddr.IP, - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, + + var ipH gopacket.SerializableLayer + var networkLayer gopacket.NetworkLayer + var dstIP net.IP + var rawConn net.PacketConn + + if endpointAddr.IP.To4() != nil { + // IPv4 path + ipv4 := &layers.IPv4{ + DstIP: localHostNetIPv4, + SrcIP: endpointAddr.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + ipH = ipv4 + networkLayer = ipv4 + dstIP = localHostNetIPv4 + rawConn = p.rawConnIPv4 + } else { + // IPv6 path + if p.rawConnIPv6 == nil { + return fmt.Errorf("IPv6 raw socket not available") + } + ipv6 := &layers.IPv6{ + DstIP: localHostNetIPv6, + SrcIP: endpointAddr.IP, + Version: 6, + HopLimit: 64, + NextHeader: layers.IPProtocolUDP, + } + ipH = ipv6 + networkLayer = ipv6 + dstIP = localHostNetIPv6 + rawConn = p.rawConnIPv6 } + udpH := &layers.UDP{ SrcPort: layers.UDPPort(endpointAddr.Port), DstPort: layers.UDPPort(p.localWGListenPort), } - err := udpH.SetNetworkLayerForChecksum(ipH) - if err != nil { + if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil { return fmt.Errorf("set network layer for checksum: %w", err) } layerBuffer := gopacket.NewSerializeBuffer() + payload := gopacket.Payload(data) - err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload) - if err != nil { + if err := gopacket.SerializeLayers(layerBuffer, serializeOpts, ipH, udpH, payload); err != nil { return fmt.Errorf("serialize layers: %w", err) } - if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil { + + if _, err := rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: dstIP}); err != nil { return fmt.Errorf("write to raw conn: %w", err) } return nil diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index f1f05a7c9..5b98be7b4 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -41,7 +41,7 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper { closeListener: listener.NewCloseListener(), } } -func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { +func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error { addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) @@ -91,12 +91,14 @@ func (p *ProxyWrapper) Pause() { } func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) { + if endpoint == nil || endpoint.IP == nil { + log.Errorf("failed to start package redirection, endpoint is nil") + return + } p.pausedCond.L.Lock() p.paused = false - if endpoint != nil && endpoint.IP != nil { - p.wgEndpointCurrentUsedAddr = endpoint - } + p.wgEndpointCurrentUsedAddr = endpoint p.pausedCond.Signal() p.pausedCond.L.Unlock() diff --git a/client/iface/wgproxy/rawsocket/rawsocket.go b/client/iface/wgproxy/rawsocket/rawsocket.go index a11ac46d5..bc785b43a 100644 --- a/client/iface/wgproxy/rawsocket/rawsocket.go +++ b/client/iface/wgproxy/rawsocket/rawsocket.go @@ -8,43 +8,87 @@ import ( "os" "syscall" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" + nbnet "github.com/netbirdio/netbird/client/net" ) -func PrepareSenderRawSocket() (net.PacketConn, error) { +// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets +func PrepareSenderRawSocketIPv4() (net.PacketConn, error) { + return prepareSenderRawSocket(syscall.AF_INET, true) +} + +// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets +func PrepareSenderRawSocketIPv6() (net.PacketConn, error) { + return prepareSenderRawSocket(syscall.AF_INET6, false) +} + +func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) { // Create a raw socket. - fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) + fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW) if err != nil { return nil, fmt.Errorf("creating raw socket failed: %w", err) } - // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. - err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) - if err != nil { - return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) + // Set the header include option on the socket to tell the kernel that headers are included in the packet. + // For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers. + if isIPv4 { + err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1) + if err != nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } + return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) + } + } else { + err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1) + if err != nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } + return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err) + } } // Bind the socket to the "lo" interface. err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") if err != nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } return nil, fmt.Errorf("binding to lo interface failed: %w", err) } // Set the fwmark on the socket. err = nbnet.SetSocketOpt(fd) if err != nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } return nil, fmt.Errorf("setting fwmark failed: %w", err) } // Convert the file descriptor to a PacketConn. file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) if file == nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } return nil, fmt.Errorf("converting fd to file failed") } packetConn, err := net.FilePacketConn(file) if err != nil { + if closeErr := file.Close(); closeErr != nil { + log.Warnf("failed to close file: %v", closeErr) + } return nil, fmt.Errorf("converting file to packet conn failed: %w", err) } + // Close the original file to release the FD (net.FilePacketConn duplicates it) + if closeErr := file.Close(); closeErr != nil { + log.Warnf("failed to close file after creating packet conn: %v", closeErr) + } + return packetConn, nil } diff --git a/client/iface/wgproxy/redirect_test.go b/client/iface/wgproxy/redirect_test.go new file mode 100644 index 000000000..b52eead25 --- /dev/null +++ b/client/iface/wgproxy/redirect_test.go @@ -0,0 +1,353 @@ +//go:build linux && !android + +package wgproxy + +import ( + "context" + "net" + "testing" + "time" + + "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + "github.com/netbirdio/netbird/client/iface/wgproxy/udp" +) + +// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs +// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore +func compareUDPAddr(addr1, addr2 net.Addr) bool { + udpAddr1, ok1 := addr1.(*net.UDPAddr) + udpAddr2, ok2 := addr2.(*net.UDPAddr) + + if !ok1 || !ok2 { + return addr1.String() == addr2.String() + } + + // Compare IP and Port, ignoring zone + return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port +} + +// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses +func TestRedirectAs_eBPF_IPv4(t *testing.T) { + wgPort := 51850 + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280) + if err := ebpfProxy.Listen(); err != nil { + t.Fatalf("failed to initialize ebpf proxy: %v", err) + } + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %v", err) + } + }() + + proxy := ebpf.NewProxyWrapper(ebpfProxy) + + // NetBird UDP address of the remote peer + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + p2pEndpoint := &net.UDPAddr{ + IP: net.ParseIP("192.168.0.56"), + Port: 51820, + } + + testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint) +} + +// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses +func TestRedirectAs_eBPF_IPv6(t *testing.T) { + wgPort := 51851 + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280) + if err := ebpfProxy.Listen(); err != nil { + t.Fatalf("failed to initialize ebpf proxy: %v", err) + } + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %v", err) + } + }() + + proxy := ebpf.NewProxyWrapper(ebpfProxy) + + // NetBird UDP address of the remote peer + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + p2pEndpoint := &net.UDPAddr{ + IP: net.ParseIP("fe80::56"), + Port: 51820, + } + + testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint) +} + +// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses +func TestRedirectAs_UDP_IPv4(t *testing.T) { + wgPort := 51852 + proxy := udp.NewWGUDPProxy(wgPort, 1280) + + // NetBird UDP address of the remote peer + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + p2pEndpoint := &net.UDPAddr{ + IP: net.ParseIP("192.168.0.56"), + Port: 51820, + } + + testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint) +} + +// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses +func TestRedirectAs_UDP_IPv6(t *testing.T) { + wgPort := 51853 + proxy := udp.NewWGUDPProxy(wgPort, 1280) + + // NetBird UDP address of the remote peer + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + p2pEndpoint := &net.UDPAddr{ + IP: net.ParseIP("fe80::56"), + Port: 51820, + } + + testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint) +} + +// testRedirectAs is a helper function that tests the RedirectAs functionality +// It verifies that: +// 1. Initial traffic from relay connection works +// 2. After calling RedirectAs, packets appear to come from the p2p endpoint +// 3. Multiple packets are correctly redirected with the new source address +func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) { + t.Helper() + + ctx := context.Background() + + // Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types + // In reality, WireGuard binds to a port and receives from both IPv4 and IPv6 + wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: wgPort, + }) + if err != nil { + t.Fatalf("failed to create IPv4 WireGuard listener: %v", err) + } + defer wgListener4.Close() + + wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{ + IP: net.ParseIP("::1"), + Port: wgPort, + }) + if err != nil { + t.Fatalf("failed to create IPv6 WireGuard listener: %v", err) + } + defer wgListener6.Close() + + // Determine which listener to use based on the NetBird address IP version + // (this is where initial traffic will come from before RedirectAs is called) + var wgListener *net.UDPConn + if p2pEndpoint.IP.To4() == nil { + wgListener = wgListener6 + } else { + wgListener = wgListener4 + } + + // Create relay server and connection + relayServer, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, // Random port + }) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + defer relayServer.Close() + + relayConn, err := net.Dial("udp", relayServer.LocalAddr().String()) + if err != nil { + t.Fatalf("failed to create relay connection: %v", err) + } + defer relayConn.Close() + + // Add TURN connection to proxy + if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil { + t.Fatalf("failed to add TURN connection: %v", err) + } + defer func() { + if err := proxy.CloseConn(); err != nil { + t.Errorf("failed to close proxy connection: %v", err) + } + }() + + // Start the proxy + proxy.Work() + + // Phase 1: Test initial relay traffic + msgFromRelay := []byte("hello from relay") + if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil { + t.Fatalf("failed to write to relay server: %v", err) + } + + // Set read deadline to avoid hanging + if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("failed to set read deadline: %v", err) + } + + buf := make([]byte, 1024) + n, _, err := wgListener4.ReadFrom(buf) + if err != nil { + t.Fatalf("failed to read from WireGuard listener: %v", err) + } + + if n != len(msgFromRelay) { + t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n) + } + + if string(buf[:n]) != string(msgFromRelay) { + t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n]) + } + + // Phase 2: Redirect to p2p endpoint + proxy.RedirectAs(p2pEndpoint) + + // Give the proxy a moment to process the redirect + time.Sleep(100 * time.Millisecond) + + // Phase 3: Test redirected traffic + redirectedMessages := [][]byte{ + []byte("redirected message 1"), + []byte("redirected message 2"), + []byte("redirected message 3"), + } + + for i, msg := range redirectedMessages { + if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil { + t.Fatalf("failed to write redirected message %d: %v", i+1, err) + } + + if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("failed to set read deadline: %v", err) + } + + n, srcAddr, err := wgListener.ReadFrom(buf) + if err != nil { + t.Fatalf("failed to read redirected message %d: %v", i+1, err) + } + + // Verify message content + if string(buf[:n]) != string(msg) { + t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n]) + } + + // Verify source address matches p2p endpoint (this is the key test) + // Use compareUDPAddr to ignore IPv6 zone IDs + if !compareUDPAddr(srcAddr, p2pEndpoint) { + t.Errorf("message %d: expected source address %s, got %s", + i+1, p2pEndpoint.String(), srcAddr.String()) + } + } +} + +// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints +func TestRedirectAs_Multiple_Switches(t *testing.T) { + wgPort := 51856 + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280) + if err := ebpfProxy.Listen(); err != nil { + t.Fatalf("failed to initialize ebpf proxy: %v", err) + } + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %v", err) + } + }() + + proxy := ebpf.NewProxyWrapper(ebpfProxy) + + ctx := context.Background() + + // Create WireGuard listener + wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: wgPort, + }) + if err != nil { + t.Fatalf("failed to create WireGuard listener: %v", err) + } + defer wgListener.Close() + + // Create relay server and connection + relayServer, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + }) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + defer relayServer.Close() + + relayConn, err := net.Dial("udp", relayServer.LocalAddr().String()) + if err != nil { + t.Fatalf("failed to create relay connection: %v", err) + } + defer relayConn.Close() + + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil { + t.Fatalf("failed to add TURN connection: %v", err) + } + defer func() { + if err := proxy.CloseConn(); err != nil { + t.Errorf("failed to close proxy connection: %v", err) + } + }() + + proxy.Work() + + // Test switching between multiple endpoints - using addresses in local subnet + endpoints := []*net.UDPAddr{ + {IP: net.ParseIP("192.168.0.100"), Port: 51820}, + {IP: net.ParseIP("192.168.0.101"), Port: 51821}, + {IP: net.ParseIP("192.168.0.102"), Port: 51822}, + } + + for i, endpoint := range endpoints { + proxy.RedirectAs(endpoint) + time.Sleep(100 * time.Millisecond) + + msg := []byte("test message") + if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil { + t.Fatalf("failed to write message for endpoint %d: %v", i, err) + } + + buf := make([]byte, 1024) + if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("failed to set read deadline: %v", err) + } + + n, srcAddr, err := wgListener.ReadFrom(buf) + if err != nil { + t.Fatalf("failed to read message for endpoint %d: %v", i, err) + } + + if string(buf[:n]) != string(msg) { + t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n]) + } + + if !compareUDPAddr(srcAddr, endpoint) { + t.Errorf("endpoint %d: expected source %s, got %s", + i, endpoint.String(), srcAddr.String()) + } + } +} diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index 4ef2f19c4..6069d1960 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy { // the connection is complete, an error is returned. Once successfully // connected, any expiration of the context will not affect the // connection. -func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { +func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error { dialer := net.Dialer{} localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { diff --git a/client/iface/wgproxy/udp/rawsocket.go b/client/iface/wgproxy/udp/rawsocket.go index fdc911463..cc099d9df 100644 --- a/client/iface/wgproxy/udp/rawsocket.go +++ b/client/iface/wgproxy/udp/rawsocket.go @@ -19,37 +19,56 @@ var ( FixLengths: true, } - localHostNetIPAddr = &net.IPAddr{ + localHostNetIPAddrV4 = &net.IPAddr{ IP: net.ParseIP("127.0.0.1"), } + localHostNetIPAddrV6 = &net.IPAddr{ + IP: net.ParseIP("::1"), + } ) type SrcFaker struct { srcAddr *net.UDPAddr - rawSocket net.PacketConn - ipH gopacket.SerializableLayer - udpH gopacket.SerializableLayer - layerBuffer gopacket.SerializeBuffer + rawSocket net.PacketConn + ipH gopacket.SerializableLayer + udpH gopacket.SerializableLayer + layerBuffer gopacket.SerializeBuffer + localHostAddr *net.IPAddr } func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) { - rawSocket, err := rawsocket.PrepareSenderRawSocket() + // Create only the raw socket for the address family we need + var rawSocket net.PacketConn + var err error + var localHostAddr *net.IPAddr + + if srcAddr.IP.To4() != nil { + rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4() + localHostAddr = localHostNetIPAddrV4 + } else { + rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6() + localHostAddr = localHostNetIPAddrV6 + } if err != nil { return nil, err } ipH, udpH, err := prepareHeaders(dstPort, srcAddr) if err != nil { + if closeErr := rawSocket.Close(); closeErr != nil { + log.Warnf("failed to close raw socket: %v", closeErr) + } return nil, err } f := &SrcFaker{ - srcAddr: srcAddr, - rawSocket: rawSocket, - ipH: ipH, - udpH: udpH, - layerBuffer: gopacket.NewSerializeBuffer(), + srcAddr: srcAddr, + rawSocket: rawSocket, + ipH: ipH, + udpH: udpH, + layerBuffer: gopacket.NewSerializeBuffer(), + localHostAddr: localHostAddr, } return f, nil @@ -72,7 +91,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) { if err != nil { return 0, fmt.Errorf("serialize layers: %w", err) } - n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr) + n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr) if err != nil { return 0, fmt.Errorf("write to raw conn: %w", err) } @@ -80,19 +99,40 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) { } func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) { - ipH := &layers.IPv4{ - DstIP: net.ParseIP("127.0.0.1"), - SrcIP: srcAddr.IP, - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, + var ipH gopacket.SerializableLayer + var networkLayer gopacket.NetworkLayer + + // Check if source IP is IPv4 or IPv6 + if srcAddr.IP.To4() != nil { + // IPv4 + ipv4 := &layers.IPv4{ + DstIP: localHostNetIPAddrV4.IP, + SrcIP: srcAddr.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + ipH = ipv4 + networkLayer = ipv4 + } else { + // IPv6 + ipv6 := &layers.IPv6{ + DstIP: localHostNetIPAddrV6.IP, + SrcIP: srcAddr.IP, + Version: 6, + HopLimit: 64, + NextHeader: layers.IPProtocolUDP, + } + ipH = ipv6 + networkLayer = ipv6 } + udpH := &layers.UDP{ SrcPort: layers.UDPPort(srcAddr.Port), DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port } - err := udpH.SetNetworkLayerForChecksum(ipH) + err := udpH.SetNetworkLayerForChecksum(networkLayer) if err != nil { return nil, nil, fmt.Errorf("set network layer for checksum: %w", err) }