[client] Reorder userspace ACL checks to fail faster for better performance (#4226)

This commit is contained in:
Viktor Liu
2026-01-09 16:13:04 +08:00
committed by GitHub
parent e8863fbb55
commit 0ad0c81899
6 changed files with 72 additions and 69 deletions

View File

@@ -29,7 +29,7 @@ import (
) )
const ( const (
layerTypeAll = 0 layerTypeAll = 255
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40 ipTCPHeaderMinSize = 40
@@ -262,10 +262,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
} }
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) { func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) wgPrefix := iface.Address().Network
if err != nil {
return nil, fmt.Errorf("parse wireguard network: %w", err)
}
log.Debugf("blocking invalid routed traffic for %s", wgPrefix) log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
rule, err := m.addRouteFiltering( rule, err := m.addRouteFiltering(
@@ -439,19 +436,7 @@ func (m *Manager) AddPeerFiltering(
r.sPort = sPort r.sPort = sPort
r.dPort = dPort r.dPort = dPort
switch proto { r.protoLayer = protoToLayer(proto, r.ipLayer)
case firewall.ProtocolTCP:
r.protoLayer = layers.LayerTypeTCP
case firewall.ProtocolUDP:
r.protoLayer = layers.LayerTypeUDP
case firewall.ProtocolICMP:
r.protoLayer = layers.LayerTypeICMPv4
if r.ipLayer == layers.LayerTypeIPv6 {
r.protoLayer = layers.LayerTypeICMPv6
}
case firewall.ProtocolALL:
r.protoLayer = layerTypeAll
}
m.mutex.Lock() m.mutex.Lock()
var targetMap map[netip.Addr]RuleSet var targetMap map[netip.Addr]RuleSet
@@ -496,16 +481,17 @@ func (m *Manager) addRouteFiltering(
} }
ruleID := uuid.New().String() ruleID := uuid.New().String()
rule := RouteRule{ rule := RouteRule{
// TODO: consolidate these IDs // TODO: consolidate these IDs
id: ruleID, id: ruleID,
mgmtId: id, mgmtId: id,
sources: sources, sources: sources,
dstSet: destination.Set, dstSet: destination.Set,
proto: proto, protoLayer: protoToLayer(proto, layers.LayerTypeIPv4),
srcPort: sPort, srcPort: sPort,
dstPort: dPort, dstPort: dPort,
action: action, action: action,
} }
if destination.IsPrefix() { if destination.IsPrefix() {
rule.destinations = []netip.Prefix{destination.Prefix} rule.destinations = []netip.Prefix{destination.Prefix}
@@ -945,7 +931,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool {
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool { func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData) ruleID, blocked := m.peerACLsBlock(srcIP, d, packetData)
if blocked { if blocked {
_, pnum := getProtocolFromPacket(d) pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace6("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
@@ -1010,20 +996,22 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return false return false
} }
proto, pnum := getProtocolFromPacket(d) protoLayer := d.decoded[1]
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) ruleID, pass := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
if !pass { if !pass {
proto := getProtocolFromPacket(d)
m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d", m.logger.Trace6("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort) ruleID, proto, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{ m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(), FlowID: uuid.New(),
Type: nftypes.TypeDrop, Type: nftypes.TypeDrop,
RuleID: ruleID, RuleID: ruleID,
Direction: nftypes.Ingress, Direction: nftypes.Ingress,
Protocol: pnum, Protocol: proto,
SourceIP: srcIP, SourceIP: srcIP,
DestIP: dstIP, DestIP: dstIP,
SourcePort: srcPort, SourcePort: srcPort,
@@ -1052,16 +1040,33 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
return true return true
} }
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) { func protoToLayer(proto firewall.Protocol, ipLayer gopacket.LayerType) gopacket.LayerType {
switch proto {
case firewall.ProtocolTCP:
return layers.LayerTypeTCP
case firewall.ProtocolUDP:
return layers.LayerTypeUDP
case firewall.ProtocolICMP:
if ipLayer == layers.LayerTypeIPv6 {
return layers.LayerTypeICMPv6
}
return layers.LayerTypeICMPv4
case firewall.ProtocolALL:
return layerTypeAll
}
return 0
}
func getProtocolFromPacket(d *decoder) nftypes.Protocol {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
return firewall.ProtocolTCP, nftypes.TCP return nftypes.TCP
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
return firewall.ProtocolUDP, nftypes.UDP return nftypes.UDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return firewall.ProtocolICMP, nftypes.ICMP return nftypes.ICMP
default: default:
return firewall.ProtocolALL, nftypes.ProtocolUnknown return nftypes.ProtocolUnknown
} }
} }
@@ -1233,19 +1238,30 @@ func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d
} }
// routeACLsPass returns true if the packet is allowed by the route ACLs // routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) { func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
for _, rule := range m.routeRules { for _, rule := range m.routeRules {
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches { if matches := m.ruleMatches(rule, srcIP, dstIP, protoLayer, srcPort, dstPort); matches {
return rule.mgmtId, rule.action == firewall.ActionAccept return rule.mgmtId, rule.action == firewall.ActionAccept
} }
} }
return nil, false return nil, false
} }
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, protoLayer gopacket.LayerType, srcPort, dstPort uint16) bool {
// TODO: handle ipv6 vs ipv4 icmp rules
if rule.protoLayer != layerTypeAll && rule.protoLayer != protoLayer {
return false
}
if protoLayer == layers.LayerTypeTCP || protoLayer == layers.LayerTypeUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
destMatched := false destMatched := false
for _, dst := range rule.destinations { for _, dst := range rule.destinations {
if dst.Contains(dstAddr) { if dst.Contains(dstAddr) {
@@ -1264,21 +1280,8 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
break break
} }
} }
if !sourceMatched {
return false
}
if rule.proto != firewall.ProtocolALL && rule.proto != proto { return sourceMatched
return false
}
if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP {
if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) {
return false
}
}
return true
} }
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched

