mirror of
https://github.com/fosrl/olm.git
synced 2026-03-07 03:06:44 +00:00
Replace trie-based domain lookup with simple map for O(1) lookups. Add exists boolean to GetRecords for proper NODATA vs NXDOMAIN responses.
461 lines
12 KiB
Go
461 lines
12 KiB
Go
package dns
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
// RecordType represents the type of DNS record
|
|
type RecordType uint16
|
|
|
|
const (
|
|
RecordTypeA RecordType = RecordType(dns.TypeA)
|
|
RecordTypeAAAA RecordType = RecordType(dns.TypeAAAA)
|
|
RecordTypePTR RecordType = RecordType(dns.TypePTR)
|
|
)
|
|
|
|
// recordSet holds A and AAAA records for a single domain or wildcard pattern
|
|
type recordSet struct {
|
|
A []net.IP
|
|
AAAA []net.IP
|
|
}
|
|
|
|
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
|
|
// Exact domains are stored in a map; wildcard patterns are in a separate map.
|
|
type DNSRecordStore struct {
|
|
mu sync.RWMutex
|
|
exact map[string]*recordSet // normalized FQDN -> A/AAAA records
|
|
wildcards map[string]*recordSet // wildcard pattern -> A/AAAA records
|
|
ptrRecords map[string]string // IP address string -> domain name
|
|
}
|
|
|
|
// NewDNSRecordStore creates a new DNS record store
|
|
func NewDNSRecordStore() *DNSRecordStore {
|
|
return &DNSRecordStore{
|
|
exact: make(map[string]*recordSet),
|
|
wildcards: make(map[string]*recordSet),
|
|
ptrRecords: make(map[string]string),
|
|
}
|
|
}
|
|
|
|
// AddRecord adds a DNS record mapping (A or AAAA)
|
|
// domain should be in FQDN format (e.g., "example.com.")
|
|
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
|
|
// ip should be a valid IPv4 or IPv6 address
|
|
// Automatically adds a corresponding PTR record for non-wildcard domains
|
|
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
|
domain = domain + "."
|
|
}
|
|
domain = strings.ToLower(dns.Fqdn(domain))
|
|
isWildcard := strings.ContainsAny(domain, "*?")
|
|
|
|
isV4 := ip.To4() != nil
|
|
if !isV4 && ip.To16() == nil {
|
|
return &net.ParseError{Type: "IP address", Text: ip.String()}
|
|
}
|
|
|
|
// Choose the appropriate map based on whether this is a wildcard
|
|
m := s.exact
|
|
if isWildcard {
|
|
m = s.wildcards
|
|
}
|
|
|
|
if m[domain] == nil {
|
|
m[domain] = &recordSet{}
|
|
}
|
|
rs := m[domain]
|
|
if isV4 {
|
|
rs.A = append(rs.A, ip)
|
|
} else {
|
|
rs.AAAA = append(rs.AAAA, ip)
|
|
}
|
|
|
|
// Add PTR record for non-wildcard domains
|
|
if !isWildcard {
|
|
s.ptrRecords[ip.String()] = domain
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// AddPTRRecord adds a PTR record mapping an IP address to a domain name
|
|
// ip should be a valid IPv4 or IPv6 address
|
|
// domain should be in FQDN format (e.g., "example.com.")
|
|
func (s *DNSRecordStore) AddPTRRecord(ip net.IP, domain string) error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
// Ensure domain ends with a dot (FQDN format)
|
|
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
|
domain = domain + "."
|
|
}
|
|
|
|
// Normalize domain to lowercase FQDN
|
|
domain = strings.ToLower(dns.Fqdn(domain))
|
|
|
|
// Store PTR record using IP string as key
|
|
s.ptrRecords[ip.String()] = domain
|
|
|
|
return nil
|
|
}
|
|
|
|
// RemoveRecord removes a specific DNS record mapping
|
|
// If ip is nil, removes all records for the domain (including wildcards)
|
|
// Automatically removes corresponding PTR records for non-wildcard domains
|
|
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if len(domain) == 0 || domain[len(domain)-1] != '.' {
|
|
domain = domain + "."
|
|
}
|
|
domain = strings.ToLower(dns.Fqdn(domain))
|
|
isWildcard := strings.ContainsAny(domain, "*?")
|
|
|
|
// Choose the appropriate map
|
|
m := s.exact
|
|
if isWildcard {
|
|
m = s.wildcards
|
|
}
|
|
|
|
rs := m[domain]
|
|
if rs == nil {
|
|
return
|
|
}
|
|
|
|
if ip == nil {
|
|
// Remove all records for this domain
|
|
if !isWildcard {
|
|
for _, ipAddr := range rs.A {
|
|
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
|
delete(s.ptrRecords, ipAddr.String())
|
|
}
|
|
}
|
|
for _, ipAddr := range rs.AAAA {
|
|
if ptrDomain, exists := s.ptrRecords[ipAddr.String()]; exists && ptrDomain == domain {
|
|
delete(s.ptrRecords, ipAddr.String())
|
|
}
|
|
}
|
|
}
|
|
delete(m, domain)
|
|
return
|
|
}
|
|
|
|
// Remove specific IP
|
|
if ip.To4() != nil {
|
|
rs.A = removeIP(rs.A, ip)
|
|
if !isWildcard {
|
|
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
|
delete(s.ptrRecords, ip.String())
|
|
}
|
|
}
|
|
} else {
|
|
rs.AAAA = removeIP(rs.AAAA, ip)
|
|
if !isWildcard {
|
|
if ptrDomain, exists := s.ptrRecords[ip.String()]; exists && ptrDomain == domain {
|
|
delete(s.ptrRecords, ip.String())
|
|
}
|
|
}
|
|
}
|
|
|
|
// Clean up empty record sets
|
|
if len(rs.A) == 0 && len(rs.AAAA) == 0 {
|
|
delete(m, domain)
|
|
}
|
|
}
|
|
|
|
// RemovePTRRecord removes a PTR record for an IP address
|
|
func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
delete(s.ptrRecords, ip.String())
|
|
}
|
|
|
|
// GetRecords returns all IP addresses for a domain and record type.
|
|
// The second return value indicates whether the domain exists at all
|
|
// (true = domain exists, use NODATA if no records; false = NXDOMAIN).
|
|
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) ([]net.IP, bool) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
domain = strings.ToLower(dns.Fqdn(domain))
|
|
|
|
// Check exact match first
|
|
if rs, exists := s.exact[domain]; exists {
|
|
var ips []net.IP
|
|
if recordType == RecordTypeA {
|
|
ips = rs.A
|
|
} else {
|
|
ips = rs.AAAA
|
|
}
|
|
if len(ips) > 0 {
|
|
out := make([]net.IP, len(ips))
|
|
copy(out, ips)
|
|
return out, true
|
|
}
|
|
// Domain exists but no records of this type
|
|
return nil, true
|
|
}
|
|
|
|
// Check wildcard matches
|
|
var records []net.IP
|
|
matched := false
|
|
for pattern, rs := range s.wildcards {
|
|
if !matchWildcard(pattern, domain) {
|
|
continue
|
|
}
|
|
matched = true
|
|
if recordType == RecordTypeA {
|
|
records = append(records, rs.A...)
|
|
} else {
|
|
records = append(records, rs.AAAA...)
|
|
}
|
|
}
|
|
|
|
if !matched {
|
|
return nil, false
|
|
}
|
|
if len(records) == 0 {
|
|
return nil, true
|
|
}
|
|
out := make([]net.IP, len(records))
|
|
copy(out, records)
|
|
return out, true
|
|
}
|
|
|
|
// GetPTRRecord returns the domain name for a PTR record query
|
|
// domain should be in reverse DNS format (e.g., "1.0.0.127.in-addr.arpa.")
|
|
func (s *DNSRecordStore) GetPTRRecord(domain string) (string, bool) {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
// Convert reverse DNS format to IP address
|
|
ip := reverseDNSToIP(domain)
|
|
if ip == nil {
|
|
return "", false
|
|
}
|
|
|
|
// Look up the PTR record
|
|
if ptrDomain, ok := s.ptrRecords[ip.String()]; ok {
|
|
return ptrDomain, true
|
|
}
|
|
|
|
return "", false
|
|
}
|
|
|
|
// HasRecord checks if a domain has any records of the specified type
|
|
// Checks both exact matches and wildcard patterns
|
|
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
domain = strings.ToLower(dns.Fqdn(domain))
|
|
|
|
// Check exact match
|
|
if rs, exists := s.exact[domain]; exists {
|
|
if recordType == RecordTypeA && len(rs.A) > 0 {
|
|
return true
|
|
}
|
|
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
|
|
return true
|
|
}
|
|
}
|
|
|
|
// Check wildcard matches
|
|
for pattern, rs := range s.wildcards {
|
|
if !matchWildcard(pattern, domain) {
|
|
continue
|
|
}
|
|
if recordType == RecordTypeA && len(rs.A) > 0 {
|
|
return true
|
|
}
|
|
if recordType == RecordTypeAAAA && len(rs.AAAA) > 0 {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// HasPTRRecord checks if a PTR record exists for the given reverse DNS domain
|
|
func (s *DNSRecordStore) HasPTRRecord(domain string) bool {
|
|
s.mu.RLock()
|
|
defer s.mu.RUnlock()
|
|
|
|
// Convert reverse DNS format to IP address
|
|
ip := reverseDNSToIP(domain)
|
|
if ip == nil {
|
|
return false
|
|
}
|
|
|
|
_, ok := s.ptrRecords[ip.String()]
|
|
return ok
|
|
}
|
|
|
|
// Clear removes all records from the store
|
|
func (s *DNSRecordStore) Clear() {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
s.exact = make(map[string]*recordSet)
|
|
s.wildcards = make(map[string]*recordSet)
|
|
s.ptrRecords = make(map[string]string)
|
|
}
|
|
|
|
// removeIP is a helper function to remove a specific IP from a slice
|
|
func removeIP(ips []net.IP, toRemove net.IP) []net.IP {
|
|
result := make([]net.IP, 0, len(ips))
|
|
for _, ip := range ips {
|
|
if !ip.Equal(toRemove) {
|
|
result = append(result, ip)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// matchWildcard checks if a domain matches a wildcard pattern
|
|
// Pattern supports * (0+ chars) and ? (exactly 1 char)
|
|
// Special case: *.domain.com does not match domain.com itself
|
|
func matchWildcard(pattern, domain string) bool {
|
|
return matchWildcardInternal(pattern, domain, 0, 0)
|
|
}
|
|
|
|
// matchWildcardInternal performs the actual wildcard matching recursively
|
|
func matchWildcardInternal(pattern, domain string, pi, di int) bool {
|
|
plen := len(pattern)
|
|
dlen := len(domain)
|
|
|
|
// Base cases
|
|
if pi == plen && di == dlen {
|
|
return true
|
|
}
|
|
if pi == plen {
|
|
return false
|
|
}
|
|
|
|
// Handle wildcard characters
|
|
if pattern[pi] == '*' {
|
|
// Special case: if pattern starts with "*." and we're at the beginning,
|
|
// ensure we don't match the domain without a prefix
|
|
// e.g., *.autoco.internal should not match autoco.internal
|
|
if pi == 0 && pi+1 < plen && pattern[pi+1] == '.' {
|
|
// The * must match at least one character
|
|
if di == dlen {
|
|
return false
|
|
}
|
|
// Try matching 1 or more characters before the dot
|
|
for i := di + 1; i <= dlen; i++ {
|
|
if matchWildcardInternal(pattern, domain, pi+1, i) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// Normal * matching (0 or more characters)
|
|
// Try matching 0 characters (skip the *)
|
|
if matchWildcardInternal(pattern, domain, pi+1, di) {
|
|
return true
|
|
}
|
|
// Try matching 1+ characters
|
|
if di < dlen {
|
|
return matchWildcardInternal(pattern, domain, pi, di+1)
|
|
}
|
|
return false
|
|
}
|
|
|
|
if pattern[pi] == '?' {
|
|
// ? matches exactly one character
|
|
if di >= dlen {
|
|
return false
|
|
}
|
|
return matchWildcardInternal(pattern, domain, pi+1, di+1)
|
|
}
|
|
|
|
// Regular character - must match exactly
|
|
if di >= dlen || pattern[pi] != domain[di] {
|
|
return false
|
|
}
|
|
|
|
return matchWildcardInternal(pattern, domain, pi+1, di+1)
|
|
}
|
|
|
|
// reverseDNSToIP converts a reverse DNS query name to an IP address
|
|
// Supports both IPv4 (in-addr.arpa) and IPv6 (ip6.arpa) formats
|
|
func reverseDNSToIP(domain string) net.IP {
|
|
// Normalize to lowercase and ensure FQDN
|
|
domain = strings.ToLower(dns.Fqdn(domain))
|
|
|
|
// Check for IPv4 reverse DNS (in-addr.arpa)
|
|
if strings.HasSuffix(domain, ".in-addr.arpa.") {
|
|
// Remove the suffix
|
|
ipPart := strings.TrimSuffix(domain, ".in-addr.arpa.")
|
|
// Split by dots and reverse
|
|
parts := strings.Split(ipPart, ".")
|
|
if len(parts) != 4 {
|
|
return nil
|
|
}
|
|
// Reverse the octets
|
|
reversed := make([]string, 4)
|
|
for i := 0; i < 4; i++ {
|
|
reversed[i] = parts[3-i]
|
|
}
|
|
// Parse as IP
|
|
return net.ParseIP(strings.Join(reversed, "."))
|
|
}
|
|
|
|
// Check for IPv6 reverse DNS (ip6.arpa)
|
|
if strings.HasSuffix(domain, ".ip6.arpa.") {
|
|
// Remove the suffix
|
|
ipPart := strings.TrimSuffix(domain, ".ip6.arpa.")
|
|
// Split by dots and reverse
|
|
parts := strings.Split(ipPart, ".")
|
|
if len(parts) != 32 {
|
|
return nil
|
|
}
|
|
// Reverse the nibbles and group into 16-bit hex values
|
|
reversed := make([]string, 32)
|
|
for i := 0; i < 32; i++ {
|
|
reversed[i] = parts[31-i]
|
|
}
|
|
// Join into IPv6 format (groups of 4 nibbles separated by colons)
|
|
var ipv6Parts []string
|
|
for i := 0; i < 32; i += 4 {
|
|
ipv6Parts = append(ipv6Parts, reversed[i]+reversed[i+1]+reversed[i+2]+reversed[i+3])
|
|
}
|
|
// Parse as IP
|
|
return net.ParseIP(strings.Join(ipv6Parts, ":"))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// IPToReverseDNS converts an IP address to reverse DNS format
|
|
// Returns the domain name for PTR queries (e.g., "1.0.0.127.in-addr.arpa.")
|
|
func IPToReverseDNS(ip net.IP) string {
|
|
if ip4 := ip.To4(); ip4 != nil {
|
|
// IPv4: reverse octets and append .in-addr.arpa.
|
|
return dns.Fqdn(fmt.Sprintf("%d.%d.%d.%d.in-addr.arpa",
|
|
ip4[3], ip4[2], ip4[1], ip4[0]))
|
|
}
|
|
|
|
if ip6 := ip.To16(); ip6 != nil && ip.To4() == nil {
|
|
// IPv6: expand to 32 nibbles, reverse, and append .ip6.arpa.
|
|
var nibbles []string
|
|
for i := 15; i >= 0; i-- {
|
|
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]&0x0f))
|
|
nibbles = append(nibbles, fmt.Sprintf("%x", ip6[i]>>4))
|
|
}
|
|
return dns.Fqdn(strings.Join(nibbles, ".") + ".ip6.arpa")
|
|
}
|
|
|
|
return ""
|
|
}
|