From d4f7df271aa1ddd330e872aab70a7f0451e31405 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 27 Jan 2026 18:04:23 +0800 Subject: [PATCH] [cllient] Don't track ebpf traffic in conntrack (#5166) --- client/firewall/iptables/manager_linux.go | 127 ++++++++++++++ client/firewall/manager/firewall.go | 4 + client/firewall/nftables/manager_linux.go | 191 +++++++++++++++++++++- client/firewall/uspfilter/filter.go | 8 + client/iface/iface.go | 7 + client/iface/wgproxy/ebpf/proxy.go | 15 +- client/iface/wgproxy/factory_kernel.go | 8 + client/iface/wgproxy/factory_usp.go | 5 + client/internal/engine.go | 22 +++ client/internal/engine_test.go | 8 + client/internal/iface_common.go | 1 + 11 files changed, 389 insertions(+), 7 deletions(-) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 2563a9052..716385705 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -83,6 +83,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { return fmt.Errorf("acl manager init: %w", err) } + if err := m.initNoTrackChain(); err != nil { + return fmt.Errorf("init notrack chain: %w", err) + } + // persist early to ensure cleanup of chains go func() { if err := stateManager.PersistState(context.Background()); err != nil { @@ -177,6 +181,10 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { var merr *multierror.Error + if err := m.cleanupNoTrackChain(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("cleanup notrack chain: %w", err)) + } + if err := m.aclMgr.Reset(); err != nil { merr = multierror.Append(merr, fmt.Errorf("reset acl manager: %w", err)) } @@ -277,6 +285,125 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } +const ( + chainNameRaw = "NETBIRD-RAW" + chainOUTPUT = "OUTPUT" + tableRaw = "raw" +) + +// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic. +// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which +// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark). +// +// Traffic flows that need NOTRACK: +// +// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite) +// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort +// Matched by: sport=wgPort +// +// 2. Egress: Proxy -> WireGuard (via raw socket) +// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort +// Matched by: dport=wgPort +// +// 3. Ingress: Packets to WireGuard +// dst=127.0.0.1:wgPort +// Matched by: dport=wgPort +// +// 4. Ingress: Packets to proxy (after eBPF rewrite) +// dst=127.0.0.1:proxyPort +// Matched by: dport=proxyPort +// +// Rules are cleaned up when the firewall manager is closed. +func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + wgPortStr := fmt.Sprintf("%d", wgPort) + proxyPortStr := fmt.Sprintf("%d", proxyPort) + + // Egress rules: match outgoing loopback UDP packets + outputRuleSport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--sport", wgPortStr, "-j", "NOTRACK"} + if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleSport...); err != nil { + return fmt.Errorf("add output sport notrack rule: %w", err) + } + + outputRuleDport := []string{"-o", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"} + if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, outputRuleDport...); err != nil { + return fmt.Errorf("add output dport notrack rule: %w", err) + } + + // Ingress rules: match incoming loopback UDP packets + preroutingRuleWg := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", wgPortStr, "-j", "NOTRACK"} + if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleWg...); err != nil { + return fmt.Errorf("add prerouting wg notrack rule: %w", err) + } + + preroutingRuleProxy := []string{"-i", "lo", "-s", "127.0.0.1", "-d", "127.0.0.1", "-p", "udp", "--dport", proxyPortStr, "-j", "NOTRACK"} + if err := m.ipv4Client.AppendUnique(tableRaw, chainNameRaw, preroutingRuleProxy...); err != nil { + return fmt.Errorf("add prerouting proxy notrack rule: %w", err) + } + + log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort) + return nil +} + +func (m *Manager) initNoTrackChain() error { + if err := m.cleanupNoTrackChain(); err != nil { + log.Debugf("cleanup notrack chain: %v", err) + } + + if err := m.ipv4Client.NewChain(tableRaw, chainNameRaw); err != nil { + return fmt.Errorf("create chain: %w", err) + } + + jumpRule := []string{"-j", chainNameRaw} + + if err := m.ipv4Client.InsertUnique(tableRaw, chainOUTPUT, 1, jumpRule...); err != nil { + if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil { + log.Debugf("delete orphan chain: %v", delErr) + } + return fmt.Errorf("add output jump rule: %w", err) + } + + if err := m.ipv4Client.InsertUnique(tableRaw, chainPREROUTING, 1, jumpRule...); err != nil { + if delErr := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); delErr != nil { + log.Debugf("delete output jump rule: %v", delErr) + } + if delErr := m.ipv4Client.DeleteChain(tableRaw, chainNameRaw); delErr != nil { + log.Debugf("delete orphan chain: %v", delErr) + } + return fmt.Errorf("add prerouting jump rule: %w", err) + } + + return nil +} + +func (m *Manager) cleanupNoTrackChain() error { + exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw) + if err != nil { + return fmt.Errorf("check chain exists: %w", err) + } + if !exists { + return nil + } + + jumpRule := []string{"-j", chainNameRaw} + + if err := m.ipv4Client.DeleteIfExists(tableRaw, chainOUTPUT, jumpRule...); err != nil { + return fmt.Errorf("remove output jump rule: %w", err) + } + + if err := m.ipv4Client.DeleteIfExists(tableRaw, chainPREROUTING, jumpRule...); err != nil { + return fmt.Errorf("remove prerouting jump rule: %w", err) + } + + if err := m.ipv4Client.ClearAndDeleteChain(tableRaw, chainNameRaw); err != nil { + return fmt.Errorf("clear and delete chain: %w", err) + } + + return nil +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 72e6a5c68..3511a5463 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -168,6 +168,10 @@ type Manager interface { // RemoveInboundDNAT removes inbound DNAT rule RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + + // SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic. + // This prevents conntrack from interfering with WireGuard proxy communication. + SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error } func GenKey(format string, pair RouterPair) string { diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index bd19f1067..acf482f86 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -12,6 +12,7 @@ import ( "github.com/google/nftables/binaryutil" "github.com/google/nftables/expr" log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" @@ -48,8 +49,10 @@ type Manager struct { rConn *nftables.Conn wgIface iFaceMapper - router *router - aclManager *AclManager + router *router + aclManager *AclManager + notrackOutputChain *nftables.Chain + notrackPreroutingChain *nftables.Chain } // Create nftables firewall manager @@ -91,6 +94,10 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { return fmt.Errorf("acl manager init: %w", err) } + if err := m.initNoTrackChains(workTable); err != nil { + return fmt.Errorf("init notrack chains: %w", err) + } + stateManager.RegisterState(&ShutdownState{}) // We only need to record minimal interface state for potential recreation. @@ -288,7 +295,15 @@ func (m *Manager) Flush() error { m.mutex.Lock() defer m.mutex.Unlock() - return m.aclManager.Flush() + if err := m.aclManager.Flush(); err != nil { + return err + } + + if err := m.refreshNoTrackChains(); err != nil { + log.Errorf("failed to refresh notrack chains: %v", err) + } + + return nil } // AddDNATRule adds a DNAT rule @@ -331,6 +346,176 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) } +const ( + chainNameRawOutput = "netbird-raw-out" + chainNameRawPrerouting = "netbird-raw-pre" +) + +// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic. +// This prevents conntrack from tracking WireGuard proxy traffic on loopback, which +// can interfere with MASQUERADE rules (e.g., from container runtimes like Podman/netavark). +// +// Traffic flows that need NOTRACK: +// +// 1. Egress: WireGuard -> fake endpoint (before eBPF rewrite) +// src=127.0.0.1:wgPort -> dst=127.0.0.1:fakePort +// Matched by: sport=wgPort +// +// 2. Egress: Proxy -> WireGuard (via raw socket) +// src=127.0.0.1:fakePort -> dst=127.0.0.1:wgPort +// Matched by: dport=wgPort +// +// 3. Ingress: Packets to WireGuard +// dst=127.0.0.1:wgPort +// Matched by: dport=wgPort +// +// 4. Ingress: Packets to proxy (after eBPF rewrite) +// dst=127.0.0.1:proxyPort +// Matched by: dport=proxyPort +// +// Rules are cleaned up when the firewall manager is closed. +func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + if m.notrackOutputChain == nil || m.notrackPreroutingChain == nil { + return fmt.Errorf("notrack chains not initialized") + } + + proxyPortBytes := binaryutil.BigEndian.PutUint16(proxyPort) + wgPortBytes := binaryutil.BigEndian.PutUint16(wgPort) + loopback := []byte{127, 0, 0, 1} + + // Egress rules: match outgoing loopback UDP packets + m.rConn.AddRule(&nftables.Rule{ + Table: m.notrackOutputChain.Table, + Chain: m.notrackOutputChain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 0, Len: 2}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // sport=wgPort + &expr.Counter{}, + &expr.Notrack{}, + }, + }) + m.rConn.AddRule(&nftables.Rule{ + Table: m.notrackOutputChain.Table, + Chain: m.notrackOutputChain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort + &expr.Counter{}, + &expr.Notrack{}, + }, + }) + + // Ingress rules: match incoming loopback UDP packets + m.rConn.AddRule(&nftables.Rule{ + Table: m.notrackPreroutingChain.Table, + Chain: m.notrackPreroutingChain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: wgPortBytes}, // dport=wgPort + &expr.Counter{}, + &expr.Notrack{}, + }, + }) + m.rConn.AddRule(&nftables.Rule{ + Table: m.notrackPreroutingChain.Table, + Chain: m.notrackPreroutingChain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname("lo")}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4}, // saddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: 16, Len: 4}, // daddr + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: loopback}, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}}, + &expr.Payload{DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2}, + &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: proxyPortBytes}, // dport=proxyPort + &expr.Counter{}, + &expr.Notrack{}, + }, + }) + + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf("flush notrack rules: %w", err) + } + + log.Debugf("set up ebpf proxy notrack rules for ports %d,%d", proxyPort, wgPort) + return nil +} + +func (m *Manager) initNoTrackChains(table *nftables.Table) error { + m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{ + Name: chainNameRawOutput, + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookOutput, + Priority: nftables.ChainPriorityRaw, + }) + + m.notrackPreroutingChain = m.rConn.AddChain(&nftables.Chain{ + Name: chainNameRawPrerouting, + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityRaw, + }) + + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf("flush chain creation: %w", err) + } + + return nil +} + +func (m *Manager) refreshNoTrackChains() error { + chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + if err != nil { + return fmt.Errorf("list chains: %w", err) + } + + tableName := getTableName() + for _, c := range chains { + if c.Table.Name != tableName { + continue + } + switch c.Name { + case chainNameRawOutput: + m.notrackOutputChain = c + case chainNameRawPrerouting: + m.notrackPreroutingChain = c + } + } + + return nil +} + func (m *Manager) createWorkTable() (*nftables.Table, error) { tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 8caa1a0ad..aacc4ca1c 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -570,6 +570,14 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error { // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } +// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic. +func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error { + if m.nativeFirewall == nil { + return nil + } + return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort) +} + // UpdateSet updates the rule destinations associated with the given set // by merging the existing prefixes with the new ones, then deduplicating. func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { diff --git a/client/iface/iface.go b/client/iface/iface.go index 71fd433ad..e5623c979 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -50,6 +50,7 @@ func ValidateMTU(mtu uint16) error { type wgProxyFactory interface { GetProxy() wgproxy.Proxy + GetProxyPort() uint16 Free() error } @@ -80,6 +81,12 @@ func (w *WGIface) GetProxy() wgproxy.Proxy { return w.wgProxyFactory.GetProxy() } +// GetProxyPort returns the proxy port used by the WireGuard proxy. +// Returns 0 if no proxy port is used (e.g., for userspace WireGuard). +func (w *WGIface) GetProxyPort() uint16 { + return w.wgProxyFactory.GetProxyPort() +} + // GetBind returns the EndpointManager userspace bind mode. func (w *WGIface) GetBind() device.EndpointManager { w.mu.Lock() diff --git a/client/iface/wgproxy/ebpf/proxy.go b/client/iface/wgproxy/ebpf/proxy.go index 0c1c886d7..5458519fa 100644 --- a/client/iface/wgproxy/ebpf/proxy.go +++ b/client/iface/wgproxy/ebpf/proxy.go @@ -39,6 +39,7 @@ var ( // WGEBPFProxy definition for proxy with EBPF support type WGEBPFProxy struct { localWGListenPort int + proxyPort int mtu uint16 ebpfManager ebpfMgr.Manager @@ -69,10 +70,11 @@ func NewWGEBPFProxy(wgPort int, mtu uint16) *WGEBPFProxy { // Listen load ebpf program and listen the proxy func (p *WGEBPFProxy) Listen() error { pl := portLookup{} - wgPorxyPort, err := pl.searchFreePort() + proxyPort, err := pl.searchFreePort() if err != nil { return err } + p.proxyPort = proxyPort // Prepare IPv4 raw socket (required) p.rawConnIPv4, err = rawsocket.PrepareSenderRawSocketIPv4() @@ -86,7 +88,7 @@ func (p *WGEBPFProxy) Listen() error { log.Warnf("failed to prepare IPv6 raw socket, continuing with IPv4 only: %v", err) } - err = p.ebpfManager.LoadWgProxy(wgPorxyPort, p.localWGListenPort) + err = p.ebpfManager.LoadWgProxy(proxyPort, p.localWGListenPort) if err != nil { if closeErr := p.rawConnIPv4.Close(); closeErr != nil { log.Warnf("failed to close IPv4 raw socket: %v", closeErr) @@ -100,7 +102,7 @@ func (p *WGEBPFProxy) Listen() error { } addr := net.UDPAddr{ - Port: wgPorxyPort, + Port: proxyPort, IP: net.ParseIP(loopbackAddr), } @@ -116,7 +118,7 @@ func (p *WGEBPFProxy) Listen() error { p.conn = conn go p.proxyToRemote() - log.Infof("local wg proxy listening on: %d", wgPorxyPort) + log.Infof("local wg proxy listening on: %d", proxyPort) return nil } @@ -171,6 +173,11 @@ func (p *WGEBPFProxy) Free() error { return nberrors.FormatErrorOrNil(result) } +// GetProxyPort returns the proxy listening port. +func (p *WGEBPFProxy) GetProxyPort() uint16 { + return uint16(p.proxyPort) +} + // proxyToRemote read messages from local WireGuard interface and forward it to remote conn // From this go routine has only one instance. func (p *WGEBPFProxy) proxyToRemote() { diff --git a/client/iface/wgproxy/factory_kernel.go b/client/iface/wgproxy/factory_kernel.go index 2714c5774..7821df3de 100644 --- a/client/iface/wgproxy/factory_kernel.go +++ b/client/iface/wgproxy/factory_kernel.go @@ -54,6 +54,14 @@ func (w *KernelFactory) GetProxy() Proxy { return ebpf.NewProxyWrapper(w.ebpfProxy) } +// GetProxyPort returns the eBPF proxy port, or 0 if eBPF is not active. +func (w *KernelFactory) GetProxyPort() uint16 { + if w.ebpfProxy == nil { + return 0 + } + return w.ebpfProxy.GetProxyPort() +} + func (w *KernelFactory) Free() error { if w.ebpfProxy == nil { return nil diff --git a/client/iface/wgproxy/factory_usp.go b/client/iface/wgproxy/factory_usp.go index a1b1c34d7..bbd67e076 100644 --- a/client/iface/wgproxy/factory_usp.go +++ b/client/iface/wgproxy/factory_usp.go @@ -24,6 +24,11 @@ func (w *USPFactory) GetProxy() Proxy { return proxyBind.NewProxyBind(w.bind, w.mtu) } +// GetProxyPort returns 0 as userspace WireGuard doesn't use a separate proxy port. +func (w *USPFactory) GetProxyPort() uint16 { + return 0 +} + func (w *USPFactory) Free() error { return nil } diff --git a/client/internal/engine.go b/client/internal/engine.go index a391ba22a..f0693e82c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -505,6 +505,10 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) return fmt.Errorf("up wg interface: %w", err) } + // Set up notrack rules immediately after proxy is listening to prevent + // conntrack entries from being created before the rules are in place + e.setupWGProxyNoTrack() + // Set the WireGuard interface for rosenpass after interface is up if e.rpManager != nil { e.rpManager.SetInterface(e.wgInterface) @@ -617,6 +621,23 @@ func (e *Engine) initFirewall() error { return nil } +// setupWGProxyNoTrack configures connection tracking exclusion for WireGuard proxy traffic. +// This prevents conntrack/MASQUERADE from affecting loopback traffic between WireGuard and the eBPF proxy. +func (e *Engine) setupWGProxyNoTrack() { + if e.firewall == nil { + return + } + + proxyPort := e.wgInterface.GetProxyPort() + if proxyPort == 0 { + return + } + + if err := e.firewall.SetupEBPFProxyNoTrack(proxyPort, uint16(e.config.WgPort)); err != nil { + log.Warnf("failed to setup ebpf proxy notrack: %v", err) + } +} + func (e *Engine) blockLanAccess() { if e.config.BlockInbound { // no need to set up extra deny rules if inbound is already blocked in general @@ -1644,6 +1665,7 @@ func (e *Engine) parseNATExternalIPMappings() []string { func (e *Engine) close() { log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) + if e.wgInterface != nil { if err := e.wgInterface.Close(); err != nil { log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index af9f27a71..012c8ad6e 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -107,6 +107,7 @@ type MockWGIface struct { GetStatsFunc func() (map[string]configurer.WGStats, error) GetInterfaceGUIDStringFunc func() (string, error) GetProxyFunc func() wgproxy.Proxy + GetProxyPortFunc func() uint16 GetNetFunc func() *netstack.Net LastActivitiesFunc func() map[string]monotime.Time } @@ -203,6 +204,13 @@ func (m *MockWGIface) GetProxy() wgproxy.Proxy { return m.GetProxyFunc() } +func (m *MockWGIface) GetProxyPort() uint16 { + if m.GetProxyPortFunc != nil { + return m.GetProxyPortFunc() + } + return 0 +} + func (m *MockWGIface) GetNet() *netstack.Net { return m.GetNetFunc() } diff --git a/client/internal/iface_common.go b/client/internal/iface_common.go index f8a433a6e..39e9bacfa 100644 --- a/client/internal/iface_common.go +++ b/client/internal/iface_common.go @@ -28,6 +28,7 @@ type wgIfaceBase interface { Up() (*udpmux.UniversalUDPMuxDefault, error) UpdateAddr(newAddr string) error GetProxy() wgproxy.Proxy + GetProxyPort() uint16 UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error RemoveEndpointAddress(key string) error RemovePeer(peerKey string) error