mirror of
https://github.com/fosrl/newt.git
synced 2026-03-05 10:16:44 +00:00
@@ -112,6 +112,8 @@ func NewWireGuardService(interfaceName string, port uint16, mtu int, host string
|
||||
return nil, fmt.Errorf("failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
logger.Debug("+++++++++++++++++++++++++++++++= the port is %d", port)
|
||||
|
||||
if port == 0 {
|
||||
// Find an available port
|
||||
portRandom, err := util.FindAvailableUDPPort(49152, 65535)
|
||||
@@ -724,7 +726,7 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
|
||||
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1117,7 +1119,7 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
|
||||
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1234,7 +1236,7 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
|
||||
}
|
||||
|
||||
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
|
||||
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v disableIcmp: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange, target.DisableIcmp)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
1
go.mod
1
go.mod
@@ -4,6 +4,7 @@ go 1.25.0
|
||||
|
||||
require (
|
||||
github.com/docker/docker v28.5.2+incompatible
|
||||
github.com/gaissmai/bart v0.26.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/prometheus/client_golang v1.23.2
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
|
||||
2
go.sum
2
go.sum
@@ -26,6 +26,8 @@ github.com/docker/go-units v0.4.0 h1:3uh0PgVws3nIA0Q+MwDC8yjEPf9zjRfZZWXZYDct3Tw
|
||||
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/gaissmai/bart v0.26.0 h1:xOZ57E9hJLBiQaSyeZa9wgWhGuzfGACgqp4BE77OkO0=
|
||||
github.com/gaissmai/bart v0.26.0/go.mod h1:GREWQfTLRWz/c5FTOsIw+KkscuFkIV5t8Rp7Nd1Td5c=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
|
||||
20
main.go
20
main.go
@@ -347,15 +347,6 @@ func runNewtMain(ctx context.Context) {
|
||||
pingTimeout = 5 * time.Second
|
||||
}
|
||||
|
||||
if portStr != "" {
|
||||
portInt, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to parse PORT, choosing a random port")
|
||||
} else {
|
||||
port = uint16(portInt)
|
||||
}
|
||||
}
|
||||
|
||||
if dockerEnforceNetworkValidation == "" {
|
||||
flag.StringVar(&dockerEnforceNetworkValidation, "docker-enforce-network-validation", "false", "Enforce validation of container on newt network (true or false)")
|
||||
}
|
||||
@@ -441,6 +432,15 @@ func runNewtMain(ctx context.Context) {
|
||||
tlsClientCAs = append(tlsClientCAs, tlsClientCAsFlag...)
|
||||
}
|
||||
|
||||
if portStr != "" {
|
||||
portInt, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to parse PORT, choosing a random port")
|
||||
} else {
|
||||
port = uint16(portInt)
|
||||
}
|
||||
}
|
||||
|
||||
if *version {
|
||||
fmt.Println("Newt version " + newtVersion)
|
||||
os.Exit(0)
|
||||
@@ -618,6 +618,8 @@ func runNewtMain(ctx context.Context) {
|
||||
var connected bool
|
||||
var wgData WgData
|
||||
var dockerEventMonitor *docker.EventMonitor
|
||||
|
||||
logger.Debug("++++++++++++++++++++++ the port is %d", port)
|
||||
|
||||
if !disableClients {
|
||||
setupClients(client)
|
||||
|
||||
@@ -48,115 +48,6 @@ type SubnetRule struct {
|
||||
PortRanges []PortRange // empty slice means all ports allowed
|
||||
}
|
||||
|
||||
// ruleKey is used as a map key for fast O(1) lookups
|
||||
type ruleKey struct {
|
||||
sourcePrefix string
|
||||
destPrefix string
|
||||
}
|
||||
|
||||
// SubnetLookup provides fast IP subnet and port matching with O(1) lookup performance
|
||||
type SubnetLookup struct {
|
||||
mu sync.RWMutex
|
||||
rules map[ruleKey]*SubnetRule // Map for O(1) lookups by prefix combination
|
||||
}
|
||||
|
||||
// NewSubnetLookup creates a new subnet lookup table
|
||||
func NewSubnetLookup() *SubnetLookup {
|
||||
return &SubnetLookup{
|
||||
rules: make(map[ruleKey]*SubnetRule),
|
||||
}
|
||||
}
|
||||
|
||||
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
|
||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
|
||||
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
|
||||
key := ruleKey{
|
||||
sourcePrefix: sourcePrefix.String(),
|
||||
destPrefix: destPrefix.String(),
|
||||
}
|
||||
|
||||
sl.rules[key] = &SubnetRule{
|
||||
SourcePrefix: sourcePrefix,
|
||||
DestPrefix: destPrefix,
|
||||
DisableIcmp: disableIcmp,
|
||||
RewriteTo: rewriteTo,
|
||||
PortRanges: portRanges,
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveSubnet removes a subnet rule from the lookup table
|
||||
func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
|
||||
key := ruleKey{
|
||||
sourcePrefix: sourcePrefix.String(),
|
||||
destPrefix: destPrefix.String(),
|
||||
}
|
||||
|
||||
delete(sl.rules, key)
|
||||
}
|
||||
|
||||
// Match checks if a source IP, destination IP, port, and protocol match any subnet rule
|
||||
// Returns the matched rule if ALL of these conditions are met:
|
||||
// - The source IP is in the rule's source prefix
|
||||
// - The destination IP is in the rule's destination prefix
|
||||
// - The port is in an allowed range (or no port restrictions exist)
|
||||
// - The protocol matches (or the port range allows both protocols)
|
||||
//
|
||||
// proto should be header.TCPProtocolNumber or header.UDPProtocolNumber
|
||||
// Returns nil if no rule matches
|
||||
func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16, proto tcpip.TransportProtocolNumber) *SubnetRule {
|
||||
sl.mu.RLock()
|
||||
defer sl.mu.RUnlock()
|
||||
|
||||
// Iterate through all rules to find matching source and destination prefixes
|
||||
// This is O(n) but necessary since we need to check prefix containment, not exact match
|
||||
for _, rule := range sl.rules {
|
||||
// Check if source and destination IPs match their respective prefixes
|
||||
if !rule.SourcePrefix.Contains(srcIP) {
|
||||
continue
|
||||
}
|
||||
if !rule.DestPrefix.Contains(dstIP) {
|
||||
continue
|
||||
}
|
||||
|
||||
if rule.DisableIcmp && (proto == header.ICMPv4ProtocolNumber || proto == header.ICMPv6ProtocolNumber) {
|
||||
// ICMP is disabled for this subnet
|
||||
return nil
|
||||
}
|
||||
|
||||
// Both IPs match - now check port restrictions
|
||||
// If no port ranges specified, all ports are allowed
|
||||
if len(rule.PortRanges) == 0 {
|
||||
return rule
|
||||
}
|
||||
|
||||
// Check if port and protocol are in any of the allowed ranges
|
||||
for _, pr := range rule.PortRanges {
|
||||
if port >= pr.Min && port <= pr.Max {
|
||||
// Check protocol compatibility
|
||||
if pr.Protocol == "" {
|
||||
// Empty protocol means allow both TCP and UDP
|
||||
return rule
|
||||
}
|
||||
// Check if the packet protocol matches the port range protocol
|
||||
if (pr.Protocol == "tcp" && proto == header.TCPProtocolNumber) ||
|
||||
(pr.Protocol == "udp" && proto == header.UDPProtocolNumber) {
|
||||
return rule
|
||||
}
|
||||
// Port matches but protocol doesn't - continue checking other ranges
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// connKey uniquely identifies a connection for NAT tracking
|
||||
type connKey struct {
|
||||
srcIP string
|
||||
@@ -166,6 +57,17 @@ type connKey struct {
|
||||
proto uint8
|
||||
}
|
||||
|
||||
// reverseConnKey uniquely identifies a connection for reverse NAT lookup (reply direction)
|
||||
// Key structure: (rewrittenTo, originalSrcIP, originalSrcPort, originalDstPort, proto)
|
||||
// This allows O(1) lookup of NAT entries for reply packets
|
||||
type reverseConnKey struct {
|
||||
rewrittenTo string // The address we rewrote to (becomes src in replies)
|
||||
originalSrcIP string // Original source IP (becomes dst in replies)
|
||||
originalSrcPort uint16 // Original source port (becomes dst port in replies)
|
||||
originalDstPort uint16 // Original destination port (becomes src port in replies)
|
||||
proto uint8
|
||||
}
|
||||
|
||||
// destKey identifies a destination for handler lookups (without source port since it may change)
|
||||
type destKey struct {
|
||||
srcIP string
|
||||
@@ -190,7 +92,8 @@ type ProxyHandler struct {
|
||||
icmpHandler *ICMPHandler
|
||||
subnetLookup *SubnetLookup
|
||||
natTable map[connKey]*natState
|
||||
destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups
|
||||
reverseNatTable map[reverseConnKey]*natState // Reverse lookup map for O(1) reply packet NAT
|
||||
destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups
|
||||
natMu sync.RWMutex
|
||||
enabled bool
|
||||
icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel
|
||||
@@ -215,6 +118,7 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
|
||||
enabled: true,
|
||||
subnetLookup: NewSubnetLookup(),
|
||||
natTable: make(map[connKey]*natState),
|
||||
reverseNatTable: make(map[reverseConnKey]*natState),
|
||||
destRewriteTable: make(map[destKey]netip.Addr),
|
||||
icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets
|
||||
proxyEp: channel.New(1024, uint32(options.MTU), ""),
|
||||
@@ -517,10 +421,23 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
|
||||
|
||||
// Store NAT state for this connection
|
||||
p.natMu.Lock()
|
||||
p.natTable[key] = &natState{
|
||||
natEntry := &natState{
|
||||
originalDst: dstAddr,
|
||||
rewrittenTo: newDst,
|
||||
}
|
||||
p.natTable[key] = natEntry
|
||||
|
||||
// Create reverse lookup key for O(1) reply packet lookups
|
||||
// Key: (rewrittenTo, originalSrcIP, originalSrcPort, originalDstPort, proto)
|
||||
reverseKey := reverseConnKey{
|
||||
rewrittenTo: newDst.String(),
|
||||
originalSrcIP: srcAddr.String(),
|
||||
originalSrcPort: srcPort,
|
||||
originalDstPort: dstPort,
|
||||
proto: uint8(protocol),
|
||||
}
|
||||
p.reverseNatTable[reverseKey] = natEntry
|
||||
|
||||
// Store destination rewrite for handler lookups
|
||||
p.destRewriteTable[dKey] = newDst
|
||||
p.natMu.Unlock()
|
||||
@@ -719,20 +636,22 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
|
||||
return view
|
||||
}
|
||||
|
||||
// Look up NAT state for reverse translation
|
||||
// The key uses the original dst (before rewrite), so for replies we need to
|
||||
// find the entry where the rewritten address matches the current source
|
||||
// Look up NAT state for reverse translation using O(1) reverse lookup map
|
||||
// Key: (rewrittenTo, originalSrcIP, originalSrcPort, originalDstPort, proto)
|
||||
// For reply packets:
|
||||
// - reply's srcIP = rewrittenTo (the address we rewrote to)
|
||||
// - reply's dstIP = originalSrcIP (original source IP)
|
||||
// - reply's srcPort = originalDstPort (original destination port)
|
||||
// - reply's dstPort = originalSrcPort (original source port)
|
||||
p.natMu.RLock()
|
||||
var natEntry *natState
|
||||
for k, entry := range p.natTable {
|
||||
// Match: reply's dst should be original src, reply's src should be rewritten dst
|
||||
if k.srcIP == dstIP.String() && k.srcPort == dstPort &&
|
||||
entry.rewrittenTo.String() == srcIP.String() && k.dstPort == srcPort &&
|
||||
k.proto == uint8(protocol) {
|
||||
natEntry = entry
|
||||
break
|
||||
}
|
||||
reverseKey := reverseConnKey{
|
||||
rewrittenTo: srcIP.String(), // Reply's source is the rewritten address
|
||||
originalSrcIP: dstIP.String(), // Reply's destination is the original source
|
||||
originalSrcPort: dstPort, // Reply's destination port is the original source port
|
||||
originalDstPort: srcPort, // Reply's source port is the original destination port
|
||||
proto: uint8(protocol),
|
||||
}
|
||||
natEntry := p.reverseNatTable[reverseKey]
|
||||
p.natMu.RUnlock()
|
||||
|
||||
if natEntry != nil {
|
||||
|
||||
206
netstack2/subnet_lookup.go
Normal file
206
netstack2/subnet_lookup.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package netstack2
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
)
|
||||
|
||||
// SubnetLookup provides fast IP subnet and port matching using BART (Binary Aggregated Range Tree)
|
||||
// This uses BART Table for O(log n) prefix matching with Supernets() for efficient lookups
|
||||
//
|
||||
// Architecture:
|
||||
// - Two-level BART structure for matching both source AND destination prefixes
|
||||
// - Level 1: Source prefix -> Level 2 (destination prefix -> rules)
|
||||
// - This reduces search space: only check destination prefixes for matching source prefixes
|
||||
type SubnetLookup struct {
|
||||
mu sync.RWMutex
|
||||
// Two-level BART structure:
|
||||
// Level 1: Source prefix -> Level 2 (destination prefix -> rules)
|
||||
// This allows us to first match source prefix, then only check destination prefixes
|
||||
// for matching source prefixes, reducing the search space significantly
|
||||
sourceTrie *bart.Table[*destTrie]
|
||||
}
|
||||
|
||||
// destTrie is a BART for destination prefixes, containing the actual rules
|
||||
type destTrie struct {
|
||||
trie *bart.Table[[]*SubnetRule]
|
||||
rules []*SubnetRule // All rules for this source prefix (for iteration if needed)
|
||||
}
|
||||
|
||||
// NewSubnetLookup creates a new subnet lookup table using BART
|
||||
func NewSubnetLookup() *SubnetLookup {
|
||||
return &SubnetLookup{
|
||||
sourceTrie: &bart.Table[*destTrie]{},
|
||||
}
|
||||
}
|
||||
|
||||
// prefixEqual compares two prefixes after masking to handle host bits correctly.
|
||||
// For example, 10.0.0.5/24 and 10.0.0.0/24 are treated as equal.
|
||||
func prefixEqual(a, b netip.Prefix) bool {
|
||||
return a.Masked() == b.Masked()
|
||||
}
|
||||
|
||||
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
|
||||
// If portRanges is nil or empty, all ports are allowed for this subnet
|
||||
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
|
||||
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
|
||||
rule := &SubnetRule{
|
||||
SourcePrefix: sourcePrefix,
|
||||
DestPrefix: destPrefix,
|
||||
DisableIcmp: disableIcmp,
|
||||
RewriteTo: rewriteTo,
|
||||
PortRanges: portRanges,
|
||||
}
|
||||
|
||||
// Canonicalize source prefix to handle host bits correctly
|
||||
canonicalSourcePrefix := sourcePrefix.Masked()
|
||||
|
||||
// Get or create destination trie for this source prefix
|
||||
destTriePtr, exists := sl.sourceTrie.Get(canonicalSourcePrefix)
|
||||
if !exists {
|
||||
// Create new destination trie for this source prefix
|
||||
destTriePtr = &destTrie{
|
||||
trie: &bart.Table[[]*SubnetRule]{},
|
||||
rules: make([]*SubnetRule, 0),
|
||||
}
|
||||
sl.sourceTrie.Insert(canonicalSourcePrefix, destTriePtr)
|
||||
}
|
||||
|
||||
// Canonicalize destination prefix to handle host bits correctly
|
||||
// BART masks prefixes internally, so we need to match that behavior in our bookkeeping
|
||||
canonicalDestPrefix := destPrefix.Masked()
|
||||
|
||||
// Add rule to destination trie
|
||||
// Original behavior: overwrite if same (sourcePrefix, destPrefix) exists
|
||||
// Store as single-element slice to match original overwrite behavior
|
||||
destTriePtr.trie.Insert(canonicalDestPrefix, []*SubnetRule{rule})
|
||||
|
||||
// Update destTriePtr.rules - remove old rule with same canonical prefix if exists, then add new one
|
||||
// Use canonical comparison to handle cases like 10.0.0.5/24 vs 10.0.0.0/24
|
||||
newRules := make([]*SubnetRule, 0, len(destTriePtr.rules)+1)
|
||||
for _, r := range destTriePtr.rules {
|
||||
if !prefixEqual(r.DestPrefix, canonicalDestPrefix) || !prefixEqual(r.SourcePrefix, canonicalSourcePrefix) {
|
||||
newRules = append(newRules, r)
|
||||
}
|
||||
}
|
||||
newRules = append(newRules, rule)
|
||||
destTriePtr.rules = newRules
|
||||
}
|
||||
|
||||
// RemoveSubnet removes a subnet rule from the lookup table
|
||||
func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) {
|
||||
sl.mu.Lock()
|
||||
defer sl.mu.Unlock()
|
||||
|
||||
// Canonicalize prefixes to handle host bits correctly
|
||||
canonicalSourcePrefix := sourcePrefix.Masked()
|
||||
canonicalDestPrefix := destPrefix.Masked()
|
||||
|
||||
destTriePtr, exists := sl.sourceTrie.Get(canonicalSourcePrefix)
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
// Remove the rule - original behavior: delete exact (sourcePrefix, destPrefix) combination
|
||||
// BART masks prefixes internally, so Delete works with canonical form
|
||||
destTriePtr.trie.Delete(canonicalDestPrefix)
|
||||
|
||||
// Also remove from destTriePtr.rules using canonical comparison
|
||||
// This ensures we remove rules even if they were added with host bits set
|
||||
newDestRules := make([]*SubnetRule, 0, len(destTriePtr.rules))
|
||||
for _, r := range destTriePtr.rules {
|
||||
if !prefixEqual(r.DestPrefix, canonicalDestPrefix) || !prefixEqual(r.SourcePrefix, canonicalSourcePrefix) {
|
||||
newDestRules = append(newDestRules, r)
|
||||
}
|
||||
}
|
||||
destTriePtr.rules = newDestRules
|
||||
|
||||
// Check if the trie is actually empty using BART's Size() method
|
||||
// This is more efficient than iterating and ensures we clean up empty tries
|
||||
// even if there were stale entries in the rules slice (which shouldn't happen
|
||||
// with proper canonicalization, but this provides a definitive check)
|
||||
if destTriePtr.trie.Size() == 0 {
|
||||
sl.sourceTrie.Delete(canonicalSourcePrefix)
|
||||
}
|
||||
}
|
||||
|
||||
// Match checks if a source IP, destination IP, port, and protocol match any subnet rule
|
||||
// Returns the matched rule if ALL of these conditions are met:
|
||||
// - The source IP is in the rule's source prefix
|
||||
// - The destination IP is in the rule's destination prefix
|
||||
// - The port is in an allowed range (or no port restrictions exist)
|
||||
// - The protocol matches (or the port range allows both protocols)
|
||||
//
|
||||
// proto should be header.TCPProtocolNumber, header.UDPProtocolNumber, or header.ICMPv4ProtocolNumber
|
||||
// Returns nil if no rule matches
|
||||
// This uses BART's Supernets() for O(log n) prefix matching instead of O(n) iteration
|
||||
func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16, proto tcpip.TransportProtocolNumber) *SubnetRule {
|
||||
sl.mu.RLock()
|
||||
defer sl.mu.RUnlock()
|
||||
|
||||
// Convert IP addresses to /32 (IPv4) or /128 (IPv6) prefixes
|
||||
// Supernets() finds all prefixes that contain this IP (i.e., are supernets of /32 or /128)
|
||||
srcPrefix := netip.PrefixFrom(srcIP, srcIP.BitLen())
|
||||
dstPrefix := netip.PrefixFrom(dstIP, dstIP.BitLen())
|
||||
|
||||
// Step 1: Find all source prefixes that contain srcIP using BART's Supernets
|
||||
// This is O(log n) instead of O(n) iteration
|
||||
// Supernets returns all prefixes that are supernets (contain) the given prefix
|
||||
for _, destTriePtr := range sl.sourceTrie.Supernets(srcPrefix) {
|
||||
if destTriePtr == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Step 2: Find all destination prefixes that contain dstIP
|
||||
// This is also O(log n) for each matching source prefix
|
||||
for _, rules := range destTriePtr.trie.Supernets(dstPrefix) {
|
||||
if rules == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Step 3: Check each rule for ICMP and port restrictions
|
||||
for _, rule := range rules {
|
||||
// Handle ICMP before port range check — ICMP has no ports
|
||||
if proto == header.ICMPv4ProtocolNumber || proto == header.ICMPv6ProtocolNumber {
|
||||
if rule.DisableIcmp {
|
||||
return nil
|
||||
}
|
||||
// ICMP is allowed; port ranges don't apply to ICMP
|
||||
return rule
|
||||
}
|
||||
|
||||
// Check port restrictions
|
||||
if len(rule.PortRanges) == 0 {
|
||||
// No port restrictions, match!
|
||||
return rule
|
||||
}
|
||||
|
||||
// Check if port and protocol are in any of the allowed ranges
|
||||
for _, pr := range rule.PortRanges {
|
||||
if port >= pr.Min && port <= pr.Max {
|
||||
// Check protocol compatibility
|
||||
if pr.Protocol == "" {
|
||||
// Empty protocol means allow both TCP and UDP
|
||||
return rule
|
||||
}
|
||||
// Check if the packet protocol matches the port range protocol
|
||||
if (pr.Protocol == "tcp" && proto == header.TCPProtocolNumber) ||
|
||||
(pr.Protocol == "udp" && proto == header.UDPProtocolNumber) {
|
||||
return rule
|
||||
}
|
||||
// Port matches but protocol doesn't - continue checking other ranges
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user