refactor(dns): simplify DNSRecordStore from trie to map

Replace trie-based domain lookup with simple map for O(1) lookups.
  Add exists boolean to GetRecords for proper NODATA vs NXDOMAIN responses.
This commit is contained in:
Laurence
2026-02-28 10:03:09 +00:00
committed by Owen Schwartz
parent 5ca4825800
commit 9ae49e36d5
3 changed files with 125 additions and 148 deletions

View File

@@ -447,19 +447,20 @@ func (p *DNSProxy) checkLocalRecords(query *dns.Msg, question dns.Question) *dns
return nil return nil
} }
ips := p.recordStore.GetRecords(question.Name, recordType) ips, exists := p.recordStore.GetRecords(question.Name, recordType)
if len(ips) == 0 { if !exists {
// Domain not found in local records, forward to upstream
return nil return nil
} }
logger.Debug("Found %d local record(s) for %s", len(ips), question.Name) logger.Debug("Found %d local record(s) for %s", len(ips), question.Name)
// Create response message // Create response message (NODATA if no records, otherwise with answers)
response := new(dns.Msg) response := new(dns.Msg)
response.SetReply(query) response.SetReply(query)
response.Authoritative = true response.Authoritative = true
// Add answer records // Add answer records (loop is a no-op if ips is empty)
for _, ip := range ips { for _, ip := range ips {
var rr dns.RR var rr dns.RR
if question.Qtype == dns.TypeA { if question.Qtype == dns.TypeA {
@@ -730,8 +731,9 @@ func (p *DNSProxy) RemoveDNSRecord(domain string, ip net.IP) {
p.recordStore.RemoveRecord(domain, ip) p.recordStore.RemoveRecord(domain, ip)
} }
// GetDNSRecords returns all IP addresses for a domain and record type // GetDNSRecords returns all IP addresses for a domain and record type.
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) []net.IP { // The second return value indicates whether the domain exists.
func (p *DNSProxy) GetDNSRecords(domain string, recordType RecordType) ([]net.IP, bool) {
return p.recordStore.GetRecords(domain, recordType) return p.recordStore.GetRecords(domain, recordType)
} }

View File

@@ -24,41 +24,19 @@ type recordSet struct {
AAAA []net.IP AAAA []net.IP
} }
// domainTrieNode is a node in the trie for exact domain lookups (no wildcards in path)
type domainTrieNode struct {
children map[string]*domainTrieNode
data *recordSet
}
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries. // DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
// Exact domains are stored in a trie for O(label count) lookup; wildcard patterns // Exact domains are stored in a map; wildcard patterns are in a separate map.
// are in a separate map. Each domain/pattern has a single recordSet (A + AAAA).
type DNSRecordStore struct { type DNSRecordStore struct {
mu sync.RWMutex mu sync.RWMutex
root *domainTrieNode // trie root for exact lookups exact map[string]*recordSet // normalized FQDN -> A/AAAA records
wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records
ptrRecords map[string]string // IP address string -> domain name ptrRecords map[string]string // IP address string -> domain name
} }
// domainToPath converts a FQDN to a trie path (reversed labels, e.g. "host.internal." -> ["internal", "host"])
func domainToPath(domain string) []string {
domain = strings.ToLower(dns.Fqdn(domain))
domain = strings.TrimSuffix(domain, ".")
if domain == "" {
return nil
}
labels := strings.Split(domain, ".")
path := make([]string, 0, len(labels))
for i := len(labels) - 1; i >= 0; i-- {
path = append(path, labels[i])
}
return path
}
// NewDNSRecordStore creates a new DNS record store // NewDNSRecordStore creates a new DNS record store
func NewDNSRecordStore() *DNSRecordStore { func NewDNSRecordStore() *DNSRecordStore {
return &DNSRecordStore{ return &DNSRecordStore{
root: &domainTrieNode{children: make(map[string]*domainTrieNode)}, exact: make(map[string]*recordSet),
wildcards: make(map[string]*recordSet), wildcards: make(map[string]*recordSet),
ptrRecords: make(map[string]string), ptrRecords: make(map[string]string),
} }
@@ -84,36 +62,26 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
return &net.ParseError{Type: "IP address", Text: ip.String()} return &net.ParseError{Type: "IP address", Text: ip.String()}
} }
// Choose the appropriate map based on whether this is a wildcard
m := s.exact
if isWildcard { if isWildcard {
if s.wildcards[domain] == nil { m = s.wildcards
s.wildcards[domain] = &recordSet{}
}
rs := s.wildcards[domain]
if isV4 {
rs.A = append(rs.A, ip)
} else {
rs.AAAA = append(rs.AAAA, ip)
}
return nil
} }
path := domainToPath(domain) if m[domain] == nil {
node := s.root m[domain] = &recordSet{}
for _, label := range path {
if node.children[label] == nil {
node.children[label] = &domainTrieNode{children: make(map[string]*domainTrieNode)}
}
node = node.children[label]
}
if node.data == nil {
node.data = &recordSet{}
} }
rs := m[domain]
if isV4 { if isV4 {
node.data.A = append(node.data.A, ip) rs.A = append(rs.A, ip)
} else { } else {
node.data.AAAA = append(node.data.AAAA, ip) rs.AAAA = append(rs.AAAA, ip)
}
// Add PTR record for non-wildcard domains
if !isWildcard {
s.ptrRecords[ip.String()] = domain
} }
s.ptrRecords[ip.String()] = domain
return nil return nil
} }
@@ -151,67 +119,55 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
isWildcard := strings.ContainsAny(domain, "*?") isWildcard := strings.ContainsAny(domain, "*?")
// Choose the appropriate map
m := s.exact
if isWildcard { if isWildcard {
if ip == nil { m = s.wildcards
delete(s.wildcards, domain)
return
}
rs := s.wildcards[domain]
if rs == nil {
return
}
if ip.To4() != nil {
rs.A = removeIP(rs.A, ip)
} else {
rs.AAAA = removeIP(rs.AAAA, ip)
}
if len(rs.A) == 0 && len(rs.AAAA) == 0 {
delete(s.wildcards, domain)
}
return
} }
// Exact domain: find trie node rs := m[domain]
path := domainToPath(domain) if rs == nil {
node := s.root
for _, label := range path {
node = node.children[label]
if node == nil {
return
}
}
if node.data == nil {
return return
} }
if ip == nil { if ip == nil {
for _, ipAddr := range node.data.A { // Remove all records for this domain
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain { if !isWildcard {
delete(s.ptrRecords, ipAddr.String()) for _, ipAddr := range rs.A {
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
for _, ipAddr := range rs.AAAA {
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
} }
} }
for _, ipAddr := range node.data.AAAA { delete(m, domain)
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
node.data = nil
return return
} }
// Remove specific IP
if ip.To4() != nil { if ip.To4() != nil {
node.data.A = removeIP(node.data.A, ip) rs.A = removeIP(rs.A, ip)
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { if !isWildcard {
delete(s.ptrRecords, ip.String()) if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
} }
} else { } else {
node.data.AAAA = removeIP(node.data.AAAA, ip) rs.AAAA = removeIP(rs.AAAA, ip)
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain { if !isWildcard {
delete(s.ptrRecords, ip.String()) if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
} }
} }
if len(node.data.A) == 0 && len(node.data.AAAA) == 0 {
node.data = nil // Clean up empty record sets
if len(rs.A) == 0 && len(rs.AAAA) == 0 {
delete(m, domain)
} }
} }
@@ -223,55 +179,56 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
delete(s.ptrRecords, ip.String()) delete(s.ptrRecords, ip.String())
} }
// GetRecords returns all IP addresses for a domain and record type // GetRecords returns all IP addresses for a domain and record type.
// First checks for exact match in the trie, then wildcard patterns // The second return value indicates whether the domain exists at all
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { // (true = domain exists, use NODATA if no records; false = NXDOMAIN).
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) ([]net.IP, bool) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
path := domainToPath(domain)
// Exact match: walk trie // Check exact match first
node := s.root if rs, exists := s.exact[domain]; exists {
for _, label := range path {
node = node.children[label]
if node == nil {
break
}
}
if node != nil && node.data != nil {
var ips []net.IP var ips []net.IP
if recordType == RecordTypeA { if recordType == RecordTypeA {
ips = node.data.A ips = rs.A
} else { } else {
ips = node.data.AAAA ips = rs.AAAA
} }
if len(ips) > 0 { if len(ips) > 0 {
out := make([]net.IP, len(ips)) out := make([]net.IP, len(ips))
copy(out, ips) copy(out, ips)
return out return out, true
} }
// Domain exists but no records of this type
return nil, true
} }
// Wildcard match // Check wildcard matches
var records []net.IP var records []net.IP
matched := false
for pattern, rs := range s.wildcards { for pattern, rs := range s.wildcards {
if !matchWildcard(pattern, domain) { if !matchWildcard(pattern, domain) {
continue continue
} }
matched = true
if recordType == RecordTypeA { if recordType == RecordTypeA {
records = append(records, rs.A...) records = append(records, rs.A...)
} else { } else {
records = append(records, rs.AAAA...) records = append(records, rs.AAAA...)
} }
} }
if !matched {
return nil, false
}
if len(records) == 0 { if len(records) == 0 {
return nil return nil, true
} }
out := make([]net.IP, len(records)) out := make([]net.IP, len(records))
copy(out, records) copy(out, records)
return out return out, true
} }
// GetPTRRecord returns the domain name for a PTR record query // GetPTRRecord returns the domain name for a PTR record query
@@ -295,30 +252,24 @@ func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) {
} }
// HasRecord checks if a domain has any records of the specified type // HasRecord checks if a domain has any records of the specified type
// Checks both exact matches (trie) and wildcard patterns // Checks both exact matches and wildcard patterns
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
path := domainToPath(domain)
node := s.root // Check exact match
for _, label := range path { if rs, exists := s.exact[domain]; exists {
node = node.children[label] if recordType == RecordTypeA && len(rs.A) > 0 {
if node == nil {
break
}
}
if node != nil && node.data != nil {
if recordType == RecordTypeA && len(node.data.A) > 0 {
return true return true
} }
if recordType == RecordTypeAAAA && len(node.data.AAAA) > 0 { if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
return true return true
} }
} }
// Check wildcard matches
for pattern, rs := range s.wildcards { for pattern, rs := range s.wildcards {
if !matchWildcard(pattern, domain) { if !matchWildcard(pattern, domain) {
continue continue
@@ -353,7 +304,7 @@ func (s *DNSRecordStore) Clear() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.root = &domainTrieNode{children: make(map[string]*domainTrieNode)} s.exact = make(map[string]*recordSet)
s.wildcards = make(map[string]*recordSet) s.wildcards = make(map[string]*recordSet)
s.ptrRecords = make(map[string]string) s.ptrRecords = make(map[string]string)
} }

View File

@@ -183,25 +183,34 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
} }
// Test exact match takes precedence // Test exact match takes precedence
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA) ips, exists := store.GetRecords("exact.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected domain to exist")
}
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for exact match, got %d", len(ips)) t.Errorf("Expected 1 IP for exact match, got %d", len(ips))
} }
if !ips[0].Equal(exactIP) { if len(ips) > 0 && !ips[0].Equal(exactIP) {
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0]) t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
} }
// Test wildcard match // Test wildcard match
ips = store.GetRecords("host.autoco.internal.", RecordTypeA) ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected wildcard match to exist")
}
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips)) t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips))
} }
if !ips[0].Equal(wildcardIP) { if len(ips) > 0 && !ips[0].Equal(wildcardIP) {
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0]) t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
} }
// Test non-match (base domain) // Test non-match (base domain)
ips = store.GetRecords("autoco.internal.", RecordTypeA) ips, exists = store.GetRecords("autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected base domain to not exist")
}
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs for base domain, got %d", len(ips)) t.Errorf("Expected 0 IPs for base domain, got %d", len(ips))
} }
@@ -218,7 +227,10 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
} }
// Test matching domain // Test matching domain
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) ips, exists := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected complex wildcard match to exist")
}
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips)) t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips))
} }
@@ -227,13 +239,19 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
} }
// Test non-matching domain (missing prefix) // Test non-matching domain (missing prefix)
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA) ips, exists = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected domain without prefix to not exist")
}
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips)) t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
} }
// Test non-matching domain (wrong ? position) // Test non-matching domain (wrong ? position)
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) ips, exists = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected domain with wrong ? match to not exist")
}
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips)) t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips))
} }
@@ -250,7 +268,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
} }
// Verify it exists // Verify it exists
ips := store.GetRecords("host.autoco.internal.", RecordTypeA) ips, exists := store.GetRecords("host.autoco.internal.", RecordTypeA)
if !exists {
t.Error("Expected domain to exist before removal")
}
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP before removal, got %d", len(ips)) t.Errorf("Expected 1 IP before removal, got %d", len(ips))
} }
@@ -259,7 +280,10 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
store.RemoveRecord("*.autoco.internal", nil) store.RemoveRecord("*.autoco.internal", nil)
// Verify it's gone // Verify it's gone
ips = store.GetRecords("host.autoco.internal.", RecordTypeA) ips, exists = store.GetRecords("host.autoco.internal.", RecordTypeA)
if exists {
t.Error("Expected domain to not exist after removal")
}
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
} }
@@ -290,19 +314,19 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
} }
// Test domain matching only the prod pattern and the broad pattern // Test domain matching only the prod pattern and the broad pattern
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) ips, _ := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
if len(ips) != 2 { if len(ips) != 2 {
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips)) t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
} }
// Test domain matching only the dev pattern and the broad pattern // Test domain matching only the dev pattern and the broad pattern
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) ips, _ = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
if len(ips) != 2 { if len(ips) != 2 {
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips)) t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
} }
// Test domain matching only the broad pattern // Test domain matching only the broad pattern
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA) ips, _ = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP (broad only), got %d", len(ips)) t.Errorf("Expected 1 IP (broad only), got %d", len(ips))
} }
@@ -319,7 +343,7 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
} }
// Test wildcard match for IPv6 // Test wildcard match for IPv6
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) ips, _ := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips)) t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips))
} }
@@ -368,7 +392,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
} }
for _, domain := range testCases { for _, domain := range testCases {
ips := store.GetRecords(domain, RecordTypeA) ips, _ := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips)) t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
} }
@@ -392,7 +416,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
} }
for _, domain := range wildcardTestCases { for _, domain := range wildcardTestCases {
ips := store.GetRecords(domain, RecordTypeA) ips, _ := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 { if len(ips) != 1 {
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips)) t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
} }
@@ -403,7 +427,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Test removal with different case // Test removal with different case
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil) store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA) ips, _ := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
} }
@@ -752,7 +776,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) {
} }
// Verify A record is also gone // Verify A record is also gone
ips := store.GetRecords(domain, RecordTypeA) ips, _ := store.GetRecords(domain, RecordTypeA)
if len(ips) != 0 { if len(ips) != 0 {
t.Errorf("Expected A record to be removed, got %d records", len(ips)) t.Errorf("Expected A record to be removed, got %d records", len(ips))
} }