diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index d781ebd77..d916ebad4 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -56,6 +56,13 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu) } + // Native firewall handles packet filtering, but the userspace WireGuard bind + // needs a device filter for DNS interception hooks. Install a minimal + // hooks-only filter that passes all traffic through to the kernel firewall. + if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil { + log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err) + } + return fm, nil } diff --git a/client/firewall/uspfilter/common/hooks.go b/client/firewall/uspfilter/common/hooks.go new file mode 100644 index 000000000..dadd800dd --- /dev/null +++ b/client/firewall/uspfilter/common/hooks.go @@ -0,0 +1,37 @@ +package common + +import ( + "net/netip" + "sync/atomic" +) + +// PacketHook stores a registered hook for a specific IP:port. +type PacketHook struct { + IP netip.Addr + Port uint16 + Fn func([]byte) bool +} + +// HookMatches checks if a packet's destination matches the hook and invokes it. +func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool { + if h == nil { + return false + } + if h.IP == dstIP && h.Port == dport { + return h.Fn(packetData) + } + return false +} + +// SetHook atomically stores a hook, handling nil removal. +func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) { + if hook == nil { + ptr.Store(nil) + return + } + ptr.Store(&PacketHook{ + IP: ip, + Port: dPort, + Fn: hook, + }) +} diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index cb9e1bb0a..24b3d0167 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -142,15 +142,8 @@ type Manager struct { mssClampEnabled bool // Only one hook per protocol is supported. Outbound direction only. - udpHookOut atomic.Pointer[packetHook] - tcpHookOut atomic.Pointer[packetHook] -} - -// packetHook stores a registered hook for a specific IP:port. -type packetHook struct { - ip netip.Addr - port uint16 - fn func([]byte) bool + udpHookOut atomic.Pointer[common.PacketHook] + tcpHookOut atomic.Pointer[common.PacketHook] } // decoder for packages @@ -912,21 +905,11 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt } func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool { - return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData) + return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData) } func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool { - return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData) -} - -func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool { - if h == nil { - return false - } - if h.ip == dstIP && h.port == dport { - return h.fn(packetData) - } - return false + return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData) } // filterInbound implements filtering logic for incoming packets. @@ -1337,28 +1320,12 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot // SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove. func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) { - if hook == nil { - m.udpHookOut.Store(nil) - return - } - m.udpHookOut.Store(&packetHook{ - ip: ip, - port: dPort, - fn: hook, - }) + common.SetHook(&m.udpHookOut, ip, dPort, hook) } // SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove. func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) { - if hook == nil { - m.tcpHookOut.Store(nil) - return - } - m.tcpHookOut.Store(&packetHook{ - ip: ip, - port: dPort, - fn: hook, - }) + common.SetHook(&m.tcpHookOut, ip, dPort, hook) } // SetLogLevel sets the log level for the firewall manager diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 5f0f9f860..39e8efa2c 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -202,9 +202,9 @@ func TestSetUDPPacketHook(t *testing.T) { h := manager.udpHookOut.Load() require.NotNil(t, h) - assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip) - assert.Equal(t, uint16(8000), h.port) - assert.True(t, h.fn(nil)) + assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP) + assert.Equal(t, uint16(8000), h.Port) + assert.True(t, h.Fn(nil)) assert.True(t, called) manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil) @@ -226,9 +226,9 @@ func TestSetTCPPacketHook(t *testing.T) { h := manager.tcpHookOut.Load() require.NotNil(t, h) - assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip) - assert.Equal(t, uint16(53), h.port) - assert.True(t, h.fn(nil)) + assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP) + assert.Equal(t, uint16(53), h.Port) + assert.True(t, h.Fn(nil)) assert.True(t, called) manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil) diff --git a/client/firewall/uspfilter/hooks_filter.go b/client/firewall/uspfilter/hooks_filter.go new file mode 100644 index 000000000..8d3cc0f5c --- /dev/null +++ b/client/firewall/uspfilter/hooks_filter.go @@ -0,0 +1,90 @@ +package uspfilter + +import ( + "encoding/binary" + "net/netip" + "sync/atomic" + + "github.com/netbirdio/netbird/client/firewall/uspfilter/common" + "github.com/netbirdio/netbird/client/iface/device" +) + +const ( + ipv4HeaderMinLen = 20 + ipv4ProtoOffset = 9 + ipv4FlagsOffset = 6 + ipv4DstOffset = 16 + ipProtoUDP = 17 + ipProtoTCP = 6 + ipv4FragOffMask = 0x1fff + // dstPortOffset is the offset of the destination port within a UDP or TCP header. + dstPortOffset = 2 +) + +// HooksFilter is a minimal packet filter that only handles outbound DNS hooks. +// It is installed on the WireGuard interface when the userspace bind is active +// but a full firewall filter (Manager) is not needed because a native kernel +// firewall (nftables/iptables) handles packet filtering. +type HooksFilter struct { + udpHook atomic.Pointer[common.PacketHook] + tcpHook atomic.Pointer[common.PacketHook] +} + +var _ device.PacketFilter = (*HooksFilter)(nil) + +// FilterOutbound checks outbound packets for DNS hook matches. +// Only IPv4 packets matching the registered hook IP:port are intercepted. +// IPv6 and non-IP packets pass through unconditionally. +func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool { + if len(packetData) < ipv4HeaderMinLen { + return false + } + + // Only process IPv4 packets, let everything else pass through. + if packetData[0]>>4 != 4 { + return false + } + + ihl := int(packetData[0]&0x0f) * 4 + if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 { + return false + } + + // Skip non-first fragments: they don't carry L4 headers. + flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2]) + if flagsAndOffset&ipv4FragOffMask != 0 { + return false + } + + dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4]) + if !ok { + return false + } + + proto := packetData[ipv4ProtoOffset] + dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2]) + + switch proto { + case ipProtoUDP: + return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData) + case ipProtoTCP: + return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData) + default: + return false + } +} + +// FilterInbound allows all inbound packets (native firewall handles filtering). +func (f *HooksFilter) FilterInbound([]byte, int) bool { + return false +} + +// SetUDPPacketHook registers the UDP packet hook. +func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) { + common.SetHook(&f.udpHook, ip, dPort, hook) +} + +// SetTCPPacketHook registers the TCP packet hook. +func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) { + common.SetHook(&f.tcpHook, ip, dPort, hook) +}