From 5ca48258007eac6be7ee0777aa091d2731ff9a6e Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 26 Feb 2026 11:15:42 +0000 Subject: [PATCH] refactor(dns): trie + unified record set for DNSRecordStore - Replace four maps (aRecords, aaaaRecords, aWildcards, aaaaWildcards) with a label trie for exact lookups and a single wildcards map - Store one recordSet (A + AAAA) per domain/pattern instead of separate A and AAAA maps - Exact lookups O(labels); PTR unchanged (map); API and behaviour unchanged --- dns/dns_records.go | 358 +++++++++++++++++++++++---------------------- 1 file changed, 185 insertions(+), 173 deletions(-) diff --git a/dns/dns_records.go b/dns/dns_records.go index 199b94b..5c62043 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -18,24 +18,49 @@ const ( RecordTypePTR RecordType = RecordType(dns.TypePTR) ) -// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries +// recordSet holds A and AAAA records for a single domain or wildcard pattern +type recordSet struct { + A []net.IP + AAAA []net.IP +} + +// domainTrieNode is a node in the trie for exact domain lookups (no wildcards in path) +type domainTrieNode struct { + children map[string]*domainTrieNode + data *recordSet +} + +// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries. +// Exact domains are stored in a trie for O(label count) lookup; wildcard patterns +// are in a separate map. Each domain/pattern has a single recordSet (A + AAAA). type DNSRecordStore struct { - mu sync.RWMutex - aRecords map[string][]net.IP // domain -> list of IPv4 addresses - aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses - aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses - aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses - ptrRecords map[string]string // IP address string -> domain name + mu sync.RWMutex + root *domainTrieNode // trie root for exact lookups + wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records + ptrRecords map[string]string // IP address string -> domain name +} + +// domainToPath converts a FQDN to a trie path (reversed labels, e.g. "host.internal." -> ["internal", "host"]) +func domainToPath(domain string) []string { + domain = strings.ToLower(dns.Fqdn(domain)) + domain = strings.TrimSuffix(domain, ".") + if domain == "" { + return nil + } + labels := strings.Split(domain, ".") + path := make([]string, 0, len(labels)) + for i := len(labels) - 1; i >= 0; i-- { + path = append(path, labels[i]) + } + return path } // NewDNSRecordStore creates a new DNS record store func NewDNSRecordStore() *DNSRecordStore { return &DNSRecordStore{ - aRecords: make(map[string][]net.IP), - aaaaRecords: make(map[string][]net.IP), - aWildcards: make(map[string][]net.IP), - aaaaWildcards: make(map[string][]net.IP), - ptrRecords: make(map[string]string), + root: &domainTrieNode{children: make(map[string]*domainTrieNode)}, + wildcards: make(map[string]*recordSet), + ptrRecords: make(map[string]string), } } @@ -48,39 +73,47 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { s.mu.Lock() defer s.mu.Unlock() - // Ensure domain ends with a dot (FQDN format) if len(domain) == 0 || domain[len(domain)-1] != '.' { domain = domain + "." } - - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) - - // Check if domain contains wildcards isWildcard := strings.ContainsAny(domain, "*?") - if ip.To4() != nil { - // IPv4 address - if isWildcard { - s.aWildcards[domain] = append(s.aWildcards[domain], ip) - } else { - s.aRecords[domain] = append(s.aRecords[domain], ip) - // Automatically add PTR record for non-wildcard domains - s.ptrRecords[ip.String()] = domain - } - } else if ip.To16() != nil { - // IPv6 address - if isWildcard { - s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip) - } else { - s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip) - // Automatically add PTR record for non-wildcard domains - s.ptrRecords[ip.String()] = domain - } - } else { + isV4 := ip.To4() != nil + if !isV4 && ip.To16() == nil { return &net.ParseError{Type: "IP address", Text: ip.String()} } + if isWildcard { + if s.wildcards[domain] == nil { + s.wildcards[domain] = &recordSet{} + } + rs := s.wildcards[domain] + if isV4 { + rs.A = append(rs.A, ip) + } else { + rs.AAAA = append(rs.AAAA, ip) + } + return nil + } + + path := domainToPath(domain) + node := s.root + for _, label := range path { + if node.children[label] == nil { + node.children[label] = &domainTrieNode{children: make(map[string]*domainTrieNode)} + } + node = node.children[label] + } + if node.data == nil { + node.data = &recordSet{} + } + if isV4 { + node.data.A = append(node.data.A, ip) + } else { + node.data.AAAA = append(node.data.AAAA, ip) + } + s.ptrRecords[ip.String()] = domain return nil } @@ -112,89 +145,74 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { s.mu.Lock() defer s.mu.Unlock() - // Ensure domain ends with a dot (FQDN format) if len(domain) == 0 || domain[len(domain)-1] != '.' { domain = domain + "." } - - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) - - // Check if domain contains wildcards isWildcard := strings.ContainsAny(domain, "*?") - if ip == nil { - // Remove all records for this domain - if isWildcard { - delete(s.aWildcards, domain) - delete(s.aaaaWildcards, domain) + if isWildcard { + if ip == nil { + delete(s.wildcards, domain) + return + } + rs := s.wildcards[domain] + if rs == nil { + return + } + if ip.To4() != nil { + rs.A = removeIP(rs.A, ip) } else { - // For non-wildcard domains, remove PTR records for all IPs - if ips, ok := s.aRecords[domain]; ok { - for _, ipAddr := range ips { - // Only remove PTR if it points to this domain - if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ipAddr.String()) - } - } - } - if ips, ok := s.aaaaRecords[domain]; ok { - for _, ipAddr := range ips { - // Only remove PTR if it points to this domain - if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ipAddr.String()) - } - } - } - delete(s.aRecords, domain) - delete(s.aaaaRecords, domain) + rs.AAAA = removeIP(rs.AAAA, ip) + } + if len(rs.A) == 0 && len(rs.AAAA) == 0 { + delete(s.wildcards, domain) } return } + // Exact domain: find trie node + path := domainToPath(domain) + node := s.root + for _, label := range path { + node = node.children[label] + if node == nil { + return + } + } + if node.data == nil { + return + } + + if ip == nil { + for _, ipAddr := range node.data.A { + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } + } + for _, ipAddr := range node.data.AAAA { + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } + } + node.data = nil + return + } + if ip.To4() != nil { - // Remove specific IPv4 address - if isWildcard { - if ips, ok := s.aWildcards[domain]; ok { - s.aWildcards[domain] = removeIP(ips, ip) - if len(s.aWildcards[domain]) == 0 { - delete(s.aWildcards, domain) - } - } - } else { - if ips, ok := s.aRecords[domain]; ok { - s.aRecords[domain] = removeIP(ips, ip) - if len(s.aRecords[domain]) == 0 { - delete(s.aRecords, domain) - } - // Automatically remove PTR record if it points to this domain - if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ip.String()) - } - } + node.data.A = removeIP(node.data.A, ip) + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) } - } else if ip.To16() != nil { - // Remove specific IPv6 address - if isWildcard { - if ips, ok := s.aaaaWildcards[domain]; ok { - s.aaaaWildcards[domain] = removeIP(ips, ip) - if len(s.aaaaWildcards[domain]) == 0 { - delete(s.aaaaWildcards, domain) - } - } - } else { - if ips, ok := s.aaaaRecords[domain]; ok { - s.aaaaRecords[domain] = removeIP(ips, ip) - if len(s.aaaaRecords[domain]) == 0 { - delete(s.aaaaRecords, domain) - } - // Automatically remove PTR record if it points to this domain - if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ip.String()) - } - } + } else { + node.data.AAAA = removeIP(node.data.AAAA, ip) + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) } } + if len(node.data.A) == 0 && len(node.data.AAAA) == 0 { + node.data = nil + } } // RemovePTRRecord removes a PTR record for an IP address @@ -206,60 +224,54 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) { } // GetRecords returns all IP addresses for a domain and record type -// First checks for exact matches, then checks wildcard patterns +// First checks for exact match in the trie, then wildcard patterns func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { s.mu.RLock() defer s.mu.RUnlock() - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) + path := domainToPath(domain) - var records []net.IP - switch recordType { - case RecordTypeA: - // Check exact match first - if ips, ok := s.aRecords[domain]; ok { - // Return a copy to prevent external modifications - records = make([]net.IP, len(ips)) - copy(records, ips) - return records + // Exact match: walk trie + node := s.root + for _, label := range path { + node = node.children[label] + if node == nil { + break } - // Check wildcard patterns - for pattern, ips := range s.aWildcards { - if matchWildcard(pattern, domain) { - records = append(records, ips...) - } + } + if node != nil && node.data != nil { + var ips []net.IP + if recordType == RecordTypeA { + ips = node.data.A + } else { + ips = node.data.AAAA } - if len(records) > 0 { - // Return a copy - result := make([]net.IP, len(records)) - copy(result, records) - return result - } - - case RecordTypeAAAA: - // Check exact match first - if ips, ok := s.aaaaRecords[domain]; ok { - // Return a copy to prevent external modifications - records = make([]net.IP, len(ips)) - copy(records, ips) - return records - } - // Check wildcard patterns - for pattern, ips := range s.aaaaWildcards { - if matchWildcard(pattern, domain) { - records = append(records, ips...) - } - } - if len(records) > 0 { - // Return a copy - result := make([]net.IP, len(records)) - copy(result, records) - return result + if len(ips) > 0 { + out := make([]net.IP, len(ips)) + copy(out, ips) + return out } } - return records + // Wildcard match + var records []net.IP + for pattern, rs := range s.wildcards { + if !matchWildcard(pattern, domain) { + continue + } + if recordType == RecordTypeA { + records = append(records, rs.A...) + } else { + records = append(records, rs.AAAA...) + } + } + if len(records) == 0 { + return nil + } + out := make([]net.IP, len(records)) + copy(out, records) + return out } // GetPTRRecord returns the domain name for a PTR record query @@ -283,39 +295,41 @@ func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) { } // HasRecord checks if a domain has any records of the specified type -// Checks both exact matches and wildcard patterns +// Checks both exact matches (trie) and wildcard patterns func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { s.mu.RLock() defer s.mu.RUnlock() - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) + path := domainToPath(domain) - switch recordType { - case RecordTypeA: - // Check exact match - if _, ok := s.aRecords[domain]; ok { + node := s.root + for _, label := range path { + node = node.children[label] + if node == nil { + break + } + } + if node != nil && node.data != nil { + if recordType == RecordTypeA && len(node.data.A) > 0 { return true } - // Check wildcard patterns - for pattern := range s.aWildcards { - if matchWildcard(pattern, domain) { - return true - } - } - case RecordTypeAAAA: - // Check exact match - if _, ok := s.aaaaRecords[domain]; ok { + if recordType == RecordTypeAAAA && len(node.data.AAAA) > 0 { return true } - // Check wildcard patterns - for pattern := range s.aaaaWildcards { - if matchWildcard(pattern, domain) { - return true - } - } } + for pattern, rs := range s.wildcards { + if !matchWildcard(pattern, domain) { + continue + } + if recordType == RecordTypeA && len(rs.A) > 0 { + return true + } + if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 { + return true + } + } return false } @@ -339,10 +353,8 @@ func (s *DNSRecordStore) Clear() { s.mu.Lock() defer s.mu.Unlock() - s.aRecords = make(map[string][]net.IP) - s.aaaaRecords = make(map[string][]net.IP) - s.aWildcards = make(map[string][]net.IP) - s.aaaaWildcards = make(map[string][]net.IP) + s.root = &domainTrieNode{children: make(map[string]*domainTrieNode)} + s.wildcards = make(map[string]*recordSet) s.ptrRecords = make(map[string]string) } @@ -494,4 +506,4 @@ func IPToReverseDNS(ip net.IP) string { } return "" -} \ No newline at end of file +}