From 5ca48258007eac6be7ee0777aa091d2731ff9a6e Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 26 Feb 2026 11:15:42 +0000 Subject: [PATCH 1/3] refactor(dns): trie + unified record set for DNSRecordStore - Replace four maps (aRecords, aaaaRecords, aWildcards, aaaaWildcards) with a label trie for exact lookups and a single wildcards map - Store one recordSet (A + AAAA) per domain/pattern instead of separate A and AAAA maps - Exact lookups O(labels); PTR unchanged (map); API and behaviour unchanged --- dns/dns_records.go | 358 +++++++++++++++++++++++---------------------- 1 file changed, 185 insertions(+), 173 deletions(-) diff --git a/dns/dns_records.go b/dns/dns_records.go index 199b94b..5c62043 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -18,24 +18,49 @@ const ( RecordTypePTR RecordType = RecordType(dns.TypePTR) ) -// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries +// recordSet holds A and AAAA records for a single domain or wildcard pattern +type recordSet struct { + A []net.IP + AAAA []net.IP +} + +// domainTrieNode is a node in the trie for exact domain lookups (no wildcards in path) +type domainTrieNode struct { + children map[string]*domainTrieNode + data *recordSet +} + +// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries. +// Exact domains are stored in a trie for O(label count) lookup; wildcard patterns +// are in a separate map. Each domain/pattern has a single recordSet (A + AAAA). type DNSRecordStore struct { - mu sync.RWMutex - aRecords map[string][]net.IP // domain -> list of IPv4 addresses - aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses - aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses - aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses - ptrRecords map[string]string // IP address string -> domain name + mu sync.RWMutex + root *domainTrieNode // trie root for exact lookups + wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records + ptrRecords map[string]string // IP address string -> domain name +} + +// domainToPath converts a FQDN to a trie path (reversed labels, e.g. "host.internal." -> ["internal", "host"]) +func domainToPath(domain string) []string { + domain = strings.ToLower(dns.Fqdn(domain)) + domain = strings.TrimSuffix(domain, ".") + if domain == "" { + return nil + } + labels := strings.Split(domain, ".") + path := make([]string, 0, len(labels)) + for i := len(labels) - 1; i >= 0; i-- { + path = append(path, labels[i]) + } + return path } // NewDNSRecordStore creates a new DNS record store func NewDNSRecordStore() *DNSRecordStore { return &DNSRecordStore{ - aRecords: make(map[string][]net.IP), - aaaaRecords: make(map[string][]net.IP), - aWildcards: make(map[string][]net.IP), - aaaaWildcards: make(map[string][]net.IP), - ptrRecords: make(map[string]string), + root: &domainTrieNode{children: make(map[string]*domainTrieNode)}, + wildcards: make(map[string]*recordSet), + ptrRecords: make(map[string]string), } } @@ -48,39 +73,47 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { s.mu.Lock() defer s.mu.Unlock() - // Ensure domain ends with a dot (FQDN format) if len(domain) == 0 || domain[len(domain)-1] != '.' { domain = domain + "." } - - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) - - // Check if domain contains wildcards isWildcard := strings.ContainsAny(domain, "*?") - if ip.To4() != nil { - // IPv4 address - if isWildcard { - s.aWildcards[domain] = append(s.aWildcards[domain], ip) - } else { - s.aRecords[domain] = append(s.aRecords[domain], ip) - // Automatically add PTR record for non-wildcard domains - s.ptrRecords[ip.String()] = domain - } - } else if ip.To16() != nil { - // IPv6 address - if isWildcard { - s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip) - } else { - s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip) - // Automatically add PTR record for non-wildcard domains - s.ptrRecords[ip.String()] = domain - } - } else { + isV4 := ip.To4() != nil + if !isV4 && ip.To16() == nil { return &net.ParseError{Type: "IP address", Text: ip.String()} } + if isWildcard { + if s.wildcards[domain] == nil { + s.wildcards[domain] = &recordSet{} + } + rs := s.wildcards[domain] + if isV4 { + rs.A = append(rs.A, ip) + } else { + rs.AAAA = append(rs.AAAA, ip) + } + return nil + } + + path := domainToPath(domain) + node := s.root + for _, label := range path { + if node.children[label] == nil { + node.children[label] = &domainTrieNode{children: make(map[string]*domainTrieNode)} + } + node = node.children[label] + } + if node.data == nil { + node.data = &recordSet{} + } + if isV4 { + node.data.A = append(node.data.A, ip) + } else { + node.data.AAAA = append(node.data.AAAA, ip) + } + s.ptrRecords[ip.String()] = domain return nil } @@ -112,89 +145,74 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { s.mu.Lock() defer s.mu.Unlock() - // Ensure domain ends with a dot (FQDN format) if len(domain) == 0 || domain[len(domain)-1] != '.' { domain = domain + "." } - - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) - - // Check if domain contains wildcards isWildcard := strings.ContainsAny(domain, "*?") - if ip == nil { - // Remove all records for this domain - if isWildcard { - delete(s.aWildcards, domain) - delete(s.aaaaWildcards, domain) + if isWildcard { + if ip == nil { + delete(s.wildcards, domain) + return + } + rs := s.wildcards[domain] + if rs == nil { + return + } + if ip.To4() != nil { + rs.A = removeIP(rs.A, ip) } else { - // For non-wildcard domains, remove PTR records for all IPs - if ips, ok := s.aRecords[domain]; ok { - for _, ipAddr := range ips { - // Only remove PTR if it points to this domain - if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ipAddr.String()) - } - } - } - if ips, ok := s.aaaaRecords[domain]; ok { - for _, ipAddr := range ips { - // Only remove PTR if it points to this domain - if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ipAddr.String()) - } - } - } - delete(s.aRecords, domain) - delete(s.aaaaRecords, domain) + rs.AAAA = removeIP(rs.AAAA, ip) + } + if len(rs.A) == 0 && len(rs.AAAA) == 0 { + delete(s.wildcards, domain) } return } + // Exact domain: find trie node + path := domainToPath(domain) + node := s.root + for _, label := range path { + node = node.children[label] + if node == nil { + return + } + } + if node.data == nil { + return + } + + if ip == nil { + for _, ipAddr := range node.data.A { + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } + } + for _, ipAddr := range node.data.AAAA { + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } + } + node.data = nil + return + } + if ip.To4() != nil { - // Remove specific IPv4 address - if isWildcard { - if ips, ok := s.aWildcards[domain]; ok { - s.aWildcards[domain] = removeIP(ips, ip) - if len(s.aWildcards[domain]) == 0 { - delete(s.aWildcards, domain) - } - } - } else { - if ips, ok := s.aRecords[domain]; ok { - s.aRecords[domain] = removeIP(ips, ip) - if len(s.aRecords[domain]) == 0 { - delete(s.aRecords, domain) - } - // Automatically remove PTR record if it points to this domain - if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ip.String()) - } - } + node.data.A = removeIP(node.data.A, ip) + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) } - } else if ip.To16() != nil { - // Remove specific IPv6 address - if isWildcard { - if ips, ok := s.aaaaWildcards[domain]; ok { - s.aaaaWildcards[domain] = removeIP(ips, ip) - if len(s.aaaaWildcards[domain]) == 0 { - delete(s.aaaaWildcards, domain) - } - } - } else { - if ips, ok := s.aaaaRecords[domain]; ok { - s.aaaaRecords[domain] = removeIP(ips, ip) - if len(s.aaaaRecords[domain]) == 0 { - delete(s.aaaaRecords, domain) - } - // Automatically remove PTR record if it points to this domain - if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ip.String()) - } - } + } else { + node.data.AAAA = removeIP(node.data.AAAA, ip) + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) } } + if len(node.data.A) == 0 && len(node.data.AAAA) == 0 { + node.data = nil + } } // RemovePTRRecord removes a PTR record for an IP address @@ -206,60 +224,54 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) { } // GetRecords returns all IP addresses for a domain and record type -// First checks for exact matches, then checks wildcard patterns +// First checks for exact match in the trie, then wildcard patterns func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { s.mu.RLock() defer s.mu.RUnlock() - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) + path := domainToPath(domain) - var records []net.IP - switch recordType { - case RecordTypeA: - // Check exact match first - if ips, ok := s.aRecords[domain]; ok { - // Return a copy to prevent external modifications - records = make([]net.IP, len(ips)) - copy(records, ips) - return records + // Exact match: walk trie + node := s.root + for _, label := range path { + node = node.children[label] + if node == nil { + break } - // Check wildcard patterns - for pattern, ips := range s.aWildcards { - if matchWildcard(pattern, domain) { - records = append(records, ips...) - } + } + if node != nil && node.data != nil { + var ips []net.IP + if recordType == RecordTypeA { + ips = node.data.A + } else { + ips = node.data.AAAA } - if len(records) > 0 { - // Return a copy - result := make([]net.IP, len(records)) - copy(result, records) - return result - } - - case RecordTypeAAAA: - // Check exact match first - if ips, ok := s.aaaaRecords[domain]; ok { - // Return a copy to prevent external modifications - records = make([]net.IP, len(ips)) - copy(records, ips) - return records - } - // Check wildcard patterns - for pattern, ips := range s.aaaaWildcards { - if matchWildcard(pattern, domain) { - records = append(records, ips...) - } - } - if len(records) > 0 { - // Return a copy - result := make([]net.IP, len(records)) - copy(result, records) - return result + if len(ips) > 0 { + out := make([]net.IP, len(ips)) + copy(out, ips) + return out } } - return records + // Wildcard match + var records []net.IP + for pattern, rs := range s.wildcards { + if !matchWildcard(pattern, domain) { + continue + } + if recordType == RecordTypeA { + records = append(records, rs.A...) + } else { + records = append(records, rs.AAAA...) + } + } + if len(records) == 0 { + return nil + } + out := make([]net.IP, len(records)) + copy(out, records) + return out } // GetPTRRecord returns the domain name for a PTR record query @@ -283,39 +295,41 @@ func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) { } // HasRecord checks if a domain has any records of the specified type -// Checks both exact matches and wildcard patterns +// Checks both exact matches (trie) and wildcard patterns func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { s.mu.RLock() defer s.mu.RUnlock() - // Normalize domain to lowercase FQDN domain = strings.ToLower(dns.Fqdn(domain)) + path := domainToPath(domain) - switch recordType { - case RecordTypeA: - // Check exact match - if _, ok := s.aRecords[domain]; ok { + node := s.root + for _, label := range path { + node = node.children[label] + if node == nil { + break + } + } + if node != nil && node.data != nil { + if recordType == RecordTypeA && len(node.data.A) > 0 { return true } - // Check wildcard patterns - for pattern := range s.aWildcards { - if matchWildcard(pattern, domain) { - return true - } - } - case RecordTypeAAAA: - // Check exact match - if _, ok := s.aaaaRecords[domain]; ok { + if recordType == RecordTypeAAAA && len(node.data.AAAA) > 0 { return true } - // Check wildcard patterns - for pattern := range s.aaaaWildcards { - if matchWildcard(pattern, domain) { - return true - } - } } + for pattern, rs := range s.wildcards { + if !matchWildcard(pattern, domain) { + continue + } + if recordType == RecordTypeA && len(rs.A) > 0 { + return true + } + if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 { + return true + } + } return false } @@ -339,10 +353,8 @@ func (s *DNSRecordStore) Clear() { s.mu.Lock() defer s.mu.Unlock() - s.aRecords = make(map[string][]net.IP) - s.aaaaRecords = make(map[string][]net.IP) - s.aWildcards = make(map[string][]net.IP) - s.aaaaWildcards = make(map[string][]net.IP) + s.root = &domainTrieNode{children: make(map[string]*domainTrieNode)} + s.wildcards = make(map[string]*recordSet) s.ptrRecords = make(map[string]string) } @@ -494,4 +506,4 @@ func IPToReverseDNS(ip net.IP) string { } return "" -} \ No newline at end of file +} From 9ae49e36d5c341266d6eff74d948af99faabbd95 Mon Sep 17 00:00:00 2001 From: Laurence Date: Sat, 28 Feb 2026 10:03:09 +0000 Subject: [PATCH 2/3] refactor(dns): simplify DNSRecordStore from trie to map Replace trie-based domain lookup with simple map for O(1) lookups. Add exists boolean to GetRecords for proper NODATA vs NXDOMAIN responses. --- dns/dns_proxy.go | 14 +-- dns/dns_records.go | 199 +++++++++++++++------------------------- dns/dns_records_test.go | 60 ++++++++---- 3 files changed, 125 insertions(+), 148 deletions(-) diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 986e847..27770e4 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -447,19 +447,20 @@ func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns return nil } - ips := p.recordStore.GetRecords(question.Name, recordType) - if len(ips) == 0 { + ips, exists := p.recordStore.GetRecords(question.Name, recordType) + if !exists { + // Domain not found in local records, forward to upstream return nil } logger.Debug("Found %d local record(s) for %s", len(ips), question.Name) - // Create response message + // Create response message (NODATA if no records, otherwise with answers) response := new(dns.Msg) response.SetReply(query) response.Authoritative = true - // Add answer records + // Add answer records (loop is a no-op if ips is empty) for _, ip := range ips { var rr dns.RR if question.Qtype == dns.TypeA { @@ -730,8 +731,9 @@ func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) { p.recordStore.RemoveRecord(domain, ip) } -// GetDNSRecords returns all IP addresses for a domain and record type -func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP { +// GetDNSRecords returns all IP addresses for a domain and record type. +// The second return value indicates whether the domain exists. +func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) ([]net.IP, bool) { return p.recordStore.GetRecords(domain, recordType) } diff --git a/dns/dns_records.go b/dns/dns_records.go index 5c62043..10bb7f3 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -24,41 +24,19 @@ type recordSet struct { AAAA []net.IP } -// domainTrieNode is a node in the trie for exact domain lookups (no wildcards in path) -type domainTrieNode struct { - children map[string]*domainTrieNode - data *recordSet -} - // DNSRecordStore manages local DNS records for A, AAAA, and PTR queries. -// Exact domains are stored in a trie for O(label count) lookup; wildcard patterns -// are in a separate map. Each domain/pattern has a single recordSet (A + AAAA). +// Exact domains are stored in a map; wildcard patterns are in a separate map. type DNSRecordStore struct { mu sync.RWMutex - root *domainTrieNode // trie root for exact lookups + exact map[string]*recordSet // normalized FQDN -> A/AAAA records wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records ptrRecords map[string]string // IP address string -> domain name } -// domainToPath converts a FQDN to a trie path (reversed labels, e.g. "host.internal." -> ["internal", "host"]) -func domainToPath(domain string) []string { - domain = strings.ToLower(dns.Fqdn(domain)) - domain = strings.TrimSuffix(domain, ".") - if domain == "" { - return nil - } - labels := strings.Split(domain, ".") - path := make([]string, 0, len(labels)) - for i := len(labels) - 1; i >= 0; i-- { - path = append(path, labels[i]) - } - return path -} - // NewDNSRecordStore creates a new DNS record store func NewDNSRecordStore() *DNSRecordStore { return &DNSRecordStore{ - root: &domainTrieNode{children: make(map[string]*domainTrieNode)}, + exact: make(map[string]*recordSet), wildcards: make(map[string]*recordSet), ptrRecords: make(map[string]string), } @@ -84,36 +62,26 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { return &net.ParseError{Type: "IP address", Text: ip.String()} } + // Choose the appropriate map based on whether this is a wildcard + m := s.exact if isWildcard { - if s.wildcards[domain] == nil { - s.wildcards[domain] = &recordSet{} - } - rs := s.wildcards[domain] - if isV4 { - rs.A = append(rs.A, ip) - } else { - rs.AAAA = append(rs.AAAA, ip) - } - return nil + m = s.wildcards } - path := domainToPath(domain) - node := s.root - for _, label := range path { - if node.children[label] == nil { - node.children[label] = &domainTrieNode{children: make(map[string]*domainTrieNode)} - } - node = node.children[label] - } - if node.data == nil { - node.data = &recordSet{} + if m[domain] == nil { + m[domain] = &recordSet{} } + rs := m[domain] if isV4 { - node.data.A = append(node.data.A, ip) + rs.A = append(rs.A, ip) } else { - node.data.AAAA = append(node.data.AAAA, ip) + rs.AAAA = append(rs.AAAA, ip) + } + + // Add PTR record for non-wildcard domains + if !isWildcard { + s.ptrRecords[ip.String()] = domain } - s.ptrRecords[ip.String()] = domain return nil } @@ -151,67 +119,55 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { domain = strings.ToLower(dns.Fqdn(domain)) isWildcard := strings.ContainsAny(domain, "*?") + // Choose the appropriate map + m := s.exact if isWildcard { - if ip == nil { - delete(s.wildcards, domain) - return - } - rs := s.wildcards[domain] - if rs == nil { - return - } - if ip.To4() != nil { - rs.A = removeIP(rs.A, ip) - } else { - rs.AAAA = removeIP(rs.AAAA, ip) - } - if len(rs.A) == 0 && len(rs.AAAA) == 0 { - delete(s.wildcards, domain) - } - return + m = s.wildcards } - // Exact domain: find trie node - path := domainToPath(domain) - node := s.root - for _, label := range path { - node = node.children[label] - if node == nil { - return - } - } - if node.data == nil { + rs := m[domain] + if rs == nil { return } if ip == nil { - for _, ipAddr := range node.data.A { - if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ipAddr.String()) + // Remove all records for this domain + if !isWildcard { + for _, ipAddr := range rs.A { + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } + } + for _, ipAddr := range rs.AAAA { + if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ipAddr.String()) + } } } - for _, ipAddr := range node.data.AAAA { - if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ipAddr.String()) - } - } - node.data = nil + delete(m, domain) return } + // Remove specific IP if ip.To4() != nil { - node.data.A = removeIP(node.data.A, ip) - if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ip.String()) + rs.A = removeIP(rs.A, ip) + if !isWildcard { + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) + } } } else { - node.data.AAAA = removeIP(node.data.AAAA, ip) - if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { - delete(s.ptrRecords, ip.String()) + rs.AAAA = removeIP(rs.AAAA, ip) + if !isWildcard { + if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { + delete(s.ptrRecords, ip.String()) + } } } - if len(node.data.A) == 0 && len(node.data.AAAA) == 0 { - node.data = nil + + // Clean up empty record sets + if len(rs.A) == 0 && len(rs.AAAA) == 0 { + delete(m, domain) } } @@ -223,55 +179,56 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) { delete(s.ptrRecords, ip.String()) } -// GetRecords returns all IP addresses for a domain and record type -// First checks for exact match in the trie, then wildcard patterns -func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { +// GetRecords returns all IP addresses for a domain and record type. +// The second return value indicates whether the domain exists at all +// (true = domain exists, use NODATA if no records; false = NXDOMAIN). +func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) ([]net.IP, bool) { s.mu.RLock() defer s.mu.RUnlock() domain = strings.ToLower(dns.Fqdn(domain)) - path := domainToPath(domain) - // Exact match: walk trie - node := s.root - for _, label := range path { - node = node.children[label] - if node == nil { - break - } - } - if node != nil && node.data != nil { + // Check exact match first + if rs, exists := s.exact[domain]; exists { var ips []net.IP if recordType == RecordTypeA { - ips = node.data.A + ips = rs.A } else { - ips = node.data.AAAA + ips = rs.AAAA } if len(ips) > 0 { out := make([]net.IP, len(ips)) copy(out, ips) - return out + return out, true } + // Domain exists but no records of this type + return nil, true } - // Wildcard match + // Check wildcard matches var records []net.IP + matched := false for pattern, rs := range s.wildcards { if !matchWildcard(pattern, domain) { continue } + matched = true if recordType == RecordTypeA { records = append(records, rs.A...) } else { records = append(records, rs.AAAA...) } } + + if !matched { + return nil, false + } if len(records) == 0 { - return nil + return nil, true } out := make([]net.IP, len(records)) copy(out, records) - return out + return out, true } // GetPTRRecord returns the domain name for a PTR record query @@ -295,30 +252,24 @@ func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) { } // HasRecord checks if a domain has any records of the specified type -// Checks both exact matches (trie) and wildcard patterns +// Checks both exact matches and wildcard patterns func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { s.mu.RLock() defer s.mu.RUnlock() domain = strings.ToLower(dns.Fqdn(domain)) - path := domainToPath(domain) - node := s.root - for _, label := range path { - node = node.children[label] - if node == nil { - break - } - } - if node != nil && node.data != nil { - if recordType == RecordTypeA && len(node.data.A) > 0 { + // Check exact match + if rs, exists := s.exact[domain]; exists { + if recordType == RecordTypeA && len(rs.A) > 0 { return true } - if recordType == RecordTypeAAAA && len(node.data.AAAA) > 0 { + if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 { return true } } + // Check wildcard matches for pattern, rs := range s.wildcards { if !matchWildcard(pattern, domain) { continue @@ -353,7 +304,7 @@ func (s *DNSRecordStore) Clear() { s.mu.Lock() defer s.mu.Unlock() - s.root = &domainTrieNode{children: make(map[string]*domainTrieNode)} + s.exact = make(map[string]*recordSet) s.wildcards = make(map[string]*recordSet) s.ptrRecords = make(map[string]string) } diff --git a/dns/dns_records_test.go b/dns/dns_records_test.go index eae9372..963dcc1 100644 --- a/dns/dns_records_test.go +++ b/dns/dns_records_test.go @@ -183,25 +183,34 @@ func TestDNSRecordStoreWildcard(t *testing.T) { } // Test exact match takes precedence - ips := store.GetRecords("exact.autoco.internal.", RecordTypeA) + ips, exists := store.GetRecords("exact.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected domain to exist") + } if len(ips) != 1 { t.Errorf("Expected 1 IP for exact match, got %d", len(ips)) } - if !ips[0].Equal(exactIP) { + if len(ips) > 0 && !ips[0].Equal(exactIP) { t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0]) } // Test wildcard match - ips = store.GetRecords("host.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected wildcard match to exist") + } if len(ips) != 1 { t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips)) } - if !ips[0].Equal(wildcardIP) { + if len(ips) > 0 && !ips[0].Equal(wildcardIP) { t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0]) } // Test non-match (base domain) - ips = store.GetRecords("autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected base domain to not exist") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs for base domain, got %d", len(ips)) } @@ -218,7 +227,10 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { } // Test matching domain - ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) + ips, exists := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected complex wildcard match to exist") + } if len(ips) != 1 { t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips)) } @@ -227,13 +239,19 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { } // Test non-matching domain (missing prefix) - ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("host-01.autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected domain without prefix to not exist") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips)) } // Test non-matching domain (wrong ? position) - ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected domain with wrong ? match to not exist") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips)) } @@ -250,7 +268,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) { } // Verify it exists - ips := store.GetRecords("host.autoco.internal.", RecordTypeA) + ips, exists := store.GetRecords("host.autoco.internal.", RecordTypeA) + if !exists { + t.Error("Expected domain to exist before removal") + } if len(ips) != 1 { t.Errorf("Expected 1 IP before removal, got %d", len(ips)) } @@ -259,7 +280,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) { store.RemoveRecord("*.autoco.internal", nil) // Verify it's gone - ips = store.GetRecords("host.autoco.internal.", RecordTypeA) + ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA) + if exists { + t.Error("Expected domain to not exist after removal") + } if len(ips) != 0 { t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) } @@ -290,19 +314,19 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) { } // Test domain matching only the prod pattern and the broad pattern - ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) + ips, _ := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips)) } // Test domain matching only the dev pattern and the broad pattern - ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) + ips, _ = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) if len(ips) != 2 { t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips)) } // Test domain matching only the broad pattern - ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA) + ips, _ = store.GetRecords("host.test.autoco.internal.", RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP (broad only), got %d", len(ips)) } @@ -319,7 +343,7 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { } // Test wildcard match for IPv6 - ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) + ips, _ := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) if len(ips) != 1 { t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips)) } @@ -368,7 +392,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { } for _, domain := range testCases { - ips := store.GetRecords(domain, RecordTypeA) + ips, _ := store.GetRecords(domain, RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips)) } @@ -392,7 +416,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { } for _, domain := range wildcardTestCases { - ips := store.GetRecords(domain, RecordTypeA) + ips, _ := store.GetRecords(domain, RecordTypeA) if len(ips) != 1 { t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips)) } @@ -403,7 +427,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { // Test removal with different case store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil) - ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA) + ips, _ := store.GetRecords("myhost.autoco.internal.", RecordTypeA) if len(ips) != 0 { t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) } @@ -752,7 +776,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) { } // Verify A record is also gone - ips := store.GetRecords(domain, RecordTypeA) + ips, _ := store.GetRecords(domain, RecordTypeA) if len(ips) != 0 { t.Errorf("Expected A record to be removed, got %d records", len(ips)) } From ae88766d85926e3643fc7ec0c6452b0270da8167 Mon Sep 17 00:00:00 2001 From: Laurence Date: Sat, 28 Feb 2026 10:22:37 +0000 Subject: [PATCH 3/3] test(dns): add dns test cases for nodata --- dns/dns_proxy_test.go | 178 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 dns/dns_proxy_test.go diff --git a/dns/dns_proxy_test.go b/dns/dns_proxy_test.go new file mode 100644 index 0000000..4a1d9f9 --- /dev/null +++ b/dns/dns_proxy_test.go @@ -0,0 +1,178 @@ +package dns + +import ( + "net" + "testing" + + "github.com/miekg/dns" +) + +func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) { + proxy := &DNSProxy{ + recordStore: NewDNSRecordStore(), + } + + // Add an A record for a domain (no AAAA record) + ip := net.ParseIP("10.0.0.1") + err := proxy.recordStore.AddRecord("myservice.internal", ip) + if err != nil { + t.Fatalf("Failed to add A record: %v", err) + } + + // Query AAAA for domain with only A record - should return NODATA + query := new(dns.Msg) + query.SetQuestion("myservice.internal.", dns.TypeAAAA) + response := proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected NODATA response, got nil (would forward to upstream)") + } + if response.Rcode != dns.RcodeSuccess { + t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode) + } + if len(response.Answer) != 0 { + t.Errorf("Expected empty answer section for NODATA, got %d answers", len(response.Answer)) + } + if !response.Authoritative { + t.Error("Expected response to be authoritative") + } + + // Query A for same domain - should return the record + query = new(dns.Msg) + query.SetQuestion("myservice.internal.", dns.TypeA) + response = proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected response with A record, got nil") + } + if len(response.Answer) != 1 { + t.Fatalf("Expected 1 answer, got %d", len(response.Answer)) + } + aRecord, ok := response.Answer[0].(*dns.A) + if !ok { + t.Fatal("Expected A record in answer") + } + if !aRecord.A.Equal(ip.To4()) { + t.Errorf("Expected IP %v, got %v", ip.To4(), aRecord.A) + } +} + +func TestCheckLocalRecordsNODATAForA(t *testing.T) { + proxy := &DNSProxy{ + recordStore: NewDNSRecordStore(), + } + + // Add an AAAA record for a domain (no A record) + ip := net.ParseIP("2001:db8::1") + err := proxy.recordStore.AddRecord("ipv6only.internal", ip) + if err != nil { + t.Fatalf("Failed to add AAAA record: %v", err) + } + + // Query A for domain with only AAAA record - should return NODATA + query := new(dns.Msg) + query.SetQuestion("ipv6only.internal.", dns.TypeA) + response := proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected NODATA response, got nil") + } + if response.Rcode != dns.RcodeSuccess { + t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode) + } + if len(response.Answer) != 0 { + t.Errorf("Expected empty answer section, got %d answers", len(response.Answer)) + } + if !response.Authoritative { + t.Error("Expected response to be authoritative") + } + + // Query AAAA for same domain - should return the record + query = new(dns.Msg) + query.SetQuestion("ipv6only.internal.", dns.TypeAAAA) + response = proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected response with AAAA record, got nil") + } + if len(response.Answer) != 1 { + t.Fatalf("Expected 1 answer, got %d", len(response.Answer)) + } + aaaaRecord, ok := response.Answer[0].(*dns.AAAA) + if !ok { + t.Fatal("Expected AAAA record in answer") + } + if !aaaaRecord.AAAA.Equal(ip) { + t.Errorf("Expected IP %v, got %v", ip, aaaaRecord.AAAA) + } +} + +func TestCheckLocalRecordsNonExistentDomain(t *testing.T) { + proxy := &DNSProxy{ + recordStore: NewDNSRecordStore(), + } + + // Add a record so the store isn't empty + err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1")) + if err != nil { + t.Fatalf("Failed to add record: %v", err) + } + + // Query A for non-existent domain - should return nil (forward to upstream) + query := new(dns.Msg) + query.SetQuestion("unknown.internal.", dns.TypeA) + response := proxy.checkLocalRecords(query, query.Question[0]) + + if response != nil { + t.Error("Expected nil for non-existent domain, got response") + } + + // Query AAAA for non-existent domain - should also return nil + query = new(dns.Msg) + query.SetQuestion("unknown.internal.", dns.TypeAAAA) + response = proxy.checkLocalRecords(query, query.Question[0]) + + if response != nil { + t.Error("Expected nil for non-existent domain, got response") + } +} + +func TestCheckLocalRecordsNODATAWildcard(t *testing.T) { + proxy := &DNSProxy{ + recordStore: NewDNSRecordStore(), + } + + // Add a wildcard A record (no AAAA) + ip := net.ParseIP("10.0.0.1") + err := proxy.recordStore.AddRecord("*.wildcard.internal", ip) + if err != nil { + t.Fatalf("Failed to add wildcard A record: %v", err) + } + + // Query AAAA for wildcard-matched domain - should return NODATA + query := new(dns.Msg) + query.SetQuestion("host.wildcard.internal.", dns.TypeAAAA) + response := proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected NODATA response for wildcard match, got nil") + } + if response.Rcode != dns.RcodeSuccess { + t.Errorf("Expected Rcode NOERROR (0), got %d", response.Rcode) + } + if len(response.Answer) != 0 { + t.Errorf("Expected empty answer section, got %d answers", len(response.Answer)) + } + + // Query A for wildcard-matched domain - should return the record + query = new(dns.Msg) + query.SetQuestion("host.wildcard.internal.", dns.TypeA) + response = proxy.checkLocalRecords(query, query.Question[0]) + + if response == nil { + t.Fatal("Expected response with A record, got nil") + } + if len(response.Answer) != 1 { + t.Fatalf("Expected 1 answer, got %d", len(response.Answer)) + } +}