From 78dc6508a4ef03d814d6c918f17dad3478887d14 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 16 Dec 2025 21:33:41 -0500 Subject: [PATCH] Support wildcard alias records Former-commit-id: cec79bf0147e2f824d38a20306e63b58d8479a1c --- dns/dns_records.go | 203 ++++++++++++++++++++--- dns/dns_records_test.go | 350 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 531 insertions(+), 22 deletions(-) create mode 100644 dns/dns_records_test.go diff --git a/dns/dns_records.go b/dns/dns_records.go index 8d57d68..ed57b77 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -2,6 +2,7 @@ package dns import ( "net" + "strings" "sync" "github.com/miekg/dns" @@ -17,21 +18,26 @@ const ( // DNSRecordStore manages local DNS records for A and AAAA queries type DNSRecordStore struct { - mu sync.RWMutex - aRecords map[string][]net.IP // domain -> list of IPv4 addresses - aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses + mu sync.RWMutex + aRecords map[string][]net.IP // domain -> list of IPv4 addresses + aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses + aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses + aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses } // NewDNSRecordStore creates a new DNS record store func NewDNSRecordStore() *DNSRecordStore { return &DNSRecordStore{ - aRecords: make(map[string][]net.IP), - aaaaRecords: make(map[string][]net.IP), + aRecords: make(map[string][]net.IP), + aaaaRecords: make(map[string][]net.IP), + aWildcards: make(map[string][]net.IP), + aaaaWildcards: make(map[string][]net.IP), } } // AddRecord adds a DNS record mapping (A or AAAA) // domain should be in FQDN format (e.g., "example.com.") +// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char) // ip should be a valid IPv4 or IPv6 address func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { s.mu.Lock() @@ -45,12 +51,23 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { // Normalize domain to lowercase domain = dns.Fqdn(domain) + // Check if domain contains wildcards + isWildcard := strings.ContainsAny(domain, "*?") + if ip.To4() != nil { // IPv4 address - s.aRecords[domain] = append(s.aRecords[domain], ip) + if isWildcard { + s.aWildcards[domain] = append(s.aWildcards[domain], ip) + } else { + s.aRecords[domain] = append(s.aRecords[domain], ip) + } } else if ip.To16() != nil { // IPv6 address - s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip) + if isWildcard { + s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip) + } else { + s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip) + } } else { return &net.ParseError{Type: "IP address", Text: ip.String()} } @@ -59,7 +76,7 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { } // RemoveRecord removes a specific DNS record mapping -// If ip is nil, removes all records for the domain +// If ip is nil, removes all records for the domain (including wildcards) func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { s.mu.Lock() defer s.mu.Unlock() @@ -72,33 +89,60 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) { // Normalize domain to lowercase domain = dns.Fqdn(domain) + // Check if domain contains wildcards + isWildcard := strings.ContainsAny(domain, "*?") + if ip == nil { // Remove all records for this domain - delete(s.aRecords, domain) - delete(s.aaaaRecords, domain) + if isWildcard { + delete(s.aWildcards, domain) + delete(s.aaaaWildcards, domain) + } else { + delete(s.aRecords, domain) + delete(s.aaaaRecords, domain) + } return } if ip.To4() != nil { // Remove specific IPv4 address - if ips, ok := s.aRecords[domain]; ok { - s.aRecords[domain] = removeIP(ips, ip) - if len(s.aRecords[domain]) == 0 { - delete(s.aRecords, domain) + if isWildcard { + if ips, ok := s.aWildcards[domain]; ok { + s.aWildcards[domain] = removeIP(ips, ip) + if len(s.aWildcards[domain]) == 0 { + delete(s.aWildcards, domain) + } + } + } else { + if ips, ok := s.aRecords[domain]; ok { + s.aRecords[domain] = removeIP(ips, ip) + if len(s.aRecords[domain]) == 0 { + delete(s.aRecords, domain) + } } } } else if ip.To16() != nil { // Remove specific IPv6 address - if ips, ok := s.aaaaRecords[domain]; ok { - s.aaaaRecords[domain] = removeIP(ips, ip) - if len(s.aaaaRecords[domain]) == 0 { - delete(s.aaaaRecords, domain) + if isWildcard { + if ips, ok := s.aaaaWildcards[domain]; ok { + s.aaaaWildcards[domain] = removeIP(ips, ip) + if len(s.aaaaWildcards[domain]) == 0 { + delete(s.aaaaWildcards, domain) + } + } + } else { + if ips, ok := s.aaaaRecords[domain]; ok { + s.aaaaRecords[domain] = removeIP(ips, ip) + if len(s.aaaaRecords[domain]) == 0 { + delete(s.aaaaRecords, domain) + } } } } } // GetRecords returns all IP addresses for a domain and record type +// First checks for exact matches, then checks wildcard patterns func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP { s.mu.RLock() defer s.mu.RUnlock() @@ -109,16 +153,45 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net. var records []net.IP switch recordType { case RecordTypeA: + // Check exact match first if ips, ok := s.aRecords[domain]; ok { // Return a copy to prevent external modifications records = make([]net.IP, len(ips)) copy(records, ips) + return records } + // Check wildcard patterns + for pattern, ips := range s.aWildcards { + if matchWildcard(pattern, domain) { + records = append(records, ips...) + } + } + if len(records) > 0 { + // Return a copy + result := make([]net.IP, len(records)) + copy(result, records) + return result + } + case RecordTypeAAAA: + // Check exact match first if ips, ok := s.aaaaRecords[domain]; ok { // Return a copy to prevent external modifications records = make([]net.IP, len(ips)) copy(records, ips) + return records + } + // Check wildcard patterns + for pattern, ips := range s.aaaaWildcards { + if matchWildcard(pattern, domain) { + records = append(records, ips...) + } + } + if len(records) > 0 { + // Return a copy + result := make([]net.IP, len(records)) + copy(result, records) + return result } } @@ -126,6 +199,7 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net. } // HasRecord checks if a domain has any records of the specified type +// Checks both exact matches and wildcard patterns func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { s.mu.RLock() defer s.mu.RUnlock() @@ -135,11 +209,27 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool { switch recordType { case RecordTypeA: - _, ok := s.aRecords[domain] - return ok + // Check exact match + if _, ok := s.aRecords[domain]; ok { + return true + } + // Check wildcard patterns + for pattern := range s.aWildcards { + if matchWildcard(pattern, domain) { + return true + } + } case RecordTypeAAAA: - _, ok := s.aaaaRecords[domain] - return ok + // Check exact match + if _, ok := s.aaaaRecords[domain]; ok { + return true + } + // Check wildcard patterns + for pattern := range s.aaaaWildcards { + if matchWildcard(pattern, domain) { + return true + } + } } return false @@ -152,6 +242,8 @@ func (s *DNSRecordStore) Clear() { s.aRecords = make(map[string][]net.IP) s.aaaaRecords = make(map[string][]net.IP) + s.aWildcards = make(map[string][]net.IP) + s.aaaaWildcards = make(map[string][]net.IP) } // removeIP is a helper function to remove a specific IP from a slice @@ -164,3 +256,70 @@ func removeIP(ips []net.IP, toRemove net.IP) []net.IP { } return result } + +// matchWildcard checks if a domain matches a wildcard pattern +// Pattern supports * (0+ chars) and ? (exactly 1 char) +// Special case: *.domain.com does not match domain.com itself +func matchWildcard(pattern, domain string) bool { + return matchWildcardInternal(pattern, domain, 0, 0) +} + +// matchWildcardInternal performs the actual wildcard matching recursively +func matchWildcardInternal(pattern, domain string, pi, di int) bool { + plen := len(pattern) + dlen := len(domain) + + // Base cases + if pi == plen && di == dlen { + return true + } + if pi == plen { + return false + } + + // Handle wildcard characters + if pattern[pi] == '*' { + // Special case: if pattern starts with "*." and we're at the beginning, + // ensure we don't match the domain without a prefix + // e.g., *.autoco.internal should not match autoco.internal + if pi == 0 && pi+1 < plen && pattern[pi+1] == '.' { + // The * must match at least one character + if di == dlen { + return false + } + // Try matching 1 or more characters before the dot + for i := di + 1; i <= dlen; i++ { + if matchWildcardInternal(pattern, domain, pi+1, i) { + return true + } + } + return false + } + + // Normal * matching (0 or more characters) + // Try matching 0 characters (skip the *) + if matchWildcardInternal(pattern, domain, pi+1, di) { + return true + } + // Try matching 1+ characters + if di < dlen { + return matchWildcardInternal(pattern, domain, pi, di+1) + } + return false + } + + if pattern[pi] == '?' { + // ? matches exactly one character + if di >= dlen { + return false + } + return matchWildcardInternal(pattern, domain, pi+1, di+1) + } + + // Regular character - must match exactly + if di >= dlen || pattern[pi] != domain[di] { + return false + } + + return matchWildcardInternal(pattern, domain, pi+1, di+1) +} \ No newline at end of file diff --git a/dns/dns_records_test.go b/dns/dns_records_test.go new file mode 100644 index 0000000..0bb18a1 --- /dev/null +++ b/dns/dns_records_test.go @@ -0,0 +1,350 @@ +package dns + +import ( + "net" + "testing" +) + +func TestWildcardMatching(t *testing.T) { + tests := []struct { + name string + pattern string + domain string + expected bool + }{ + // Basic wildcard tests + { + name: "*.autoco.internal matches host.autoco.internal", + pattern: "*.autoco.internal.", + domain: "host.autoco.internal.", + expected: true, + }, + { + name: "*.autoco.internal matches longerhost.autoco.internal", + pattern: "*.autoco.internal.", + domain: "longerhost.autoco.internal.", + expected: true, + }, + { + name: "*.autoco.internal matches sub.host.autoco.internal", + pattern: "*.autoco.internal.", + domain: "sub.host.autoco.internal.", + expected: true, + }, + { + name: "*.autoco.internal does NOT match autoco.internal", + pattern: "*.autoco.internal.", + domain: "autoco.internal.", + expected: false, + }, + + // Question mark wildcard tests + { + name: "host-0?.autoco.internal matches host-01.autoco.internal", + pattern: "host-0?.autoco.internal.", + domain: "host-01.autoco.internal.", + expected: true, + }, + { + name: "host-0?.autoco.internal matches host-0a.autoco.internal", + pattern: "host-0?.autoco.internal.", + domain: "host-0a.autoco.internal.", + expected: true, + }, + { + name: "host-0?.autoco.internal does NOT match host-0.autoco.internal", + pattern: "host-0?.autoco.internal.", + domain: "host-0.autoco.internal.", + expected: false, + }, + { + name: "host-0?.autoco.internal does NOT match host-012.autoco.internal", + pattern: "host-0?.autoco.internal.", + domain: "host-012.autoco.internal.", + expected: false, + }, + + // Combined wildcard tests + { + name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal", + pattern: "*.host-0?.autoco.internal.", + domain: "sub.host-01.autoco.internal.", + expected: true, + }, + { + name: "*.host-0?.autoco.internal matches prefix.host-0a.autoco.internal", + pattern: "*.host-0?.autoco.internal.", + domain: "prefix.host-0a.autoco.internal.", + expected: true, + }, + { + name: "*.host-0?.autoco.internal does NOT match host-01.autoco.internal", + pattern: "*.host-0?.autoco.internal.", + domain: "host-01.autoco.internal.", + expected: false, + }, + + // Multiple asterisks + { + name: "*.*. autoco.internal matches any.thing.autoco.internal", + pattern: "*.*.autoco.internal.", + domain: "any.thing.autoco.internal.", + expected: true, + }, + { + name: "*.*.autoco.internal does NOT match single.autoco.internal", + pattern: "*.*.autoco.internal.", + domain: "single.autoco.internal.", + expected: false, + }, + + // Asterisk in middle + { + name: "host-*.autoco.internal matches host-anything.autoco.internal", + pattern: "host-*.autoco.internal.", + domain: "host-anything.autoco.internal.", + expected: true, + }, + { + name: "host-*.autoco.internal matches host-.autoco.internal (empty match)", + pattern: "host-*.autoco.internal.", + domain: "host-.autoco.internal.", + expected: true, + }, + + // Multiple question marks + { + name: "host-??.autoco.internal matches host-01.autoco.internal", + pattern: "host-??.autoco.internal.", + domain: "host-01.autoco.internal.", + expected: true, + }, + { + name: "host-??.autoco.internal does NOT match host-1.autoco.internal", + pattern: "host-??.autoco.internal.", + domain: "host-1.autoco.internal.", + expected: false, + }, + + // Exact match (no wildcards) + { + name: "exact.autoco.internal matches exact.autoco.internal", + pattern: "exact.autoco.internal.", + domain: "exact.autoco.internal.", + expected: true, + }, + { + name: "exact.autoco.internal does NOT match other.autoco.internal", + pattern: "exact.autoco.internal.", + domain: "other.autoco.internal.", + expected: false, + }, + + // Edge cases + { + name: "* matches anything", + pattern: "*", + domain: "anything.at.all.", + expected: true, + }, + { + name: "*.* matches multi.level.", + pattern: "*.*", + domain: "multi.level.", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := matchWildcard(tt.pattern, tt.domain) + if result != tt.expected { + t.Errorf("matchWildcard(%q, %q) = %v, want %v", tt.pattern, tt.domain, result, tt.expected) + } + }) + } +} + +func TestDNSRecordStoreWildcard(t *testing.T) { + store := NewDNSRecordStore() + + // Add wildcard records + wildcardIP := net.ParseIP("10.0.0.1") + err := store.AddRecord("*.autoco.internal", wildcardIP) + if err != nil { + t.Fatalf("Failed to add wildcard record: %v", err) + } + + // Add exact record + exactIP := net.ParseIP("10.0.0.2") + err = store.AddRecord("exact.autoco.internal", exactIP) + if err != nil { + t.Fatalf("Failed to add exact record: %v", err) + } + + // Test exact match takes precedence + ips := store.GetRecords("exact.autoco.internal.", RecordTypeA) + if len(ips) != 1 { + t.Errorf("Expected 1 IP for exact match, got %d", len(ips)) + } + if !ips[0].Equal(exactIP) { + t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0]) + } + + // Test wildcard match + ips = store.GetRecords("host.autoco.internal.", RecordTypeA) + if len(ips) != 1 { + t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips)) + } + if !ips[0].Equal(wildcardIP) { + t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0]) + } + + // Test non-match (base domain) + ips = store.GetRecords("autoco.internal.", RecordTypeA) + if len(ips) != 0 { + t.Errorf("Expected 0 IPs for base domain, got %d", len(ips)) + } +} + +func TestDNSRecordStoreComplexWildcard(t *testing.T) { + store := NewDNSRecordStore() + + // Add complex wildcard pattern + ip1 := net.ParseIP("10.0.0.1") + err := store.AddRecord("*.host-0?.autoco.internal", ip1) + if err != nil { + t.Fatalf("Failed to add wildcard record: %v", err) + } + + // Test matching domain + ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA) + if len(ips) != 1 { + t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips)) + } + if len(ips) > 0 && !ips[0].Equal(ip1) { + t.Errorf("Expected IP %v, got %v", ip1, ips[0]) + } + + // Test non-matching domain (missing prefix) + ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA) + if len(ips) != 0 { + t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips)) + } + + // Test non-matching domain (wrong ? position) + ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA) + if len(ips) != 0 { + t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips)) + } +} + +func TestDNSRecordStoreRemoveWildcard(t *testing.T) { + store := NewDNSRecordStore() + + // Add wildcard record + ip := net.ParseIP("10.0.0.1") + err := store.AddRecord("*.autoco.internal", ip) + if err != nil { + t.Fatalf("Failed to add wildcard record: %v", err) + } + + // Verify it exists + ips := store.GetRecords("host.autoco.internal.", RecordTypeA) + if len(ips) != 1 { + t.Errorf("Expected 1 IP before removal, got %d", len(ips)) + } + + // Remove wildcard record + store.RemoveRecord("*.autoco.internal", nil) + + // Verify it's gone + ips = store.GetRecords("host.autoco.internal.", RecordTypeA) + if len(ips) != 0 { + t.Errorf("Expected 0 IPs after removal, got %d", len(ips)) + } +} + +func TestDNSRecordStoreMultipleWildcards(t *testing.T) { + store := NewDNSRecordStore() + + // Add multiple wildcard patterns that don't overlap + ip1 := net.ParseIP("10.0.0.1") + ip2 := net.ParseIP("10.0.0.2") + ip3 := net.ParseIP("10.0.0.3") + + err := store.AddRecord("*.prod.autoco.internal", ip1) + if err != nil { + t.Fatalf("Failed to add first wildcard: %v", err) + } + + err = store.AddRecord("*.dev.autoco.internal", ip2) + if err != nil { + t.Fatalf("Failed to add second wildcard: %v", err) + } + + // Add a broader wildcard that matches both + err = store.AddRecord("*.autoco.internal", ip3) + if err != nil { + t.Fatalf("Failed to add third wildcard: %v", err) + } + + // Test domain matching only the prod pattern and the broad pattern + ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA) + if len(ips) != 2 { + t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips)) + } + + // Test domain matching only the dev pattern and the broad pattern + ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA) + if len(ips) != 2 { + t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips)) + } + + // Test domain matching only the broad pattern + ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA) + if len(ips) != 1 { + t.Errorf("Expected 1 IP (broad only), got %d", len(ips)) + } +} + +func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { + store := NewDNSRecordStore() + + // Add IPv6 wildcard record + ip := net.ParseIP("2001:db8::1") + err := store.AddRecord("*.autoco.internal", ip) + if err != nil { + t.Fatalf("Failed to add IPv6 wildcard record: %v", err) + } + + // Test wildcard match for IPv6 + ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA) + if len(ips) != 1 { + t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips)) + } + if len(ips) > 0 && !ips[0].Equal(ip) { + t.Errorf("Expected IPv6 %v, got %v", ip, ips[0]) + } +} + +func TestHasRecordWildcard(t *testing.T) { + store := NewDNSRecordStore() + + // Add wildcard record + ip := net.ParseIP("10.0.0.1") + err := store.AddRecord("*.autoco.internal", ip) + if err != nil { + t.Fatalf("Failed to add wildcard record: %v", err) + } + + // Test HasRecord with wildcard match + if !store.HasRecord("host.autoco.internal.", RecordTypeA) { + t.Error("Expected HasRecord to return true for wildcard match") + } + + // Test HasRecord with non-match + if store.HasRecord("autoco.internal.", RecordTypeA) { + t.Error("Expected HasRecord to return false for base domain") + } +} \ No newline at end of file