From 9ae49e36d5c341266d6eff74d948af99faabbd95 Mon Sep 17 00:00:00 2001 From: Laurence Date: Sat, 28 Feb 2026 10:03:09 +0000 Subject: [PATCH] refactor(dns): simplify DNSRecordStore from trie to map Replace trie-based domain lookup with simple map for O(1) lookups. Add exists boolean to GetRecords for proper NODATA vs NXDOMAIN responses. --- dns/dns_proxy.go | 14 +-- dns/dns_records.go | 199 +++++++++++++++------------------------- dns/dns_records_test.go | 60 ++++++++---- 3 files changed, 125 insertions(+), 148 deletions(-) diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 986e847..27770e4 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -447,19 +447,20 @@ func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns return nil } - ips := p.recordStore.GetRecords(question.Name, recordType) - if len(ips) == 0 { + ips, exists := p.recordStore.GetRecords(question.Name, recordType) + if !exists { + // Domain not found in local records, forward to upstream return nil } logger.Debug("Found %d local record(s) for %s", len(ips), question.Name) - // Create response message + // Create response message (NODATA if no records, otherwise with answers) response := new(dns.Msg) response.SetReply(query) response.Authoritative = true - // Add answer records + // Add answer records (loop is a no-op if ips is empty) for _, ip := range ips { var rr dns.RR if question.Qtype == dns.TypeA { @@ -730,8 +731,9 @@ func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) { p.recordStore.RemoveRecord(domain, ip) } -// GetDNSRecords returns all IP addresses for a domain and record type -func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP { +// GetDNSRecords returns all IP addresses for a domain and record type. +// The second return value indicates whether the domain exists. +func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) ([]net.IP, bool) { return p.recordStore.GetRecords(domain, recordType) } diff --git a/dns/dns_records.go b/dns/dns_records.go index 5c62043..10bb7f3 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -24,41 +24,19 @@ type recordSet struct { AAAA []net.IP } -// domainTrieNode is a node in the trie for exact domain lookups (no wildcards in path) -type domainTrieNode struct { - children map[string]*domainTrieNode - data *recordSet -} - // DNSRecordStore manages local DNS records for A, AAAA, and PTR queries. -// Exact domains are stored in a trie for O(label count) lookup; wildcard patterns -// are in a separate map. Each domain/pattern has a single recordSet (A + AAAA). +// Exact domains are stored in a map; wildcard patterns are in a separate map. type DNSRecordStore struct { mu sync.RWMutex - root *domainTrieNode // trie root for exact lookups + exact map[string]*recordSet // normalized FQDN -> A/AAAA records wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records ptrRecords map[string]string // IP address string -> domain name } -// domainToPath converts a FQDN to a trie path (reversed labels, e.g. "host.internal." -> ["internal", "host"]) -func domainToPath(domain string) []string { - domain = strings.ToLower(dns.Fqdn(domain)) - domain = strings.TrimSuffix(domain, ".") - if domain == "" { - return nil - } - labels := strings.Split(domain, ".") - path := make([]string, 0, len(labels)) - for i := len(labels) - 1; i >= 0; i-- { - path = append(path, labels[i]) - } - return path -} - // NewDNSRecordStore creates a new DNS record store func NewDNSRecordStore() *DNSRecordStore { return &DNSRecordStore{ - root: &domainTrieNode{children: make(map[string]*domainTrieNode)}, + exact: make(map[string]*recordSet), wildcards: make(map[string]*recordSet), ptrRecords: make(map[string]string), } @@ -84,36 +62,26 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { return &net.ParseError{Type: "IP address", Text: ip.String()} } + // Choose the appropriate map based on whether this is a wildcard + m := s.exact if isWildcard { - if s.wildcards[domain] == nil { - s.wildcards[domain] = &recordSet{} - } - rs := s.wildcards[domain] - if isV4 { - rs.A = append(rs.A, ip) - } else { - rs.AAAA = append(rs.AAAA, ip) - } - return nil + m = s.wildcards } - path := domainToPath(domain) - node := s.root - for _, label := range path { - if node.children[label] == nil { - node.children[label] = &domainTrieNode{children: make(map[string]*domainTrieNode)} - } - node = node.children[label] - } - if node.data == nil { - node.data = &recordSet{} + if m[domain] == nil { + m[domain] = &recordSet{} } + rs := m[domain] if isV4 { - node.data.A = append(node.data.A, ip) + rs.A = append(rs.A, ip) } else { - node.data.AAAA = append(node.data.AAAA, ip) + rs.AAAA = append(rs.AAAA, ip) + } + + // Add PTR record for non-wildcard domains + if !isWildcard { + s.ptrRecords[ip.String()] = domain } - s.ptrRecords[ip.String()] = domain return nil } @@ -151,67 +119,55 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { domain = strings.ToLower(dns.Fqdn(domain)) isWildcard := strings.ContainsAny(domain, "*?") + // Choose the appropriate map + m := s.exact if isWildcard { - if ip == nil { - delete(s.wildcards, domain) - return - } - rs := s.wildcards[domain] - if rs == nil { - return - } - if ip.To4() != nil { - rs.A = removeIP(rs.A, ip) - } else { - rs.AAAA = removeIP(rs.AAAA, ip) - } - if len(rs.A) == 0 && len(rs.AAAA) == 0 { - delete(s.wildcards, domain) - } - return + m = s.wildcards } - // Exact domain: find trie node - path := domainToPath(domain) - node := s.root - for _, label := range path { - node = node.children[label] - if node == nil { - return - } - } - if node.data == nil { + rs := m[domain] + if rs == nil { return } if ip == nil { - for _, ipAddr := range node.data.A { - if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ipAddr.String()) + // Remove all records for this domain + if !isWildcard { + for _, ipAddr := range rs.A { + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } + } + for _, ipAddr := range rs.AAAA { + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } } } - for _, ipAddr := range node.data.AAAA { - if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ipAddr.String()) - } - } - node.data = nil + delete(m, domain) return } + // Remove specific IP if ip.To4() != nil { - node.data.A = removeIP(node.data.A, ip) - if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ip.String()) + rs.A = removeIP(rs.A, ip) + if !isWildcard { + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) + } } } else { - node.data.AAAA = removeIP(node.data.AAAA, ip) - if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ip.String()) + rs.AAAA = removeIP(rs.AAAA, ip) + if !isWildcard { + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) + } } } - if len(node.data.A) == 0 && len(node.data.AAAA) == 0 { - node.data = nil + + // Clean up empty record sets + if len(rs.A) == 0 && len(rs.AAAA) == 0 { + delete(m, domain) } } @@ -223,55 +179,56 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) { delete(s.ptrRecords, ip.String()) } -// GetRecords returns all IP addresses for a domain and record type -// First checks for exact match in the trie, then wildcard patterns -func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { +// GetRecords returns all IP addresses for a domain and record type. +// The second return value indicates whether the domain exists at all +// (true = domain exists, use NODATA if no records; false = NXDOMAIN). +func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) ([]net.IP, bool) { s.mu.RLock() defer s.mu.RUnlock() domain = strings.ToLower(dns.Fqdn(domain)) - path := domainToPath(domain) - // Exact match: walk trie - node := s.root - for _, label := range path { - node = node.children[label] - if node == nil { - break - } - } - if node != nil && node.data != nil { + // Check exact match first + if rs, exists := s.exact[domain]; exists { var ips []net.IP if recordType == RecordTypeA { - ips = node.data.A + ips = rs.A } else { - ips = node.data.AAAA + ips = rs.AAAA } if len(ips) > 0 { out := make([]net.IP, len(ips)) copy(out, ips) - return out + return out, true } + // Domain exists but no records of this type + return nil, true } - // Wildcard match + // Check wildcard matches var records []net.IP + matched := false for pattern, rs := range s.wildcards { if !matchWildcard(pattern, domain) { continue } + matched = true if recordType == RecordTypeA { records = append(records, rs.A...) } else { records = append(records, rs.AAAA...) } } + + if !matched { + return nil, false + } if len(records) == 0 { - return nil + return nil, true } out := make([]net.IP, len(records)) copy(out, records) - return out + return out, true } // GetPTRRecord returns the domain name for a PTR record query @@ -295,30 +252,24 @@ func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) { } // HasRecord checks if a domain has any records of the specified type -// Checks both exact matches (trie) and wildcard patterns +// Checks both exact matches and wildcard patterns func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { s.mu.RLock() defer s.mu.RUnlock() domain = strings.ToLower(dns.Fqdn(domain)) - path := domainToPath(domain) - node := s.root - for _, label := range path { - node = node.children[label] - if node == nil { - break - } - } - if node != nil && node.data != nil { - if recordType == RecordTypeA && len(node.data.A) > 0 { + // Check exact match + if rs, exists := s.exact[domain]; exists { + if recordType == RecordTypeA && len(rs.A) > 0 { return true } - if recordType == RecordTypeAAAA && len(node.data.AAAA) > 0 { + if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 { return true } } + // Check wildcard matches for pattern, rs := range s.wildcards { if !matchWildcard(pattern, domain) { continue @@ -353,7 +304,7 @@ func (s *DNSRecordStore) Clear() { s.mu.Lock() defer s.mu.Unlock() - s.root = &domainTrieNode{children: make(map[string]*domainTrieNode)} + s.exact = make(map[string]*recordSet) s.wildcards = make(map[string]*recordSet) s.ptrRecords = make(map[string]string) } diff --git a/dns/dns_records_test.go b/dns/dns_records_test.go index eae9372..963dcc1 100644 --- a/dns/dns_records_test.go +++ b/dns/dns_records_test.go @@ -183,25 +183,34 @@ func TestDNSRecordStoreWildcard(t *testing.T) { } // Test exact match takes precedence - ips := store.GetRecords("exact.autoco.internal.", RecordTypeA) + ips, exists := store.GetRecords("exact.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected domain to exist") + } if len(ips) != 1 { t.Errorf("Expected 1 IP for exact match, got %d", len(ips)) } - if !ips[0].Equal(exactIP) { + if len(ips) > 0 && !ips[0].Equal(exactIP) { t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0]) } // Test wildcard match - ips = store.GetRecords("host.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected wildcard match to exist") + } if len(ips) != 1 { t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips)) } - if !ips[0].Equal(wildcardIP) { + if len(ips) > 0 && !ips[0].Equal(wildcardIP) { t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0]) } // Test non-match (base domain) - ips = store.GetRecords("autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected base domain to not exist") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs for base domain, got %d", len(ips)) } @@ -218,7 +227,10 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { } // Test matching domain - ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) + ips, exists := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected complex wildcard match to exist") + } if len(ips) != 1 { t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips)) } @@ -227,13 +239,19 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { } // Test non-matching domain (missing prefix) - ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("host-01.autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected domain without prefix to not exist") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips)) } // Test non-matching domain (wrong ? position) - ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected domain with wrong ? match to not exist") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips)) } @@ -250,7 +268,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) { } // Verify it exists - ips := store.GetRecords("host.autoco.internal.", RecordTypeA) + ips, exists := store.GetRecords("host.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected domain to exist before removal") + } if len(ips) != 1 { t.Errorf("Expected 1 IP before removal, got %d", len(ips)) } @@ -259,7 +280,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) { store.RemoveRecord("*.autoco.internal", nil) // Verify it's gone - ips = store.GetRecords("host.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected domain to not exist after removal") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) } @@ -290,19 +314,19 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) { } // Test domain matching only the prod pattern and the broad pattern - ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) + ips, _ := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips)) } // Test domain matching only the dev pattern and the broad pattern - ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) + ips, _ = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips)) } // Test domain matching only the broad pattern - ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA) + ips, _ = store.GetRecords("host.test.autoco.internal.", RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP (broad only), got %d", len(ips)) } @@ -319,7 +343,7 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { } // Test wildcard match for IPv6 - ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) + ips, _ := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) if len(ips) != 1 { t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips)) } @@ -368,7 +392,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { } for _, domain := range testCases { - ips := store.GetRecords(domain, RecordTypeA) + ips, _ := store.GetRecords(domain, RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips)) } @@ -392,7 +416,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { } for _, domain := range wildcardTestCases { - ips := store.GetRecords(domain, RecordTypeA) + ips, _ := store.GetRecords(domain, RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips)) } @@ -403,7 +427,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { // Test removal with different case store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil) - ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA) + ips, _ := store.GetRecords("myhost.autoco.internal.", RecordTypeA) if len(ips) != 0 { t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) } @@ -752,7 +776,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) { } // Verify A record is also gone - ips := store.GetRecords(domain, RecordTypeA) + ips, _ := store.GetRecords(domain, RecordTypeA) if len(ips) != 0 { t.Errorf("Expected A record to be removed, got %d records", len(ips)) }