[client] Add TCP DNS support for local listener (#5758)

This commit is contained in:
Viktor Liu
2026-04-08 13:40:36 +08:00
committed by GitHub
parent 1d920d700c
commit cb73b94ffb
28 changed files with 1615 additions and 417 deletions

View File

@@ -140,6 +140,17 @@ type Manager struct {
mtu uint16
mssClampValue uint16
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[packetHook]
tcpHookOut atomic.Pointer[packetHook]
}
// packetHook stores a registered hook for a specific IP:port.
type packetHook struct {
ip netip.Addr
port uint16
fn func([]byte) bool
}
// decoder for packages
@@ -594,6 +605,8 @@ func (m *Manager) resetState() {
maps.Clear(m.incomingRules)
maps.Clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil)
m.tcpHookOut.Store(nil)
if m.udpTracker != nil {
m.udpTracker.Close()
@@ -713,6 +726,9 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return true
}
case layers.LayerTypeTCP:
if m.tcpHooksDrop(uint16(d.tcp.DstPort), dstIP, packetData) {
return true
}
// Clamp MSS on all TCP SYN packets, including those from local IPs.
// SNATed routed traffic may appear as local IP but still requires clamping.
if m.mssClampEnabled {
@@ -895,38 +911,21 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
d.dnatOrigPort = 0
}
// udpHooksDrop checks if any UDP hooks should drop the packet
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
}
// Check specific destination IP first
if rules, exists := m.outgoingRules[dstIP]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
// Check IPv4 unspecified address
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
if h.ip == dstIP && h.port == dport {
return h.fn(packetData)
}
// Check IPv6 unspecified address
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
return false
}
@@ -1278,12 +1277,6 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
return rule.mgmtId, rule.drop, true
}
case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook
if rule.udpHook != nil {
return rule.mgmtId, rule.udpHook(packetData), true
}
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.mgmtId, rule.drop, true
}
@@ -1342,65 +1335,30 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
return sourceMatched
}
// AddUDPPacketHook calls hook when UDP packet from given direction matched
//
// Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
r := PeerRule{
id: uuid.New().String(),
ip: ip,
protoLayer: layers.LayerTypeUDP,
dPort: &firewall.Port{Values: []uint16{dPort}},
ipLayer: layers.LayerTypeIPv6,
udpHook: hook,
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.udpHookOut.Store(nil)
return
}
if ip.Is4() {
r.ipLayer = layers.LayerTypeIPv4
}
m.mutex.Lock()
if in {
// Incoming UDP hooks are stored in allow rules map
if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip] = make(map[string]PeerRule)
}
m.incomingRules[r.ip][r.id] = r
} else {
if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip] = make(map[string]PeerRule)
}
m.outgoingRules[r.ip][r.id] = r
}
m.mutex.Unlock()
return r.id
m.udpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
}
// RemovePacketHook removes packet hook by given ID
func (m *Manager) RemovePacketHook(hookID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
// Check incoming hooks (stored in allow rules)
for _, arr := range m.incomingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
if hook == nil {
m.tcpHookOut.Store(nil)
return
}
// Check outgoing hooks
for _, arr := range m.outgoingRules {
for _, r := range arr {
if r.id == hookID {
delete(arr, r.id)
return nil
}
}
}
return fmt.Errorf("hook with given id not found")
m.tcpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
}
// SetLogLevel sets the log level for the firewall manager

View File

