diff --git a/go.mod b/go.mod index d167bd7..03ff1c2 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.25 require ( github.com/docker/docker v28.5.2+incompatible + github.com/gaissmai/bart v0.26.0 github.com/gorilla/websocket v1.5.3 github.com/prometheus/client_golang v1.23.2 github.com/vishvananda/netlink v1.3.1 @@ -40,7 +41,6 @@ require ( github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.4.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/gaissmai/bart v0.26.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/google/btree v1.1.3 // indirect diff --git a/netstack2/subnet_lookup.go b/netstack2/subnet_lookup.go index 1a8414d..fcfed63 100644 --- a/netstack2/subnet_lookup.go +++ b/netstack2/subnet_lookup.go @@ -38,6 +38,12 @@ func NewSubnetLookup() *SubnetLookup { } } +// prefixEqual compares two prefixes after masking to handle host bits correctly. +// For example, 10.0.0.5/24 and 10.0.0.0/24 are treated as equal. +func prefixEqual(a, b netip.Prefix) bool { + return a.Masked() == b.Masked() +} + // AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions // If portRanges is nil or empty, all ports are allowed for this subnet // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") @@ -53,26 +59,34 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite PortRanges: portRanges, } + // Canonicalize source prefix to handle host bits correctly + canonicalSourcePrefix := sourcePrefix.Masked() + // Get or create destination trie for this source prefix - destTriePtr, exists := sl.sourceTrie.Get(sourcePrefix) + destTriePtr, exists := sl.sourceTrie.Get(canonicalSourcePrefix) if !exists { // Create new destination trie for this source prefix destTriePtr = &destTrie{ trie: &bart.Table[[]*SubnetRule]{}, rules: make([]*SubnetRule, 0), } - sl.sourceTrie.Insert(sourcePrefix, destTriePtr) + sl.sourceTrie.Insert(canonicalSourcePrefix, destTriePtr) } + // Canonicalize destination prefix to handle host bits correctly + // BART masks prefixes internally, so we need to match that behavior in our bookkeeping + canonicalDestPrefix := destPrefix.Masked() + // Add rule to destination trie // Original behavior: overwrite if same (sourcePrefix, destPrefix) exists // Store as single-element slice to match original overwrite behavior - destTriePtr.trie.Insert(destPrefix, []*SubnetRule{rule}) + destTriePtr.trie.Insert(canonicalDestPrefix, []*SubnetRule{rule}) - // Update destTriePtr.rules - remove old rule with same prefix if exists, then add new one + // Update destTriePtr.rules - remove old rule with same canonical prefix if exists, then add new one + // Use canonical comparison to handle cases like 10.0.0.5/24 vs 10.0.0.0/24 newRules := make([]*SubnetRule, 0, len(destTriePtr.rules)+1) for _, r := range destTriePtr.rules { - if r.DestPrefix != destPrefix { + if !prefixEqual(r.DestPrefix, canonicalDestPrefix) || !prefixEqual(r.SourcePrefix, canonicalSourcePrefix) { newRules = append(newRules, r) } } @@ -85,26 +99,35 @@ func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) { sl.mu.Lock() defer sl.mu.Unlock() - destTriePtr, exists := sl.sourceTrie.Get(sourcePrefix) + // Canonicalize prefixes to handle host bits correctly + canonicalSourcePrefix := sourcePrefix.Masked() + canonicalDestPrefix := destPrefix.Masked() + + destTriePtr, exists := sl.sourceTrie.Get(canonicalSourcePrefix) if !exists { return } // Remove the rule - original behavior: delete exact (sourcePrefix, destPrefix) combination - destTriePtr.trie.Delete(destPrefix) + // BART masks prefixes internally, so Delete works with canonical form + destTriePtr.trie.Delete(canonicalDestPrefix) - // Also remove from destTriePtr.rules + // Also remove from destTriePtr.rules using canonical comparison + // This ensures we remove rules even if they were added with host bits set newDestRules := make([]*SubnetRule, 0, len(destTriePtr.rules)) for _, r := range destTriePtr.rules { - if r.DestPrefix != destPrefix { + if !prefixEqual(r.DestPrefix, canonicalDestPrefix) || !prefixEqual(r.SourcePrefix, canonicalSourcePrefix) { newDestRules = append(newDestRules, r) } } destTriePtr.rules = newDestRules - // If no more rules for this source prefix, remove it - if len(destTriePtr.rules) == 0 { - sl.sourceTrie.Delete(sourcePrefix) + // Check if the trie is actually empty using BART's Size() method + // This is more efficient than iterating and ensures we clean up empty tries + // even if there were stale entries in the rules slice (which shouldn't happen + // with proper canonicalization, but this provides a definitive check) + if destTriePtr.trie.Size() == 0 { + sl.sourceTrie.Delete(canonicalSourcePrefix) } }