mirror of
https://github.com/fosrl/olm.git
synced 2026-03-07 11:16:40 +00:00
refactor(dns): trie + unified record set for DNSRecordStore
- Replace four maps (aRecords, aaaaRecords, aWildcards, aaaaWildcards) with a label trie for exact lookups and a single wildcards map - Store one recordSet (A + AAAA) per domain/pattern instead of separate A and AAAA maps - Exact lookups O(labels); PTR unchanged (map); API and behaviour unchanged
This commit is contained in:
@@ -18,24 +18,49 @@ const (
|
|||||||
RecordTypePTR RecordType = RecordType(dns.TypePTR)
|
RecordTypePTR RecordType = RecordType(dns.TypePTR)
|
||||||
)
|
)
|
||||||
|
|
||||||
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries
|
// recordSet holds A and AAAA records for a single domain or wildcard pattern
|
||||||
|
type recordSet struct {
|
||||||
|
A []net.IP
|
||||||
|
AAAA []net.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
// domainTrieNode is a node in the trie for exact domain lookups (no wildcards in path)
|
||||||
|
type domainTrieNode struct {
|
||||||
|
children map[string]*domainTrieNode
|
||||||
|
data *recordSet
|
||||||
|
}
|
||||||
|
|
||||||
|
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
|
||||||
|
// Exact domains are stored in a trie for O(label count) lookup; wildcard patterns
|
||||||
|
// are in a separate map. Each domain/pattern has a single recordSet (A + AAAA).
|
||||||
type DNSRecordStore struct {
|
type DNSRecordStore struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
|
root *domainTrieNode // trie root for exact lookups
|
||||||
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
|
wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records
|
||||||
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
|
ptrRecords map[string]string // IP address string -> domain name
|
||||||
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
|
}
|
||||||
ptrRecords map[string]string // IP address string -> domain name
|
|
||||||
|
// domainToPath converts a FQDN to a trie path (reversed labels, e.g. "host.internal." -> ["internal", "host"])
|
||||||
|
func domainToPath(domain string) []string {
|
||||||
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
domain = strings.TrimSuffix(domain, ".")
|
||||||
|
if domain == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
labels := strings.Split(domain, ".")
|
||||||
|
path := make([]string, 0, len(labels))
|
||||||
|
for i := len(labels) - 1; i >= 0; i-- {
|
||||||
|
path = append(path, labels[i])
|
||||||
|
}
|
||||||
|
return path
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSRecordStore creates a new DNS record store
|
// NewDNSRecordStore creates a new DNS record store
|
||||||
func NewDNSRecordStore() *DNSRecordStore {
|
func NewDNSRecordStore() *DNSRecordStore {
|
||||||
return &DNSRecordStore{
|
return &DNSRecordStore{
|
||||||
aRecords: make(map[string][]net.IP),
|
root: &domainTrieNode{children: make(map[string]*domainTrieNode)},
|
||||||
aaaaRecords: make(map[string][]net.IP),
|
wildcards: make(map[string]*recordSet),
|
||||||
aWildcards: make(map[string][]net.IP),
|
ptrRecords: make(map[string]string),
|
||||||
aaaaWildcards: make(map[string][]net.IP),
|
|
||||||
ptrRecords: make(map[string]string),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,39 +73,47 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
// Ensure domain ends with a dot (FQDN format)
|
|
||||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||||
domain = domain + "."
|
domain = domain + "."
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize domain to lowercase FQDN
|
|
||||||
domain = strings.ToLower(dns.Fqdn(domain))
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
|
||||||
// Check if domain contains wildcards
|
|
||||||
isWildcard := strings.ContainsAny(domain, "*?")
|
isWildcard := strings.ContainsAny(domain, "*?")
|
||||||
|
|
||||||
if ip.To4() != nil {
|
isV4 := ip.To4() != nil
|
||||||
// IPv4 address
|
if !isV4 && ip.To16() == nil {
|
||||||
if isWildcard {
|
|
||||||
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
|
|
||||||
} else {
|
|
||||||
s.aRecords[domain] = append(s.aRecords[domain], ip)
|
|
||||||
// Automatically add PTR record for non-wildcard domains
|
|
||||||
s.ptrRecords[ip.String()] = domain
|
|
||||||
}
|
|
||||||
} else if ip.To16() != nil {
|
|
||||||
// IPv6 address
|
|
||||||
if isWildcard {
|
|
||||||
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
|
|
||||||
} else {
|
|
||||||
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
|
|
||||||
// Automatically add PTR record for non-wildcard domains
|
|
||||||
s.ptrRecords[ip.String()] = domain
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if isWildcard {
|
||||||
|
if s.wildcards[domain] == nil {
|
||||||
|
s.wildcards[domain] = &recordSet{}
|
||||||
|
}
|
||||||
|
rs := s.wildcards[domain]
|
||||||
|
if isV4 {
|
||||||
|
rs.A = append(rs.A, ip)
|
||||||
|
} else {
|
||||||
|
rs.AAAA = append(rs.AAAA, ip)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
path := domainToPath(domain)
|
||||||
|
node := s.root
|
||||||
|
for _, label := range path {
|
||||||
|
if node.children[label] == nil {
|
||||||
|
node.children[label] = &domainTrieNode{children: make(map[string]*domainTrieNode)}
|
||||||
|
}
|
||||||
|
node = node.children[label]
|
||||||
|
}
|
||||||
|
if node.data == nil {
|
||||||
|
node.data = &recordSet{}
|
||||||
|
}
|
||||||
|
if isV4 {
|
||||||
|
node.data.A = append(node.data.A, ip)
|
||||||
|
} else {
|
||||||
|
node.data.AAAA = append(node.data.AAAA, ip)
|
||||||
|
}
|
||||||
|
s.ptrRecords[ip.String()] = domain
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,89 +145,74 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
// Ensure domain ends with a dot (FQDN format)
|
|
||||||
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
||||||
domain = domain + "."
|
domain = domain + "."
|
||||||
}
|
}
|
||||||
|
|
||||||
// Normalize domain to lowercase FQDN
|
|
||||||
domain = strings.ToLower(dns.Fqdn(domain))
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
|
||||||
// Check if domain contains wildcards
|
|
||||||
isWildcard := strings.ContainsAny(domain, "*?")
|
isWildcard := strings.ContainsAny(domain, "*?")
|
||||||
|
|
||||||
if ip == nil {
|
if isWildcard {
|
||||||
// Remove all records for this domain
|
if ip == nil {
|
||||||
if isWildcard {
|
delete(s.wildcards, domain)
|
||||||
delete(s.aWildcards, domain)
|
return
|
||||||
delete(s.aaaaWildcards, domain)
|
}
|
||||||
|
rs := s.wildcards[domain]
|
||||||
|
if rs == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ip.To4() != nil {
|
||||||
|
rs.A = removeIP(rs.A, ip)
|
||||||
} else {
|
} else {
|
||||||
// For non-wildcard domains, remove PTR records for all IPs
|
rs.AAAA = removeIP(rs.AAAA, ip)
|
||||||
if ips, ok := s.aRecords[domain]; ok {
|
}
|
||||||
for _, ipAddr := range ips {
|
if len(rs.A) == 0 && len(rs.AAAA) == 0 {
|
||||||
// Only remove PTR if it points to this domain
|
delete(s.wildcards, domain)
|
||||||
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
|
||||||
delete(s.ptrRecords, ipAddr.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
|
||||||
for _, ipAddr := range ips {
|
|
||||||
// Only remove PTR if it points to this domain
|
|
||||||
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
|
||||||
delete(s.ptrRecords, ipAddr.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(s.aRecords, domain)
|
|
||||||
delete(s.aaaaRecords, domain)
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Exact domain: find trie node
|
||||||
|
path := domainToPath(domain)
|
||||||
|
node := s.root
|
||||||
|
for _, label := range path {
|
||||||
|
node = node.children[label]
|
||||||
|
if node == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if node.data == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip == nil {
|
||||||
|
for _, ipAddr := range node.data.A {
|
||||||
|
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||||
|
delete(s.ptrRecords, ipAddr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, ipAddr := range node.data.AAAA {
|
||||||
|
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
||||||
|
delete(s.ptrRecords, ipAddr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
node.data = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if ip.To4() != nil {
|
if ip.To4() != nil {
|
||||||
// Remove specific IPv4 address
|
node.data.A = removeIP(node.data.A, ip)
|
||||||
if isWildcard {
|
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||||
if ips, ok := s.aWildcards[domain]; ok {
|
delete(s.ptrRecords, ip.String())
|
||||||
s.aWildcards[domain] = removeIP(ips, ip)
|
|
||||||
if len(s.aWildcards[domain]) == 0 {
|
|
||||||
delete(s.aWildcards, domain)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if ips, ok := s.aRecords[domain]; ok {
|
|
||||||
s.aRecords[domain] = removeIP(ips, ip)
|
|
||||||
if len(s.aRecords[domain]) == 0 {
|
|
||||||
delete(s.aRecords, domain)
|
|
||||||
}
|
|
||||||
// Automatically remove PTR record if it points to this domain
|
|
||||||
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
|
||||||
delete(s.ptrRecords, ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else if ip.To16() != nil {
|
} else {
|
||||||
// Remove specific IPv6 address
|
node.data.AAAA = removeIP(node.data.AAAA, ip)
|
||||||
if isWildcard {
|
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
||||||
if ips, ok := s.aaaaWildcards[domain]; ok {
|
delete(s.ptrRecords, ip.String())
|
||||||
s.aaaaWildcards[domain] = removeIP(ips, ip)
|
|
||||||
if len(s.aaaaWildcards[domain]) == 0 {
|
|
||||||
delete(s.aaaaWildcards, domain)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
|
||||||
s.aaaaRecords[domain] = removeIP(ips, ip)
|
|
||||||
if len(s.aaaaRecords[domain]) == 0 {
|
|
||||||
delete(s.aaaaRecords, domain)
|
|
||||||
}
|
|
||||||
// Automatically remove PTR record if it points to this domain
|
|
||||||
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
|
||||||
delete(s.ptrRecords, ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(node.data.A) == 0 && len(node.data.AAAA) == 0 {
|
||||||
|
node.data = nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePTRRecord removes a PTR record for an IP address
|
// RemovePTRRecord removes a PTR record for an IP address
|
||||||
@@ -206,60 +224,54 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetRecords returns all IP addresses for a domain and record type
|
// GetRecords returns all IP addresses for a domain and record type
|
||||||
// First checks for exact matches, then checks wildcard patterns
|
// First checks for exact match in the trie, then wildcard patterns
|
||||||
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
|
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
// Normalize domain to lowercase FQDN
|
|
||||||
domain = strings.ToLower(dns.Fqdn(domain))
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
path := domainToPath(domain)
|
||||||
|
|
||||||
var records []net.IP
|
// Exact match: walk trie
|
||||||
switch recordType {
|
node := s.root
|
||||||
case RecordTypeA:
|
for _, label := range path {
|
||||||
// Check exact match first
|
node = node.children[label]
|
||||||
if ips, ok := s.aRecords[domain]; ok {
|
if node == nil {
|
||||||
// Return a copy to prevent external modifications
|
break
|
||||||
records = make([]net.IP, len(ips))
|
|
||||||
copy(records, ips)
|
|
||||||
return records
|
|
||||||
}
|
}
|
||||||
// Check wildcard patterns
|
}
|
||||||
for pattern, ips := range s.aWildcards {
|
if node != nil && node.data != nil {
|
||||||
if matchWildcard(pattern, domain) {
|
var ips []net.IP
|
||||||
records = append(records, ips...)
|
if recordType == RecordTypeA {
|
||||||
}
|
ips = node.data.A
|
||||||
|
} else {
|
||||||
|
ips = node.data.AAAA
|
||||||
}
|
}
|
||||||
if len(records) > 0 {
|
if len(ips) > 0 {
|
||||||
// Return a copy
|
out := make([]net.IP, len(ips))
|
||||||
result := make([]net.IP, len(records))
|
copy(out, ips)
|
||||||
copy(result, records)
|
return out
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
case RecordTypeAAAA:
|
|
||||||
// Check exact match first
|
|
||||||
if ips, ok := s.aaaaRecords[domain]; ok {
|
|
||||||
// Return a copy to prevent external modifications
|
|
||||||
records = make([]net.IP, len(ips))
|
|
||||||
copy(records, ips)
|
|
||||||
return records
|
|
||||||
}
|
|
||||||
// Check wildcard patterns
|
|
||||||
for pattern, ips := range s.aaaaWildcards {
|
|
||||||
if matchWildcard(pattern, domain) {
|
|
||||||
records = append(records, ips...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(records) > 0 {
|
|
||||||
// Return a copy
|
|
||||||
result := make([]net.IP, len(records))
|
|
||||||
copy(result, records)
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return records
|
// Wildcard match
|
||||||
|
var records []net.IP
|
||||||
|
for pattern, rs := range s.wildcards {
|
||||||
|
if !matchWildcard(pattern, domain) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if recordType == RecordTypeA {
|
||||||
|
records = append(records, rs.A...)
|
||||||
|
} else {
|
||||||
|
records = append(records, rs.AAAA...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(records) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]net.IP, len(records))
|
||||||
|
copy(out, records)
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPTRRecord returns the domain name for a PTR record query
|
// GetPTRRecord returns the domain name for a PTR record query
|
||||||
@@ -283,39 +295,41 @@ func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// HasRecord checks if a domain has any records of the specified type
|
// HasRecord checks if a domain has any records of the specified type
|
||||||
// Checks both exact matches and wildcard patterns
|
// Checks both exact matches (trie) and wildcard patterns
|
||||||
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
// Normalize domain to lowercase FQDN
|
|
||||||
domain = strings.ToLower(dns.Fqdn(domain))
|
domain = strings.ToLower(dns.Fqdn(domain))
|
||||||
|
path := domainToPath(domain)
|
||||||
|
|
||||||
switch recordType {
|
node := s.root
|
||||||
case RecordTypeA:
|
for _, label := range path {
|
||||||
// Check exact match
|
node = node.children[label]
|
||||||
if _, ok := s.aRecords[domain]; ok {
|
if node == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if node != nil && node.data != nil {
|
||||||
|
if recordType == RecordTypeA && len(node.data.A) > 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// Check wildcard patterns
|
if recordType == RecordTypeAAAA && len(node.data.AAAA) > 0 {
|
||||||
for pattern := range s.aWildcards {
|
|
||||||
if matchWildcard(pattern, domain) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case RecordTypeAAAA:
|
|
||||||
// Check exact match
|
|
||||||
if _, ok := s.aaaaRecords[domain]; ok {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// Check wildcard patterns
|
|
||||||
for pattern := range s.aaaaWildcards {
|
|
||||||
if matchWildcard(pattern, domain) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for pattern, rs := range s.wildcards {
|
||||||
|
if !matchWildcard(pattern, domain) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if recordType == RecordTypeA && len(rs.A) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -339,10 +353,8 @@ func (s *DNSRecordStore) Clear() {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
s.aRecords = make(map[string][]net.IP)
|
s.root = &domainTrieNode{children: make(map[string]*domainTrieNode)}
|
||||||
s.aaaaRecords = make(map[string][]net.IP)
|
s.wildcards = make(map[string]*recordSet)
|
||||||
s.aWildcards = make(map[string][]net.IP)
|
|
||||||
s.aaaaWildcards = make(map[string][]net.IP)
|
|
||||||
s.ptrRecords = make(map[string]string)
|
s.ptrRecords = make(map[string]string)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -494,4 +506,4 @@ func IPToReverseDNS(ip net.IP) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user