mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 17:26:40 +00:00
[client] Fix DNS resolution with userspace WireGuard and kernel firewall (#5873)
This commit is contained in:
37
client/firewall/uspfilter/common/hooks.go
Normal file
37
client/firewall/uspfilter/common/hooks.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// PacketHook stores a registered hook for a specific IP:port.
|
||||
type PacketHook struct {
|
||||
IP netip.Addr
|
||||
Port uint16
|
||||
Fn func([]byte) bool
|
||||
}
|
||||
|
||||
// HookMatches checks if a packet's destination matches the hook and invokes it.
|
||||
func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
|
||||
if h == nil {
|
||||
return false
|
||||
}
|
||||
if h.IP == dstIP && h.Port == dport {
|
||||
return h.Fn(packetData)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SetHook atomically stores a hook, handling nil removal.
|
||||
func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||
if hook == nil {
|
||||
ptr.Store(nil)
|
||||
return
|
||||
}
|
||||
ptr.Store(&PacketHook{
|
||||
IP: ip,
|
||||
Port: dPort,
|
||||
Fn: hook,
|
||||
})
|
||||
}
|
||||
@@ -142,15 +142,8 @@ type Manager struct {
|
||||
mssClampEnabled bool
|
||||
|
||||
// Only one hook per protocol is supported. Outbound direction only.
|
||||
udpHookOut atomic.Pointer[packetHook]
|
||||
tcpHookOut atomic.Pointer[packetHook]
|
||||
}
|
||||
|
||||
// packetHook stores a registered hook for a specific IP:port.
|
||||
type packetHook struct {
|
||||
ip netip.Addr
|
||||
port uint16
|
||||
fn func([]byte) bool
|
||||
udpHookOut atomic.Pointer[common.PacketHook]
|
||||
tcpHookOut atomic.Pointer[common.PacketHook]
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -912,21 +905,11 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
||||
}
|
||||
|
||||
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
|
||||
return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
|
||||
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
|
||||
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
|
||||
if h == nil {
|
||||
return false
|
||||
}
|
||||
if h.ip == dstIP && h.port == dport {
|
||||
return h.fn(packetData)
|
||||
}
|
||||
return false
|
||||
return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
|
||||
}
|
||||
|
||||
// filterInbound implements filtering logic for incoming packets.
|
||||
@@ -1337,28 +1320,12 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
|
||||
|
||||
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
|
||||
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||
if hook == nil {
|
||||
m.udpHookOut.Store(nil)
|
||||
return
|
||||
}
|
||||
m.udpHookOut.Store(&packetHook{
|
||||
ip: ip,
|
||||
port: dPort,
|
||||
fn: hook,
|
||||
})
|
||||
common.SetHook(&m.udpHookOut, ip, dPort, hook)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
|
||||
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
|
||||
if hook == nil {
|
||||
m.tcpHookOut.Store(nil)
|
||||
return
|
||||
}
|
||||
m.tcpHookOut.Store(&packetHook{
|
||||
ip: ip,
|
||||
port: dPort,
|
||||
fn: hook,
|
||||
})
|
||||
common.SetHook(&m.tcpHookOut, ip, dPort, hook)
|
||||
}
|
||||
|
||||
// SetLogLevel sets the log level for the firewall manager
|
||||
|
||||
@@ -202,9 +202,9 @@ func TestSetUDPPacketHook(t *testing.T) {
|
||||
|
||||
h := manager.udpHookOut.Load()
|
||||
require.NotNil(t, h)
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
||||
assert.Equal(t, uint16(8000), h.port)
|
||||
assert.True(t, h.fn(nil))
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
|
||||
assert.Equal(t, uint16(8000), h.Port)
|
||||
assert.True(t, h.Fn(nil))
|
||||
assert.True(t, called)
|
||||
|
||||
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
|
||||
@@ -226,9 +226,9 @@ func TestSetTCPPacketHook(t *testing.T) {
|
||||
|
||||
h := manager.tcpHookOut.Load()
|
||||
require.NotNil(t, h)
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
|
||||
assert.Equal(t, uint16(53), h.port)
|
||||
assert.True(t, h.fn(nil))
|
||||
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
|
||||
assert.Equal(t, uint16(53), h.Port)
|
||||
assert.True(t, h.Fn(nil))
|
||||
assert.True(t, called)
|
||||
|
||||
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
|
||||
|
||||
90
client/firewall/uspfilter/hooks_filter.go
Normal file
90
client/firewall/uspfilter/hooks_filter.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
const (
|
||||
ipv4HeaderMinLen = 20
|
||||
ipv4ProtoOffset = 9
|
||||
ipv4FlagsOffset = 6
|
||||
ipv4DstOffset = 16
|
||||
ipProtoUDP = 17
|
||||
ipProtoTCP = 6
|
||||
ipv4FragOffMask = 0x1fff
|
||||
// dstPortOffset is the offset of the destination port within a UDP or TCP header.
|
||||
dstPortOffset = 2
|
||||
)
|
||||
|
||||
// HooksFilter is a minimal packet filter that only handles outbound DNS hooks.
|
||||
// It is installed on the WireGuard interface when the userspace bind is active
|
||||
// but a full firewall filter (Manager) is not needed because a native kernel
|
||||
// firewall (nftables/iptables) handles packet filtering.
|
||||
type HooksFilter struct {
|
||||
udpHook atomic.Pointer[common.PacketHook]
|
||||
tcpHook atomic.Pointer[common.PacketHook]
|
||||
}
|
||||
|
||||
var _ device.PacketFilter = (*HooksFilter)(nil)
|
||||
|
||||
// FilterOutbound checks outbound packets for DNS hook matches.
|
||||
// Only IPv4 packets matching the registered hook IP:port are intercepted.
|
||||
// IPv6 and non-IP packets pass through unconditionally.
|
||||
func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool {
|
||||
if len(packetData) < ipv4HeaderMinLen {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only process IPv4 packets, let everything else pass through.
|
||||
if packetData[0]>>4 != 4 {
|
||||
return false
|
||||
}
|
||||
|
||||
ihl := int(packetData[0]&0x0f) * 4
|
||||
if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip non-first fragments: they don't carry L4 headers.
|
||||
flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2])
|
||||
if flagsAndOffset&ipv4FragOffMask != 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4])
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
proto := packetData[ipv4ProtoOffset]
|
||||
dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2])
|
||||
|
||||
switch proto {
|
||||
case ipProtoUDP:
|
||||
return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData)
|
||||
case ipProtoTCP:
|
||||
return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// FilterInbound allows all inbound packets (native firewall handles filtering).
|
||||
func (f *HooksFilter) FilterInbound([]byte, int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// SetUDPPacketHook registers the UDP packet hook.
|
||||
func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||
common.SetHook(&f.udpHook, ip, dPort, hook)
|
||||
}
|
||||
|
||||
// SetTCPPacketHook registers the TCP packet hook.
|
||||
func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
|
||||
common.SetHook(&f.tcpHook, ip, dPort, hook)
|
||||
}
|
||||
Reference in New Issue
Block a user