Files
olm/dns/dns_records.go
Laurence 9ae49e36d5 refactor(dns): simplify DNSRecordStore from trie to map
Replace trie-based domain lookup with simple map for O(1) lookups.
  Add exists boolean to GetRecords for proper NODATA vs NXDOMAIN responses.
2026-03-06 16:08:01 -08:00

461 lines
12 KiB
Go

package dns
import (
"fmt"
"net"
"strings"
"sync"
"github.com/miekg/dns"
)
// RecordType represents the type of DNS record
type RecordType uint16
const (
RecordTypeA RecordType = RecordType(dns.TypeA)
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
RecordTypePTR RecordType = RecordType(dns.TypePTR)
)
// recordSet holds A and AAAA records for a single domain or wildcard pattern
type recordSet struct {
A []net.IP
AAAA []net.IP
}
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
// Exact domains are stored in a map; wildcard patterns are in a separate map.
type DNSRecordStore struct {
mu sync.RWMutex
exact map[string]*recordSet // normalized FQDN -> A/AAAA records
wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records
ptrRecords map[string]string // IP address string -> domain name
}
// NewDNSRecordStore creates a new DNS record store
func NewDNSRecordStore() *DNSRecordStore {
return &DNSRecordStore{
exact: make(map[string]*recordSet),
wildcards: make(map[string]*recordSet),
ptrRecords: make(map[string]string),
}
}
// AddRecord adds a DNS record mapping (A or AAAA)
// domain should be in FQDN format (e.g., "example.com.")
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
// ip should be a valid IPv4 or IPv6 address
// Automatically adds a corresponding PTR record for non-wildcard domains
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.mu.Lock()
defer s.mu.Unlock()
if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "."
}
domain = strings.ToLower(dns.Fqdn(domain))
isWildcard := strings.ContainsAny(domain, "*?")
isV4 := ip.To4() != nil
if !isV4 && ip.To16() == nil {
return &net.ParseError{Type: "IP address", Text: ip.String()}
}
// Choose the appropriate map based on whether this is a wildcard
m := s.exact
if isWildcard {
m = s.wildcards
}
if m[domain] == nil {
m[domain] = &recordSet{}
}
rs := m[domain]
if isV4 {
rs.A = append(rs.A, ip)
} else {
rs.AAAA = append(rs.AAAA, ip)
}
// Add PTR record for non-wildcard domains
if !isWildcard {
s.ptrRecords[ip.String()] = domain
}
return nil
}
// AddPTRRecord adds a PTR record mapping an IP address to a domain name
// ip should be a valid IPv4 or IPv6 address
// domain should be in FQDN format (e.g., "example.com.")
func (s *DNSRecordStore) AddPTRRecord(ip net.IP, domain string) error {
s.mu.Lock()
defer s.mu.Unlock()
// Ensure domain ends with a dot (FQDN format)
if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "."
}
// Normalize domain to lowercase FQDN
domain = strings.ToLower(dns.Fqdn(domain))
// Store PTR record using IP string as key
s.ptrRecords[ip.String()] = domain
return nil
}
// RemoveRecord removes a specific DNS record mapping
// If ip is nil, removes all records for the domain (including wildcards)
// Automatically removes corresponding PTR records for non-wildcard domains
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
s.mu.Lock()
defer s.mu.Unlock()
if len(domain) == 0 || domain[len(domain)-1] != '.' {
domain = domain + "."
}
domain = strings.ToLower(dns.Fqdn(domain))
isWildcard := strings.ContainsAny(domain, "*?")
// Choose the appropriate map
m := s.exact
if isWildcard {
m = s.wildcards
}
rs := m[domain]
if rs == nil {
return
}
if ip == nil {
// Remove all records for this domain
if !isWildcard {
for _, ipAddr := range rs.A {
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
for _, ipAddr := range rs.AAAA {
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ipAddr.String())
}
}
}
delete(m, domain)
return
}
// Remove specific IP
if ip.To4() != nil {
rs.A = removeIP(rs.A, ip)
if !isWildcard {
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
}
} else {
rs.AAAA = removeIP(rs.AAAA, ip)
if !isWildcard {
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
delete(s.ptrRecords, ip.String())
}
}
}
// Clean up empty record sets
if len(rs.A) == 0 && len(rs.AAAA) == 0 {
delete(m, domain)
}
}
// RemovePTRRecord removes a PTR record for an IP address
func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.ptrRecords, ip.String())
}
// GetRecords returns all IP addresses for a domain and record type.
// The second return value indicates whether the domain exists at all
// (true = domain exists, use NODATA if no records; false = NXDOMAIN).
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) ([]net.IP, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
domain = strings.ToLower(dns.Fqdn(domain))
// Check exact match first
if rs, exists := s.exact[domain]; exists {
var ips []net.IP
if recordType == RecordTypeA {
ips = rs.A
} else {
ips = rs.AAAA
}
if len(ips) > 0 {
out := make([]net.IP, len(ips))
copy(out, ips)
return out, true
}
// Domain exists but no records of this type
return nil, true
}
// Check wildcard matches
var records []net.IP
matched := false
for pattern, rs := range s.wildcards {
if !matchWildcard(pattern, domain) {
continue
}
matched = true
if recordType == RecordTypeA {
records = append(records, rs.A...)
} else {
records = append(records, rs.AAAA...)
}
}
if !matched {
return nil, false
}
if len(records) == 0 {
return nil, true
}
out := make([]net.IP, len(records))
copy(out, records)
return out, true
}
// GetPTRRecord returns the domain name for a PTR record query
// domain should be in reverse DNS format (e.g., "1.0.0.127.in-addr.arpa.")
func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
// Convert reverse DNS format to IP address
ip := reverseDNSToIP(domain)
if ip == nil {
return "", false
}
// Look up the PTR record
if ptrDomain, ok := s.ptrRecords[ip.String()]; ok {
return ptrDomain, true
}
return "", false
}
// HasRecord checks if a domain has any records of the specified type
// Checks both exact matches and wildcard patterns
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
s.mu.RLock()
defer s.mu.RUnlock()
domain = strings.ToLower(dns.Fqdn(domain))
// Check exact match
if rs, exists := s.exact[domain]; exists {
if recordType == RecordTypeA && len(rs.A) > 0 {
return true
}
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
return true
}
}
// Check wildcard matches
for pattern, rs := range s.wildcards {
if !matchWildcard(pattern, domain) {
continue
}
if recordType == RecordTypeA && len(rs.A) > 0 {
return true
}
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
return true
}
}
return false
}
// HasPTRRecord checks if a PTR record exists for the given reverse DNS domain
func (s *DNSRecordStore) HasPTRRecord(domain string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
// Convert reverse DNS format to IP address
ip := reverseDNSToIP(domain)
if ip == nil {
return false
}
_, ok := s.ptrRecords[ip.String()]
return ok
}
// Clear removes all records from the store
func (s *DNSRecordStore) Clear() {
s.mu.Lock()
defer s.mu.Unlock()
s.exact = make(map[string]*recordSet)
s.wildcards = make(map[string]*recordSet)
s.ptrRecords = make(map[string]string)
}
// removeIP is a helper function to remove a specific IP from a slice
func removeIP(ips []net.IP, toRemove net.IP) []net.IP {
result := make([]net.IP, 0, len(ips))
for _, ip := range ips {
if !ip.Equal(toRemove) {
result = append(result, ip)
}
}
return result
}
// matchWildcard checks if a domain matches a wildcard pattern
// Pattern supports * (0+ chars) and ? (exactly 1 char)
// Special case: *.domain.com does not match domain.com itself
func matchWildcard(pattern, domain string) bool {
return matchWildcardInternal(pattern, domain, 0, 0)
}
// matchWildcardInternal performs the actual wildcard matching recursively
func matchWildcardInternal(pattern, domain string, pi, di int) bool {
plen := len(pattern)
dlen := len(domain)
// Base cases
if pi == plen && di == dlen {
return true
}
if pi == plen {
return false
}
// Handle wildcard characters
if pattern[pi] == '*' {
// Special case: if pattern starts with "*." and we're at the beginning,
// ensure we don't match the domain without a prefix
// e.g., *.autoco.internal should not match autoco.internal
if pi == 0 && pi+1 < plen && pattern[pi+1] == '.' {
// The * must match at least one character
if di == dlen {
return false
}
// Try matching 1 or more characters before the dot
for i := di + 1; i <= dlen; i++ {
if matchWildcardInternal(pattern, domain, pi+1, i) {
return true
}
}
return false
}
// Normal * matching (0 or more characters)
// Try matching 0 characters (skip the *)
if matchWildcardInternal(pattern, domain, pi+1, di) {
return true
}
// Try matching 1+ characters
if di < dlen {
return matchWildcardInternal(pattern, domain, pi, di+1)
}
return false
}
if pattern[pi] == '?' {
// ? matches exactly one character
if di >= dlen {
return false
}
return matchWildcardInternal(pattern, domain, pi+1, di+1)
}
// Regular character - must match exactly
if di >= dlen || pattern[pi] != domain[di] {
return false
}
return matchWildcardInternal(pattern, domain, pi+1, di+1)
}
// reverseDNSToIP converts a reverse DNS query name to an IP address
// Supports both IPv4 (in-addr.arpa) and IPv6 (ip6.arpa) formats
func reverseDNSToIP(domain string) net.IP {
// Normalize to lowercase and ensure FQDN
domain = strings.ToLower(dns.Fqdn(domain))
// Check for IPv4 reverse DNS (in-addr.arpa)
if strings.HasSuffix(domain, ".in-addr.arpa.") {
// Remove the suffix
ipPart := strings.TrimSuffix(domain, ".in-addr.arpa.")
// Split by dots and reverse
parts := strings.Split(ipPart, ".")
if len(parts) != 4 {
return nil
}
// Reverse the octets
reversed := make([]string, 4)
for i := 0; i < 4; i++ {
reversed[i] = parts[3-i]
}
// Parse as IP
return net.ParseIP(strings.Join(reversed, "."))
}
// Check for IPv6 reverse DNS (ip6.arpa)
if strings.HasSuffix(domain, ".ip6.arpa.") {
// Remove the suffix
ipPart := strings.TrimSuffix(domain, ".ip6.arpa.")
// Split by dots and reverse
parts := strings.Split(ipPart, ".")
if len(parts) != 32 {
return nil
}
// Reverse the nibbles and group into 16-bit hex values
reversed := make([]string, 32)
for i := 0; i < 32; i++ {
reversed[i] = parts[31-i]
}
// Join into IPv6 format (groups of 4 nibbles separated by colons)
var ipv6Parts []string
for i := 0; i < 32; i += 4 {
ipv6Parts = append(ipv6Parts, reversed[i]+reversed[i+1]+reversed[i+2]+reversed[i+3])
}
// Parse as IP
return net.ParseIP(strings.Join(ipv6Parts, ":"))
}
return nil
}
// IPToReverseDNS converts an IP address to reverse DNS format
// Returns the domain name for PTR queries (e.g., "1.0.0.127.in-addr.arpa.")
func IPToReverseDNS(ip net.IP) string {
if ip4 := ip.To4(); ip4 != nil {
// IPv4: reverse octets and append .in-addr.arpa.
return dns.Fqdn(fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa",
ip4[3], ip4[2], ip4[1], ip4[0]))
}
if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil {
// IPv6: expand to 32 nibbles, reverse, and append .ip6.arpa.
var nibbles []string
for i := 15; i >= 0; i-- {
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]&0x0f))
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]>>4))
}
return dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa")
}
return ""
}