diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index 5458519fa..1b1a8ce1c 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -8,8 +8,6 @@ import ( "net" "sync" - "github.com/google/gopacket" - "github.com/google/gopacket/layers" "github.com/hashicorp/go-multierror" "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" @@ -26,16 +24,6 @@ const ( loopbackAddr = "127.0.0.1" ) -var ( - localHostNetIPv4 = net.ParseIP("127.0.0.1") - localHostNetIPv6 = net.ParseIP("::1") - - serializeOpts = gopacket.SerializeOptions{ - ComputeChecksums: true, - FixLengths: true, - } -) - // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { localWGListenPort int @@ -253,63 +241,3 @@ generatePort: } return p.lastUsedPort, nil } - -func (p *WGEBPFProxy) sendPkg(data []byte, endpointAddr *net.UDPAddr) error { - - var ipH gopacket.SerializableLayer - var networkLayer gopacket.NetworkLayer - var dstIP net.IP - var rawConn net.PacketConn - - if endpointAddr.IP.To4() != nil { - // IPv4 path - ipv4 := &layers.IPv4{ - DstIP: localHostNetIPv4, - SrcIP: endpointAddr.IP, - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, - } - ipH = ipv4 - networkLayer = ipv4 - dstIP = localHostNetIPv4 - rawConn = p.rawConnIPv4 - } else { - // IPv6 path - if p.rawConnIPv6 == nil { - return fmt.Errorf("IPv6 raw socket not available") - } - ipv6 := &layers.IPv6{ - DstIP: localHostNetIPv6, - SrcIP: endpointAddr.IP, - Version: 6, - HopLimit: 64, - NextHeader: layers.IPProtocolUDP, - } - ipH = ipv6 - networkLayer = ipv6 - dstIP = localHostNetIPv6 - rawConn = p.rawConnIPv6 - } - - udpH := &layers.UDP{ - SrcPort: layers.UDPPort(endpointAddr.Port), - DstPort: layers.UDPPort(p.localWGListenPort), - } - - if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil { - return fmt.Errorf("set network layer for checksum: %w", err) - } - - layerBuffer := gopacket.NewSerializeBuffer() - payload := gopacket.Payload(data) - - if err := gopacket.SerializeLayers(layerBuffer, serializeOpts, ipH, udpH, payload); err != nil { - return fmt.Errorf("serialize layers: %w", err) - } - - if _, err := rawConn.WriteTo(layerBuffer.Bytes(), &net.IPAddr{IP: dstIP}); err != nil { - return fmt.Errorf("write to raw conn: %w", err) - } - return nil -} diff --git a/client/iface/wgproxy/ebpf/wrapper.go b/client/iface/wgproxy/ebpf/wrapper.go index 5b98be7b4..6e80945c4 100644 --- a/client/iface/wgproxy/ebpf/wrapper.go +++ b/client/iface/wgproxy/ebpf/wrapper.go @@ -10,12 +10,89 @@ import ( "net" "sync" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/iface/bufsize" "github.com/netbirdio/netbird/client/iface/wgproxy/listener" ) +var ( + errIPv6ConnNotAvailable = errors.New("IPv6 endpoint but rawConnIPv6 is not available") + errIPv4ConnNotAvailable = errors.New("IPv4 endpoint but rawConnIPv4 is not available") + + localHostNetIPv4 = net.ParseIP("127.0.0.1") + localHostNetIPv6 = net.ParseIP("::1") + + serializeOpts = gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } +) + +// PacketHeaders holds pre-created headers and buffers for efficient packet sending +type PacketHeaders struct { + ipH gopacket.SerializableLayer + udpH *layers.UDP + layerBuffer gopacket.SerializeBuffer + localHostAddr net.IP + isIPv4 bool +} + +func NewPacketHeaders(localWGListenPort int, endpoint *net.UDPAddr) (*PacketHeaders, error) { + var ipH gopacket.SerializableLayer + var networkLayer gopacket.NetworkLayer + var localHostAddr net.IP + var isIPv4 bool + + // Check if source address is IPv4 or IPv6 + if endpoint.IP.To4() != nil { + // IPv4 path + ipv4 := &layers.IPv4{ + DstIP: localHostNetIPv4, + SrcIP: endpoint.IP, + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolUDP, + } + ipH = ipv4 + networkLayer = ipv4 + localHostAddr = localHostNetIPv4 + isIPv4 = true + } else { + // IPv6 path + ipv6 := &layers.IPv6{ + DstIP: localHostNetIPv6, + SrcIP: endpoint.IP, + Version: 6, + HopLimit: 64, + NextHeader: layers.IPProtocolUDP, + } + ipH = ipv6 + networkLayer = ipv6 + localHostAddr = localHostNetIPv6 + isIPv4 = false + } + + udpH := &layers.UDP{ + SrcPort: layers.UDPPort(endpoint.Port), + DstPort: layers.UDPPort(localWGListenPort), + } + + if err := udpH.SetNetworkLayerForChecksum(networkLayer); err != nil { + return nil, fmt.Errorf("set network layer for checksum: %w", err) + } + + return &PacketHeaders{ + ipH: ipH, + udpH: udpH, + layerBuffer: gopacket.NewSerializeBuffer(), + localHostAddr: localHostAddr, + isIPv4: isIPv4, + }, nil +} + // ProxyWrapper help to keep the remoteConn instance for net.Conn.Close function call type ProxyWrapper struct { wgeBPFProxy *WGEBPFProxy @@ -24,8 +101,10 @@ type ProxyWrapper struct { ctx context.Context cancel context.CancelFunc - wgRelayedEndpointAddr *net.UDPAddr - wgEndpointCurrentUsedAddr *net.UDPAddr + wgRelayedEndpointAddr *net.UDPAddr + headers *PacketHeaders + headerCurrentUsed *PacketHeaders + rawConn net.PacketConn paused bool pausedCond *sync.Cond @@ -41,15 +120,32 @@ func NewProxyWrapper(proxy *WGEBPFProxy) *ProxyWrapper { closeListener: listener.NewCloseListener(), } } + func (p *ProxyWrapper) AddTurnConn(ctx context.Context, _ *net.UDPAddr, remoteConn net.Conn) error { addr, err := p.wgeBPFProxy.AddTurnConn(remoteConn) if err != nil { return fmt.Errorf("add turn conn: %w", err) } + + headers, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, addr) + if err != nil { + return fmt.Errorf("create packet sender: %w", err) + } + + // Check if required raw connection is available + if !headers.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil { + return errIPv6ConnNotAvailable + } + if headers.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil { + return errIPv4ConnNotAvailable + } + p.remoteConn = remoteConn p.ctx, p.cancel = context.WithCancel(ctx) p.wgRelayedEndpointAddr = addr - return err + p.headers = headers + p.rawConn = p.selectRawConn(headers) + return nil } func (p *ProxyWrapper) EndpointAddr() *net.UDPAddr { @@ -68,7 +164,8 @@ func (p *ProxyWrapper) Work() { p.pausedCond.L.Lock() p.paused = false - p.wgEndpointCurrentUsedAddr = p.wgRelayedEndpointAddr + p.headerCurrentUsed = p.headers + p.rawConn = p.selectRawConn(p.headerCurrentUsed) if !p.isStarted { p.isStarted = true @@ -95,10 +192,28 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) { log.Errorf("failed to start package redirection, endpoint is nil") return } + + header, err := NewPacketHeaders(p.wgeBPFProxy.localWGListenPort, endpoint) + if err != nil { + log.Errorf("failed to create packet headers: %s", err) + return + } + + // Check if required raw connection is available + if !header.isIPv4 && p.wgeBPFProxy.rawConnIPv6 == nil { + log.Error(errIPv6ConnNotAvailable) + return + } + if header.isIPv4 && p.wgeBPFProxy.rawConnIPv4 == nil { + log.Error(errIPv4ConnNotAvailable) + return + } + p.pausedCond.L.Lock() p.paused = false - p.wgEndpointCurrentUsedAddr = endpoint + p.headerCurrentUsed = header + p.rawConn = p.selectRawConn(header) p.pausedCond.Signal() p.pausedCond.L.Unlock() @@ -140,7 +255,7 @@ func (p *ProxyWrapper) proxyToLocal(ctx context.Context) { p.pausedCond.Wait() } - err = p.wgeBPFProxy.sendPkg(buf[:n], p.wgEndpointCurrentUsedAddr) + err = p.sendPkg(buf[:n], p.headerCurrentUsed) p.pausedCond.L.Unlock() if err != nil { @@ -166,3 +281,29 @@ func (p *ProxyWrapper) readFromRemote(ctx context.Context, buf []byte) (int, err } return n, nil } + +func (p *ProxyWrapper) sendPkg(data []byte, header *PacketHeaders) error { + defer func() { + if err := header.layerBuffer.Clear(); err != nil { + log.Errorf("failed to clear layer buffer: %s", err) + } + }() + + payload := gopacket.Payload(data) + + if err := gopacket.SerializeLayers(header.layerBuffer, serializeOpts, header.ipH, header.udpH, payload); err != nil { + return fmt.Errorf("serialize layers: %w", err) + } + + if _, err := p.rawConn.WriteTo(header.layerBuffer.Bytes(), &net.IPAddr{IP: header.localHostAddr}); err != nil { + return fmt.Errorf("write to raw conn: %w", err) + } + return nil +} + +func (p *ProxyWrapper) selectRawConn(header *PacketHeaders) net.PacketConn { + if header.isIPv4 { + return p.wgeBPFProxy.rawConnIPv4 + } + return p.wgeBPFProxy.rawConnIPv6 +}