View File

@@ -955,7 +955,7 @@ func BenchmarkRouteACLs(b *testing.B) {
for _, tc := range cases { for _, tc := range cases {
srcIP := netip.MustParseAddr(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := netip.MustParseAddr(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort) manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), 0, tc.dstPort)
} }
} }
} }

View File

@@ -1259,7 +1259,7 @@ func TestRouteACLFiltering(t *testing.T) {
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed // testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
// to the forwarder // to the forwarder
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(tc.proto, layers.LayerTypeIPv4), tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed) require.Equal(t, tc.shouldPass, isAllowed)
}) })
} }
@@ -1445,7 +1445,7 @@ func TestRouteACLOrder(t *testing.T) {
srcIP := netip.MustParseAddr(p.srcIP) srcIP := netip.MustParseAddr(p.srcIP)
dstIP := netip.MustParseAddr(p.dstIP) dstIP := netip.MustParseAddr(p.dstIP)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(p.proto, layers.LayerTypeIPv4), p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
} }
}) })
@@ -1488,13 +1488,13 @@ func TestRouteACLSet(t *testing.T) {
dstIP := netip.MustParseAddr("192.168.1.100") dstIP := netip.MustParseAddr("192.168.1.100")
// Check that traffic is dropped (empty set shouldn't match anything) // Check that traffic is dropped (empty set shouldn't match anything)
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.False(t, isAllowed, "Empty set should not allow any traffic") require.False(t, isAllowed, "Empty set should not allow any traffic")
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}) err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
require.NoError(t, err) require.NoError(t, err)
// Now the packet should be allowed // Now the packet should be allowed
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80) _, isAllowed = manager.routeACLsPass(srcIP, dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed") require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
} }

View File

@@ -767,9 +767,9 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP2 := netip.MustParseAddr("192.168.1.100") dstIP2 := netip.MustParseAddr("192.168.1.100")
dstIP3 := netip.MustParseAddr("172.16.0.100") dstIP3 := netip.MustParseAddr("172.16.0.100")
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) _, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) _, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80) _, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed") require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed") require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
@@ -784,8 +784,8 @@ func TestUpdateSetMerge(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Check that all original prefixes are still included // Check that all original prefixes are still included
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80) _, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80) _, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update") require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update") require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
@@ -793,8 +793,8 @@ func TestUpdateSetMerge(t *testing.T) {
dstIP4 := netip.MustParseAddr("172.16.1.100") dstIP4 := netip.MustParseAddr("172.16.1.100")
dstIP5 := netip.MustParseAddr("10.1.0.50") dstIP5 := netip.MustParseAddr("10.1.0.50")
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80) _, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80) _, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed") require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed") require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
@@ -922,7 +922,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
srcIP := netip.MustParseAddr("100.10.0.1") srcIP := netip.MustParseAddr("100.10.0.1")
for _, tc := range testCases { for _, tc := range testCases {
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80) _, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, protoToLayer(fw.ProtocolTCP, layers.LayerTypeIPv4), 12345, 80)
require.Equal(t, tc.expected, isAllowed, tc.desc) require.Equal(t, tc.expected, isAllowed, tc.desc)
} }
} }

View File

@@ -34,7 +34,7 @@ type RouteRule struct {
sources []netip.Prefix sources []netip.Prefix
dstSet firewall.Set dstSet firewall.Set
destinations []netip.Prefix destinations []netip.Prefix
proto firewall.Protocol protoLayer gopacket.LayerType
srcPort *firewall.Port srcPort *firewall.Port
dstPort *firewall.Port dstPort *firewall.Port
action firewall.Action action firewall.Action

View File

@@ -379,9 +379,9 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
} }
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace { func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
proto, _ := getProtocolFromPacket(d) protoLayer := d.decoded[1]
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) id, allowed := m.routeACLsPass(srcIP, dstIP, protoLayer, srcPort, dstPort)
strId := string(id) strId := string(id)
if id == nil { if id == nil {