@@ -12,6 +12,7 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
wgdevice "golang.zx2c4.com/wireguard/device"
@@ -186,81 +187,52 @@ func TestManagerDeleteRule(t *testing.T) {
}
}
func TestAddUDPPacketHook(t *testing.T) {
tests := []struct {
name string
in bool
expDir fw.RuleDirection
ip netip.Addr
dPort uint16
hook func([]byte) bool
expectedID string
}{
{
name: "Test Outgoing UDP Packet Hook",
in: false,
expDir: fw.RuleDirectionOUT,
ip: netip.MustParseAddr("10.168.0.1"),
dPort: 8000,
hook: func([]byte) bool { return true },
},
{
name: "Test Incoming UDP Packet Hook",
in: true,
expDir: fw.RuleDirectionIN,
ip: netip.MustParseAddr("::1"),
dPort: 9000,
hook: func([]byte) bool { return false },
},
}
func TestSetUDPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
var called bool
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, func([]byte) bool {
called = true
return true
})
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
h := manager.udpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(8000), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
var addedRule PeerRule
if tt.in {
// Incoming UDP hooks are stored in allow rules map
if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules[tt.ip]))
return
}
for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule
}
} else {
if len(manager.outgoingRules[tt.ip]) != 1 {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules[tt.ip]))
return
}
for _, rule := range manager.outgoingRules[tt.ip] {
addedRule = rule
}
}
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
assert.Nil(t, manager.udpHookOut.Load())
}
if tt.ip.Compare(addedRule.ip) != 0 {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return
}
if tt.dPort != addedRule.dPort.Values[0] {
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
return
}
if layers.LayerTypeUDP != addedRule.protoLayer {
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
return
}
if addedRule.udpHook == nil {
t.Errorf("expected udpHook to be set")
return
}
})
}
func TestSetTCPPacketHook(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, manager.Close(nil)) })
var called bool
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, func([]byte) bool {
called = true
return true
})
h := manager.tcpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(53), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)
assert.Nil(t, manager.tcpHookOut.Load())
}
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
@@ -530,39 +502,12 @@ func TestRemovePacketHook(t *testing.T) {
require.NoError(t, manager.Close(nil))
}()
// Add a UDP packet hook
hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, func([]byte) bool { return true })
// Assert the hook is added by finding it in the manager's outgoing rules
found := false
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
found = true
break
}
}
}
require.NotNil(t, manager.udpHookOut.Load(), "hook should be registered")
if !found {
t.Fatalf("The hook was not added properly.")
}
// Now remove the packet hook
err = manager.RemovePacketHook(hookID)
if err != nil {
t.Fatalf("Failed to remove hook: %s", err)
}
// Assert the hook is removed by checking it in the manager's outgoing rules
for _, arr := range manager.outgoingRules {
for _, rule := range arr {
if rule.id == hookID {
t.Fatalf("The hook was not removed properly.")
}
}
}
manager.SetUDPPacketHook(netip.MustParseAddr("192.168.0.1"), 8080, nil)
assert.Nil(t, manager.udpHookOut.Load(), "hook should be removed")
}
func TestProcessOutgoingHooks(t *testing.T) {
@@ -592,8 +537,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
}
hookCalled := false
hookID := manager.AddUDPPacketHook(
false,
manager.SetUDPPacketHook(
netip.MustParseAddr("100.10.0.100"),
53,
func([]byte) bool {
@@ -601,7 +545,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
return true
},
)
require.NotEmpty(t, hookID)
// Create test UDP packet
ipv4 := &layers.IPv4{

View File

@@ -421,6 +421,7 @@ func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.Laye
}
// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services.
// TODO: also delegate to nativeFirewall when available for kernel WG mode
func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
var layerType gopacket.LayerType
switch protocol {
@@ -466,6 +467,22 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot
return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort)
}
// AddOutputDNAT delegates to the native firewall if available.
func (m *Manager) AddOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return fmt.Errorf("output DNAT not supported without native firewall")
}
return m.nativeFirewall.AddOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// RemoveOutputDNAT delegates to the native firewall if available.
func (m *Manager) RemoveOutputDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.RemoveOutputDNAT(localAddr, protocol, sourcePort, targetPort)
}
// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets.
func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.portDNATEnabled.Load() {

View File

@@ -18,9 +18,7 @@ type PeerRule struct {
protoLayer gopacket.LayerType
sPort *firewall.Port
dPort *firewall.Port
drop bool
udpHook func([]byte) bool
drop bool
}
// ID returns the rule id

View File

@@ -399,21 +399,17 @@ func TestTracePacket(t *testing.T) {
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
m.SetUDPPacketHook(netip.MustParseAddr("100.10.255.254"), 53, func([]byte) bool {
return true // drop (intercepted by hook)
})
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
return createPacketBuilder("100.10.0.100", "100.10.255.254", "udp", 12345, 53, fw.RuleDirectionOUT)
},
expectedStages: []PacketStage{
StageReceived,
StageInboundPortDNAT,
StageInbound1to1NAT,
StageConntrack,
StageRouting,
StagePeerACL,
StageOutbound1to1NAT,
StageOutboundPortReverse,
StageCompleted,
},
expectedAllow: false,