refactor(dns): trie + unified record set for DNSRecordStore

- Replace four maps (aRecords, aaaaRecords, aWildcards, aaaaWildcards) with a label trie for exact lookups and a single wildcards map
- Store one recordSet (A + AAAA) per domain/pattern instead of separate A and AAAA maps
- Exact lookups O(labels); PTR unchanged (map); API and behaviour unchanged
This commit is contained in:
Laurence
2026-02-26 11:15:42 +00:00
committed by Owen Schwartz
parent 21b66fbb34
commit 5ca4825800

View File

@@ -18,24 +18,49 @@ const (
RecordTypePTR RecordType = RecordType(dns.TypePTR) RecordTypePTR RecordType = RecordType(dns.TypePTR)
) )
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries // recordSet holds A and AAAA records for a single domain or wildcard pattern
type recordSet struct {
A []net.IP
AAAA []net.IP
}
// domainTrieNode is a node in the trie for exact domain lookups (no wildcards in path)
type domainTrieNode struct {
children map[string]*domainTrieNode
data *recordSet
}
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
// Exact domains are stored in a trie for O(label count) lookup; wildcard patterns
// are in a separate map. Each domain/pattern has a single recordSet (A + AAAA).
type DNSRecordStore struct { type DNSRecordStore struct {
mu sync.RWMutex mu sync.RWMutex
aRecords map[string][]net.IP // domain -> list of IPv4 addresses root *domainTrieNode // trie root for exact lookups
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses ptrRecords map[string]string // IP address string -> domain name
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses }
ptrRecords map[string]string // IP address string -> domain name
// domainToPath converts a FQDN to a trie path (reversed labels, e.g. "host.internal." -> ["internal", "host"])
func domainToPath(domain string) []string {
domain = strings.ToLower(dns.Fqdn(domain))
domain = strings.TrimSuffix(domain, ".")
if domain == "" {
return nil
}
labels := strings.Split(domain, ".")
path := make([]string, 0, len(labels))
for i := len(labels) - 1; i >= 0; i-- {
path = append(path, labels[i])
}
return path
} }
// NewDNSRecordStore creates a new DNS record store // NewDNSRecordStore creates a new DNS record store
func NewDNSRecordStore() *DNSRecordStore { func NewDNSRecordStore() *DNSRecordStore {
return &DNSRecordStore{ return &DNSRecordStore{
aRecords: make(map[string][]net.IP), root: &domainTrieNode{children: make(map[string]*domainTrieNode)},
aaaaRecords: make(map[string][]net.IP), wildcards: make(map[string]*recordSet),
aWildcards: make(map[string][]net.IP), ptrRecords: make(map[string]string),
aaaaWildcards: make(map[string][]net.IP),
ptrRecords: make(map[string]string),
} }
} }
@@ -48,39 +73,47 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' { if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "." domain = domain + "."
} }
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?") isWildcard := strings.ContainsAny(domain, "*?")
if ip.To4() != nil { isV4 := ip.To4() != nil
// IPv4 address if !isV4 && ip.To16() == nil {
if isWildcard {
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
} else {
s.aRecords[domain] = append(s.aRecords[domain], ip)
// Automatically add PTR record for non-wildcard domains
s.ptrRecords[ip.String()] = domain
}
} else if ip.To16() != nil {
// IPv6 address
if isWildcard {
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
} else {
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
// Automatically add PTR record for non-wildcard domains
s.ptrRecords[ip.String()] = domain
}
} else {
return &net.ParseError{Type: "IP address", Text: ip.String()} return &net.ParseError{Type: "IP address", Text: ip.String()}
} }
if isWildcard {
if s.wildcards[domain] == nil {
s.wildcards[domain] = &recordSet{}
}
rs := s.wildcards[domain]
if isV4 {
rs.A = append(rs.A, ip)
} else {
rs.AAAA = append(rs.AAAA, ip)
}
return nil
}
path := domainToPath(domain)
node := s.root
for _, label := range path {
if node.children[label] == nil {
node.children[label] = &domainTrieNode{children: make(map[string]*domainTrieNode)}
}
node = node.children[label]
}
if node.data == nil {
node.data = &recordSet{}
}
if isV4 {
node.data.A = append(node.data.A, ip)
} else {
node.data.AAAA = append(node.data.AAAA, ip)
}
s.ptrRecords[ip.String()] = domain
return nil return nil
} }
@@ -112,89 +145,74 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' { if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "." domain = domain + "."
} }
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?") isWildcard := strings.ContainsAny(domain, "*?")
if ip == nil { if isWildcard {
// Remove all records for this domain if ip == nil {
if isWildcard { delete(s.wildcards, domain)
delete(s.aWildcards, domain) return
delete(s.aaaaWildcards, domain) }
rs := s.wildcards[domain]
if rs == nil {
return
}
if ip.To4() != nil {
rs.A = removeIP(rs.A, ip)
} else { } else {
// For non-wildcard domains, remove PTR records for all IPs rs.AAAA = removeIP(rs.AAAA, ip)
if ips, ok := s.aRecords[domain]; ok { }
for _, ipAddr := range ips { if len(rs.A) == 0 && len(rs.AAAA) == 0 {
// Only remove PTR if it points to this domain delete(s.wildcards, domain)
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
if ips, ok := s.aaaaRecords[domain]; ok {
for _, ipAddr := range ips {
// Only remove PTR if it points to this domain
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
delete(s.aRecords, domain)
delete(s.aaaaRecords, domain)
} }
return return
} }
// Exact domain: find trie node
path := domainToPath(domain)
node := s.root
for _, label := range path {
node = node.children[label]
if node == nil {
return
}
}
if node.data == nil {
return
}
if ip == nil {
for _, ipAddr := range node.data.A {
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
for _, ipAddr := range node.data.AAAA {
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
node.data = nil
return
}
if ip.To4() != nil { if ip.To4() != nil {
// Remove specific IPv4 address node.data.A = removeIP(node.data.A, ip)
if isWildcard { if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
if ips, ok := s.aWildcards[domain]; ok { delete(s.ptrRecords, ip.String())
s.aWildcards[domain] = removeIP(ips, ip)
if len(s.aWildcards[domain]) == 0 {
delete(s.aWildcards, domain)
}
}
} else {
if ips, ok := s.aRecords[domain]; ok {
s.aRecords[domain] = removeIP(ips, ip)
if len(s.aRecords[domain]) == 0 {
delete(s.aRecords, domain)
}
// Automatically remove PTR record if it points to this domain
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
}
} }
} else if ip.To16() != nil { } else {
// Remove specific IPv6 address node.data.AAAA = removeIP(node.data.AAAA, ip)
if isWildcard { if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
if ips, ok := s.aaaaWildcards[domain]; ok { delete(s.ptrRecords, ip.String())
s.aaaaWildcards[domain] = removeIP(ips, ip)
if len(s.aaaaWildcards[domain]) == 0 {
delete(s.aaaaWildcards, domain)
}
}
} else {
if ips, ok := s.aaaaRecords[domain]; ok {
s.aaaaRecords[domain] = removeIP(ips, ip)
if len(s.aaaaRecords[domain]) == 0 {
delete(s.aaaaRecords, domain)
}
// Automatically remove PTR record if it points to this domain
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
}
} }
} }
if len(node.data.A) == 0 && len(node.data.AAAA) == 0 {
node.data = nil
}
} }
// RemovePTRRecord removes a PTR record for an IP address // RemovePTRRecord removes a PTR record for an IP address
@@ -206,60 +224,54 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
} }
// GetRecords returns all IP addresses for a domain and record type // GetRecords returns all IP addresses for a domain and record type
// First checks for exact matches, then checks wildcard patterns // First checks for exact match in the trie, then wildcard patterns
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
path := domainToPath(domain)
var records []net.IP // Exact match: walk trie
switch recordType { node := s.root
case RecordTypeA: for _, label := range path {
// Check exact match first node = node.children[label]
if ips, ok := s.aRecords[domain]; ok { if node == nil {
// Return a copy to prevent external modifications break
records = make([]net.IP, len(ips))
copy(records, ips)
return records
} }
// Check wildcard patterns }
for pattern, ips := range s.aWildcards { if node != nil && node.data != nil {
if matchWildcard(pattern, domain) { var ips []net.IP
records = append(records, ips...) if recordType == RecordTypeA {
} ips = node.data.A
} else {
ips = node.data.AAAA
} }
if len(records) > 0 { if len(ips) > 0 {
// Return a copy out := make([]net.IP, len(ips))
result := make([]net.IP, len(records)) copy(out, ips)
copy(result, records) return out
return result
}
case RecordTypeAAAA:
// Check exact match first
if ips, ok := s.aaaaRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
return records
}
// Check wildcard patterns
for pattern, ips := range s.aaaaWildcards {
if matchWildcard(pattern, domain) {
records = append(records, ips...)
}
}
if len(records) > 0 {
// Return a copy
result := make([]net.IP, len(records))
copy(result, records)
return result
} }
} }
return records // Wildcard match
var records []net.IP
for pattern, rs := range s.wildcards {
if !matchWildcard(pattern, domain) {
continue
}
if recordType == RecordTypeA {
records = append(records, rs.A...)
} else {
records = append(records, rs.AAAA...)
}
}
if len(records) == 0 {
return nil
}
out := make([]net.IP, len(records))
copy(out, records)
return out
} }
// GetPTRRecord returns the domain name for a PTR record query // GetPTRRecord returns the domain name for a PTR record query
@@ -283,39 +295,41 @@ func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) {
} }
// HasRecord checks if a domain has any records of the specified type // HasRecord checks if a domain has any records of the specified type
// Checks both exact matches and wildcard patterns // Checks both exact matches (trie) and wildcard patterns
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain)) domain = strings.ToLower(dns.Fqdn(domain))
path := domainToPath(domain)
switch recordType { node := s.root
case RecordTypeA: for _, label := range path {
// Check exact match node = node.children[label]
if _, ok := s.aRecords[domain]; ok { if node == nil {
break
}
}
if node != nil && node.data != nil {
if recordType == RecordTypeA && len(node.data.A) > 0 {
return true return true
} }
// Check wildcard patterns if recordType == RecordTypeAAAA && len(node.data.AAAA) > 0 {
for pattern := range s.aWildcards {
if matchWildcard(pattern, domain) {
return true
}
}
case RecordTypeAAAA:
// Check exact match
if _, ok := s.aaaaRecords[domain]; ok {
return true return true
} }
// Check wildcard patterns
for pattern := range s.aaaaWildcards {
if matchWildcard(pattern, domain) {
return true
}
}
} }
for pattern, rs := range s.wildcards {
if !matchWildcard(pattern, domain) {
continue
}
if recordType == RecordTypeA && len(rs.A) > 0 {
return true
}
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
return true
}
}
return false return false
} }
@@ -339,10 +353,8 @@ func (s *DNSRecordStore) Clear() {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.aRecords = make(map[string][]net.IP) s.root = &domainTrieNode{children: make(map[string]*domainTrieNode)}
s.aaaaRecords = make(map[string][]net.IP) s.wildcards = make(map[string]*recordSet)
s.aWildcards = make(map[string][]net.IP)
s.aaaaWildcards = make(map[string][]net.IP)
s.ptrRecords = make(map[string]string) s.ptrRecords = make(map[string]string)
} }
@@ -494,4 +506,4 @@ func IPToReverseDNS(ip net.IP) string {
} }
return "" return ""
} }