diff --git a/client/iface/wgproxy/bind/proxy.go b/client/iface/wgproxy/bind/proxy.go index 9978cceee..9ac3ea6df 100644 --- a/client/iface/wgproxy/bind/proxy.go +++ b/client/iface/wgproxy/bind/proxy.go @@ -114,34 +114,21 @@ func (p *ProxyBind) Pause() { } func (p *ProxyBind) RedirectAs(endpoint *net.UDPAddr) { + ep, err := addrToEndpoint(endpoint) + if err != nil { + log.Errorf("failed to start package redirection: %v", err) + return + } + p.pausedCond.L.Lock() p.paused = false - ep, err := addrToEndpoint(endpoint) - if err != nil { - log.Errorf("failed to convert endpoint address: %v", err) - } else { - p.wgCurrentUsed = ep - } + p.wgCurrentUsed = ep p.pausedCond.Signal() p.pausedCond.L.Unlock() } -func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) { - if addr == nil { - return nil, errors.New("nil address") - } - - ip, ok := netip.AddrFromSlice(addr.IP) - if !ok { - return nil, fmt.Errorf("convert %s to netip.Addr", addr) - } - - addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port)) - return &bind.Endpoint{AddrPort: addrPort}, nil -} - func (p *ProxyBind) CloseConn() error { if p.cancel == nil { return fmt.Errorf("proxy not started") @@ -225,3 +212,16 @@ func fakeAddress(peerAddress *net.UDPAddr) (*netip.AddrPort, error) { netipAddr := netip.AddrPortFrom(fakeIP, uint16(peerAddress.Port)) return &netipAddr, nil } + +func addrToEndpoint(addr *net.UDPAddr) (*bind.Endpoint, error) { + if addr == nil { + return nil, fmt.Errorf("invalid address") + } + ip, ok := netip.AddrFromSlice(addr.IP) + if !ok { + return nil, fmt.Errorf("convert %s to netip.Addr", addr) + } + + addrPort := netip.AddrPortFrom(ip.Unmap(), uint16(addr.Port)) + return &bind.Endpoint{AddrPort: addrPort}, nil +} diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index 858143091..0c1c886d7 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -27,7 +27,13 @@ const ( ) var ( - localHostNetIP = net.ParseIP("127.0.0.1") + localHostNetIPv4 = net.ParseIP("127.0.0.1") + localHostNetIPv6 = net.ParseIP("::1") + + serializeOpts = gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } ) // WGEBPFProxy definition for proxy with EBPF support @@ -40,7 +46,8 @@ type WGEBPFProxy struct { turnConnMutex sync.Mutex lastUsedPort uint16 - rawConn net.PacketConn + rawConnIPv4 net.PacketConn + rawConnIPv6 net.PacketConn conn transport.UDPConn ctx context.Context @@ -67,13 +74,28 @@ func (p *WGEBPFProxy) Listen() error { return err } - p.rawConn, err = rawsocket.PrepareSenderRawSocket() + // Prepare IPv4 raw socket (required) + p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4() if err != nil { return err } + // Prepare IPv6 raw socket (optional) + p.rawConnIPv6, err = rawsocket.PrepareSenderRawSocketIPv6() + if err != nil { + log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err) + } + err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort) if err != nil { + if closeErr := p.rawConnIPv4.Close(); closeErr != nil { + log.Warnf("failed to close IPv4 raw socket: %v", closeErr) + } + if p.rawConnIPv6 != nil { + if closeErr := p.rawConnIPv6.Close(); closeErr != nil { + log.Warnf("failed to close IPv6 raw socket: %v", closeErr) + } + } return err } @@ -135,8 +157,16 @@ func (p *WGEBPFProxy) Free() error { result = multierror.Append(result, err) } - if err := p.rawConn.Close(); err != nil { - result = multierror.Append(result, err) + if p.rawConnIPv4 != nil { + if err := p.rawConnIPv4.Close(); err != nil { + result = multierror.Append(result, err) + } + } + + if p.rawConnIPv6 != nil { + if err := p.rawConnIPv6.Close(); err != nil { + result = multierror.Append(result, err) + } } return nberrors.FormatErrorOrNil(result) } @@ -218,31 +248,60 @@ generatePort: } func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error { - payload := gopacket.Payload(data) - ipH := &layers.IPv4{ - DstIP: localHostNetIP, - SrcIP: endpointAddr.IP, - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, + + var ipH gopacket.SerializableLayer + var networkLayer gopacket.NetworkLayer + var dstIP net.IP + var rawConn net.PacketConn + + if endpointAddr.IP.To4() != nil { + // IPv4 path + ipv4 := &layers.IPv4{ + DstIP: localHostNetIPv4, + SrcIP: endpointAddr.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + ipH = ipv4 + networkLayer = ipv4 + dstIP = localHostNetIPv4 + rawConn = p.rawConnIPv4 + } else { + // IPv6 path + if p.rawConnIPv6 == nil { + return fmt.Errorf("IPv6 raw socket not available") + } + ipv6 := &layers.IPv6{ + DstIP: localHostNetIPv6, + SrcIP: endpointAddr.IP, + Version: 6, + HopLimit: 64, + NextHeader: layers.IPProtocolUDP, + } + ipH = ipv6 + networkLayer = ipv6 + dstIP = localHostNetIPv6 + rawConn = p.rawConnIPv6 } + udpH := &layers.UDP{ SrcPort: layers.UDPPort(endpointAddr.Port), DstPort: layers.UDPPort(p.localWGListenPort), } - err := udpH.SetNetworkLayerForChecksum(ipH) - if err != nil { + if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil { return fmt.Errorf("set network layer for checksum: %w", err) } layerBuffer := gopacket.NewSerializeBuffer() + payload := gopacket.Payload(data) - err = gopacket.SerializeLayers(layerBuffer, gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}, ipH, udpH, payload) - if err != nil { + if err := gopacket.SerializeLayers(layerBuffer, serializeOpts, ipH, udpH, payload); err != nil { return fmt.Errorf("serialize layers: %w", err) } - if _, err = p.rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: localHostNetIP}); err != nil { + + if _, err := rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: dstIP}); err != nil { return fmt.Errorf("write to raw conn: %w", err) } return nil diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index f1f05a7c9..5b98be7b4 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -41,7 +41,7 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper { closeListener: listener.NewCloseListener(), } } -func (p *ProxyWrapper) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { +func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error { addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) @@ -91,12 +91,14 @@ func (p *ProxyWrapper) Pause() { } func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) { + if endpoint == nil || endpoint.IP == nil { + log.Errorf("failed to start package redirection, endpoint is nil") + return + } p.pausedCond.L.Lock() p.paused = false - if endpoint != nil && endpoint.IP != nil { - p.wgEndpointCurrentUsedAddr = endpoint - } + p.wgEndpointCurrentUsedAddr = endpoint p.pausedCond.Signal() p.pausedCond.L.Unlock() diff --git a/client/iface/wgproxy/rawsocket/rawsocket.go b/client/iface/wgproxy/rawsocket/rawsocket.go index a11ac46d5..bc785b43a 100644 --- a/client/iface/wgproxy/rawsocket/rawsocket.go +++ b/client/iface/wgproxy/rawsocket/rawsocket.go @@ -8,43 +8,87 @@ import ( "os" "syscall" + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" + nbnet "github.com/netbirdio/netbird/client/net" ) -func PrepareSenderRawSocket() (net.PacketConn, error) { +// PrepareSenderRawSocketIPv4 creates and configures a raw socket for sending IPv4 packets +func PrepareSenderRawSocketIPv4() (net.PacketConn, error) { + return prepareSenderRawSocket(syscall.AF_INET, true) +} + +// PrepareSenderRawSocketIPv6 creates and configures a raw socket for sending IPv6 packets +func PrepareSenderRawSocketIPv6() (net.PacketConn, error) { + return prepareSenderRawSocket(syscall.AF_INET6, false) +} + +func prepareSenderRawSocket(family int, isIPv4 bool) (net.PacketConn, error) { // Create a raw socket. - fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_RAW, syscall.IPPROTO_RAW) + fd, err := syscall.Socket(family, syscall.SOCK_RAW, syscall.IPPROTO_RAW) if err != nil { return nil, fmt.Errorf("creating raw socket failed: %w", err) } - // Set the IP_HDRINCL option on the socket to tell the kernel that headers are included in the packet. - err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, syscall.IP_HDRINCL, 1) - if err != nil { - return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) + // Set the header include option on the socket to tell the kernel that headers are included in the packet. + // For IPv4, we need to set IP_HDRINCL. For IPv6, we need to set IPV6_HDRINCL to accept application-provided IPv6 headers. + if isIPv4 { + err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IP, unix.IP_HDRINCL, 1) + if err != nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } + return nil, fmt.Errorf("setting IP_HDRINCL failed: %w", err) + } + } else { + err = syscall.SetsockoptInt(fd, syscall.IPPROTO_IPV6, unix.IPV6_HDRINCL, 1) + if err != nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } + return nil, fmt.Errorf("setting IPV6_HDRINCL failed: %w", err) + } } // Bind the socket to the "lo" interface. err = syscall.SetsockoptString(fd, syscall.SOL_SOCKET, syscall.SO_BINDTODEVICE, "lo") if err != nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } return nil, fmt.Errorf("binding to lo interface failed: %w", err) } // Set the fwmark on the socket. err = nbnet.SetSocketOpt(fd) if err != nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } return nil, fmt.Errorf("setting fwmark failed: %w", err) } // Convert the file descriptor to a PacketConn. file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd)) if file == nil { + if closeErr := syscall.Close(fd); closeErr != nil { + log.Warnf("failed to close raw socket fd: %v", closeErr) + } return nil, fmt.Errorf("converting fd to file failed") } packetConn, err := net.FilePacketConn(file) if err != nil { + if closeErr := file.Close(); closeErr != nil { + log.Warnf("failed to close file: %v", closeErr) + } return nil, fmt.Errorf("converting file to packet conn failed: %w", err) } + // Close the original file to release the FD (net.FilePacketConn duplicates it) + if closeErr := file.Close(); closeErr != nil { + log.Warnf("failed to close file after creating packet conn: %v", closeErr) + } + return packetConn, nil } diff --git a/client/iface/wgproxy/redirect_test.go b/client/iface/wgproxy/redirect_test.go new file mode 100644 index 000000000..b52eead25 --- /dev/null +++ b/client/iface/wgproxy/redirect_test.go @@ -0,0 +1,353 @@ +//go:build linux && !android + +package wgproxy + +import ( + "context" + "net" + "testing" + "time" + + "github.com/netbirdio/netbird/client/iface/wgproxy/ebpf" + "github.com/netbirdio/netbird/client/iface/wgproxy/udp" +) + +// compareUDPAddr compares two UDP addresses, ignoring IPv6 zone IDs +// IPv6 link-local addresses include zone IDs (e.g., fe80::1%lo) which we should ignore +func compareUDPAddr(addr1, addr2 net.Addr) bool { + udpAddr1, ok1 := addr1.(*net.UDPAddr) + udpAddr2, ok2 := addr2.(*net.UDPAddr) + + if !ok1 || !ok2 { + return addr1.String() == addr2.String() + } + + // Compare IP and Port, ignoring zone + return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port +} + +// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses +func TestRedirectAs_eBPF_IPv4(t *testing.T) { + wgPort := 51850 + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280) + if err := ebpfProxy.Listen(); err != nil { + t.Fatalf("failed to initialize ebpf proxy: %v", err) + } + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %v", err) + } + }() + + proxy := ebpf.NewProxyWrapper(ebpfProxy) + + // NetBird UDP address of the remote peer + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + p2pEndpoint := &net.UDPAddr{ + IP: net.ParseIP("192.168.0.56"), + Port: 51820, + } + + testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint) +} + +// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses +func TestRedirectAs_eBPF_IPv6(t *testing.T) { + wgPort := 51851 + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280) + if err := ebpfProxy.Listen(); err != nil { + t.Fatalf("failed to initialize ebpf proxy: %v", err) + } + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %v", err) + } + }() + + proxy := ebpf.NewProxyWrapper(ebpfProxy) + + // NetBird UDP address of the remote peer + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + p2pEndpoint := &net.UDPAddr{ + IP: net.ParseIP("fe80::56"), + Port: 51820, + } + + testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint) +} + +// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses +func TestRedirectAs_UDP_IPv4(t *testing.T) { + wgPort := 51852 + proxy := udp.NewWGUDPProxy(wgPort, 1280) + + // NetBird UDP address of the remote peer + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + p2pEndpoint := &net.UDPAddr{ + IP: net.ParseIP("192.168.0.56"), + Port: 51820, + } + + testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint) +} + +// TestRedirectAs_UDP_IPv6 tests RedirectAs with UDP proxy using IPv6 addresses +func TestRedirectAs_UDP_IPv6(t *testing.T) { + wgPort := 51853 + proxy := udp.NewWGUDPProxy(wgPort, 1280) + + // NetBird UDP address of the remote peer + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + p2pEndpoint := &net.UDPAddr{ + IP: net.ParseIP("fe80::56"), + Port: 51820, + } + + testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint) +} + +// testRedirectAs is a helper function that tests the RedirectAs functionality +// It verifies that: +// 1. Initial traffic from relay connection works +// 2. After calling RedirectAs, packets appear to come from the p2p endpoint +// 3. Multiple packets are correctly redirected with the new source address +func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *net.UDPAddr) { + t.Helper() + + ctx := context.Background() + + // Create WireGuard listeners on both IPv4 and IPv6 to support both P2P connection types + // In reality, WireGuard binds to a port and receives from both IPv4 and IPv6 + wgListener4, err := net.ListenUDP("udp4", &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: wgPort, + }) + if err != nil { + t.Fatalf("failed to create IPv4 WireGuard listener: %v", err) + } + defer wgListener4.Close() + + wgListener6, err := net.ListenUDP("udp6", &net.UDPAddr{ + IP: net.ParseIP("::1"), + Port: wgPort, + }) + if err != nil { + t.Fatalf("failed to create IPv6 WireGuard listener: %v", err) + } + defer wgListener6.Close() + + // Determine which listener to use based on the NetBird address IP version + // (this is where initial traffic will come from before RedirectAs is called) + var wgListener *net.UDPConn + if p2pEndpoint.IP.To4() == nil { + wgListener = wgListener6 + } else { + wgListener = wgListener4 + } + + // Create relay server and connection + relayServer, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, // Random port + }) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + defer relayServer.Close() + + relayConn, err := net.Dial("udp", relayServer.LocalAddr().String()) + if err != nil { + t.Fatalf("failed to create relay connection: %v", err) + } + defer relayConn.Close() + + // Add TURN connection to proxy + if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil { + t.Fatalf("failed to add TURN connection: %v", err) + } + defer func() { + if err := proxy.CloseConn(); err != nil { + t.Errorf("failed to close proxy connection: %v", err) + } + }() + + // Start the proxy + proxy.Work() + + // Phase 1: Test initial relay traffic + msgFromRelay := []byte("hello from relay") + if _, err := relayServer.WriteTo(msgFromRelay, relayConn.LocalAddr()); err != nil { + t.Fatalf("failed to write to relay server: %v", err) + } + + // Set read deadline to avoid hanging + if err := wgListener4.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("failed to set read deadline: %v", err) + } + + buf := make([]byte, 1024) + n, _, err := wgListener4.ReadFrom(buf) + if err != nil { + t.Fatalf("failed to read from WireGuard listener: %v", err) + } + + if n != len(msgFromRelay) { + t.Errorf("expected %d bytes, got %d", len(msgFromRelay), n) + } + + if string(buf[:n]) != string(msgFromRelay) { + t.Errorf("expected message %q, got %q", msgFromRelay, buf[:n]) + } + + // Phase 2: Redirect to p2p endpoint + proxy.RedirectAs(p2pEndpoint) + + // Give the proxy a moment to process the redirect + time.Sleep(100 * time.Millisecond) + + // Phase 3: Test redirected traffic + redirectedMessages := [][]byte{ + []byte("redirected message 1"), + []byte("redirected message 2"), + []byte("redirected message 3"), + } + + for i, msg := range redirectedMessages { + if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil { + t.Fatalf("failed to write redirected message %d: %v", i+1, err) + } + + if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("failed to set read deadline: %v", err) + } + + n, srcAddr, err := wgListener.ReadFrom(buf) + if err != nil { + t.Fatalf("failed to read redirected message %d: %v", i+1, err) + } + + // Verify message content + if string(buf[:n]) != string(msg) { + t.Errorf("message %d: expected %q, got %q", i+1, msg, buf[:n]) + } + + // Verify source address matches p2p endpoint (this is the key test) + // Use compareUDPAddr to ignore IPv6 zone IDs + if !compareUDPAddr(srcAddr, p2pEndpoint) { + t.Errorf("message %d: expected source address %s, got %s", + i+1, p2pEndpoint.String(), srcAddr.String()) + } + } +} + +// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints +func TestRedirectAs_Multiple_Switches(t *testing.T) { + wgPort := 51856 + ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280) + if err := ebpfProxy.Listen(); err != nil { + t.Fatalf("failed to initialize ebpf proxy: %v", err) + } + defer func() { + if err := ebpfProxy.Free(); err != nil { + t.Errorf("failed to free ebpf proxy: %v", err) + } + }() + + proxy := ebpf.NewProxyWrapper(ebpfProxy) + + ctx := context.Background() + + // Create WireGuard listener + wgListener, err := net.ListenUDP("udp4", &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: wgPort, + }) + if err != nil { + t.Fatalf("failed to create WireGuard listener: %v", err) + } + defer wgListener.Close() + + // Create relay server and connection + relayServer, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 0, + }) + if err != nil { + t.Fatalf("failed to create relay server: %v", err) + } + defer relayServer.Close() + + relayConn, err := net.Dial("udp", relayServer.LocalAddr().String()) + if err != nil { + t.Fatalf("failed to create relay connection: %v", err) + } + defer relayConn.Close() + + nbAddr := &net.UDPAddr{ + IP: net.ParseIP("100.108.111.177"), + Port: 38746, + } + + if err := proxy.AddTurnConn(ctx, nbAddr, relayConn); err != nil { + t.Fatalf("failed to add TURN connection: %v", err) + } + defer func() { + if err := proxy.CloseConn(); err != nil { + t.Errorf("failed to close proxy connection: %v", err) + } + }() + + proxy.Work() + + // Test switching between multiple endpoints - using addresses in local subnet + endpoints := []*net.UDPAddr{ + {IP: net.ParseIP("192.168.0.100"), Port: 51820}, + {IP: net.ParseIP("192.168.0.101"), Port: 51821}, + {IP: net.ParseIP("192.168.0.102"), Port: 51822}, + } + + for i, endpoint := range endpoints { + proxy.RedirectAs(endpoint) + time.Sleep(100 * time.Millisecond) + + msg := []byte("test message") + if _, err := relayServer.WriteTo(msg, relayConn.LocalAddr()); err != nil { + t.Fatalf("failed to write message for endpoint %d: %v", i, err) + } + + buf := make([]byte, 1024) + if err := wgListener.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { + t.Fatalf("failed to set read deadline: %v", err) + } + + n, srcAddr, err := wgListener.ReadFrom(buf) + if err != nil { + t.Fatalf("failed to read message for endpoint %d: %v", i, err) + } + + if string(buf[:n]) != string(msg) { + t.Errorf("endpoint %d: expected message %q, got %q", i, msg, buf[:n]) + } + + if !compareUDPAddr(srcAddr, endpoint) { + t.Errorf("endpoint %d: expected source %s, got %s", + i, endpoint.String(), srcAddr.String()) + } + } +} diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index 4ef2f19c4..6069d1960 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -56,7 +56,7 @@ func NewWGUDPProxy(wgPort int, mtu uint16) *WGUDPProxy { // the connection is complete, an error is returned. Once successfully // connected, any expiration of the context will not affect the // connection. -func (p *WGUDPProxy) AddTurnConn(ctx context.Context, endpoint *net.UDPAddr, remoteConn net.Conn) error { +func (p *WGUDPProxy) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error { dialer := net.Dialer{} localConn, err := dialer.DialContext(ctx, "udp", fmt.Sprintf(":%d", p.localWGListenPort)) if err != nil { diff --git a/client/iface/wgproxy/udp/rawsocket.go b/client/iface/wgproxy/udp/rawsocket.go index fdc911463..cc099d9df 100644 --- a/client/iface/wgproxy/udp/rawsocket.go +++ b/client/iface/wgproxy/udp/rawsocket.go @@ -19,37 +19,56 @@ var ( FixLengths: true, } - localHostNetIPAddr = &net.IPAddr{ + localHostNetIPAddrV4 = &net.IPAddr{ IP: net.ParseIP("127.0.0.1"), } + localHostNetIPAddrV6 = &net.IPAddr{ + IP: net.ParseIP("::1"), + } ) type SrcFaker struct { srcAddr *net.UDPAddr - rawSocket net.PacketConn - ipH gopacket.SerializableLayer - udpH gopacket.SerializableLayer - layerBuffer gopacket.SerializeBuffer + rawSocket net.PacketConn + ipH gopacket.SerializableLayer + udpH gopacket.SerializableLayer + layerBuffer gopacket.SerializeBuffer + localHostAddr *net.IPAddr } func NewSrcFaker(dstPort int, srcAddr *net.UDPAddr) (*SrcFaker, error) { - rawSocket, err := rawsocket.PrepareSenderRawSocket() + // Create only the raw socket for the address family we need + var rawSocket net.PacketConn + var err error + var localHostAddr *net.IPAddr + + if srcAddr.IP.To4() != nil { + rawSocket, err = rawsocket.PrepareSenderRawSocketIPv4() + localHostAddr = localHostNetIPAddrV4 + } else { + rawSocket, err = rawsocket.PrepareSenderRawSocketIPv6() + localHostAddr = localHostNetIPAddrV6 + } if err != nil { return nil, err } ipH, udpH, err := prepareHeaders(dstPort, srcAddr) if err != nil { + if closeErr := rawSocket.Close(); closeErr != nil { + log.Warnf("failed to close raw socket: %v", closeErr) + } return nil, err } f := &SrcFaker{ - srcAddr: srcAddr, - rawSocket: rawSocket, - ipH: ipH, - udpH: udpH, - layerBuffer: gopacket.NewSerializeBuffer(), + srcAddr: srcAddr, + rawSocket: rawSocket, + ipH: ipH, + udpH: udpH, + layerBuffer: gopacket.NewSerializeBuffer(), + localHostAddr: localHostAddr, } return f, nil @@ -72,7 +91,7 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) { if err != nil { return 0, fmt.Errorf("serialize layers: %w", err) } - n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), localHostNetIPAddr) + n, err := f.rawSocket.WriteTo(f.layerBuffer.Bytes(), f.localHostAddr) if err != nil { return 0, fmt.Errorf("write to raw conn: %w", err) } @@ -80,19 +99,40 @@ func (f *SrcFaker) SendPkg(data []byte) (int, error) { } func prepareHeaders(dstPort int, srcAddr *net.UDPAddr) (gopacket.SerializableLayer, gopacket.SerializableLayer, error) { - ipH := &layers.IPv4{ - DstIP: net.ParseIP("127.0.0.1"), - SrcIP: srcAddr.IP, - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, + var ipH gopacket.SerializableLayer + var networkLayer gopacket.NetworkLayer + + // Check if source IP is IPv4 or IPv6 + if srcAddr.IP.To4() != nil { + // IPv4 + ipv4 := &layers.IPv4{ + DstIP: localHostNetIPAddrV4.IP, + SrcIP: srcAddr.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + ipH = ipv4 + networkLayer = ipv4 + } else { + // IPv6 + ipv6 := &layers.IPv6{ + DstIP: localHostNetIPAddrV6.IP, + SrcIP: srcAddr.IP, + Version: 6, + HopLimit: 64, + NextHeader: layers.IPProtocolUDP, + } + ipH = ipv6 + networkLayer = ipv6 } + udpH := &layers.UDP{ SrcPort: layers.UDPPort(srcAddr.Port), DstPort: layers.UDPPort(dstPort), // dst is the localhost WireGuard port } - err := udpH.SetNetworkLayerForChecksum(ipH) + err := udpH.SetNetworkLayerForChecksum(networkLayer) if err != nil { return nil, nil, fmt.Errorf("set network layer for checksum: %w", err) }