Support wildcard alias records

Former-commit-id: cec79bf014
This commit is contained in:
Owen
2025-12-16 21:33:41 -05:00
parent 7f6c824122
commit 78dc6508a4
2 changed files with 531 additions and 22 deletions

View File

@@ -2,6 +2,7 @@ package dns
import (
"net"
"strings"
"sync"
"github.com/miekg/dns"
@@ -17,21 +18,26 @@ const (
// DNSRecordStore manages local DNS records for A and AAAA queries
type DNSRecordStore struct {
mu sync.RWMutex
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
mu sync.RWMutex
aRecords map[string][]net.IP // domain -> list of IPv4 addresses
aaaaRecords map[string][]net.IP // domain -> list of IPv6 addresses
aWildcards map[string][]net.IP // wildcard pattern -> list of IPv4 addresses
aaaaWildcards map[string][]net.IP // wildcard pattern -> list of IPv6 addresses
}
// NewDNSRecordStore creates a new DNS record store
func NewDNSRecordStore() *DNSRecordStore {
return &DNSRecordStore{
aRecords: make(map[string][]net.IP),
aaaaRecords: make(map[string][]net.IP),
aRecords: make(map[string][]net.IP),
aaaaRecords: make(map[string][]net.IP),
aWildcards: make(map[string][]net.IP),
aaaaWildcards: make(map[string][]net.IP),
}
}
// AddRecord adds a DNS record mapping (A or AAAA)
// domain should be in FQDN format (e.g., "example.com.")
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
// ip should be a valid IPv4 or IPv6 address
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
s.mu.Lock()
@@ -45,12 +51,23 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
// Normalize domain to lowercase
domain = dns.Fqdn(domain)
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?")
if ip.To4() != nil {
// IPv4 address
s.aRecords[domain] = append(s.aRecords[domain], ip)
if isWildcard {
s.aWildcards[domain] = append(s.aWildcards[domain], ip)
} else {
s.aRecords[domain] = append(s.aRecords[domain], ip)
}
} else if ip.To16() != nil {
// IPv6 address
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
if isWildcard {
s.aaaaWildcards[domain] = append(s.aaaaWildcards[domain], ip)
} else {
s.aaaaRecords[domain] = append(s.aaaaRecords[domain], ip)
}
} else {
return &net.ParseError{Type: "IP address", Text: ip.String()}
}
@@ -59,7 +76,7 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
}
// RemoveRecord removes a specific DNS record mapping
// If ip is nil, removes all records for the domain
// If ip is nil, removes all records for the domain (including wildcards)
func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -72,33 +89,60 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
// Normalize domain to lowercase
domain = dns.Fqdn(domain)
// Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?")
if ip == nil {
// Remove all records for this domain
delete(s.aRecords, domain)
delete(s.aaaaRecords, domain)
if isWildcard {
delete(s.aWildcards, domain)
delete(s.aaaaWildcards, domain)
} else {
delete(s.aRecords, domain)
delete(s.aaaaRecords, domain)
}
return
}
if ip.To4() != nil {
// Remove specific IPv4 address
if ips, ok := s.aRecords[domain]; ok {
s.aRecords[domain] = removeIP(ips, ip)
if len(s.aRecords[domain]) == 0 {
delete(s.aRecords, domain)
if isWildcard {
if ips, ok := s.aWildcards[domain]; ok {
s.aWildcards[domain] = removeIP(ips, ip)
if len(s.aWildcards[domain]) == 0 {
delete(s.aWildcards, domain)
}
}
} else {
if ips, ok := s.aRecords[domain]; ok {
s.aRecords[domain] = removeIP(ips, ip)
if len(s.aRecords[domain]) == 0 {
delete(s.aRecords, domain)
}
}
}
} else if ip.To16() != nil {
// Remove specific IPv6 address
if ips, ok := s.aaaaRecords[domain]; ok {
s.aaaaRecords[domain] = removeIP(ips, ip)
if len(s.aaaaRecords[domain]) == 0 {
delete(s.aaaaRecords, domain)
if isWildcard {
if ips, ok := s.aaaaWildcards[domain]; ok {
s.aaaaWildcards[domain] = removeIP(ips, ip)
if len(s.aaaaWildcards[domain]) == 0 {
delete(s.aaaaWildcards, domain)
}
}
} else {
if ips, ok := s.aaaaRecords[domain]; ok {
s.aaaaRecords[domain] = removeIP(ips, ip)
if len(s.aaaaRecords[domain]) == 0 {
delete(s.aaaaRecords, domain)
}
}
}
}
}
// GetRecords returns all IP addresses for a domain and record type
// First checks for exact matches, then checks wildcard patterns
func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.IP {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -109,16 +153,45 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.
var records []net.IP
switch recordType {
case RecordTypeA:
// Check exact match first
if ips, ok := s.aRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
return records
}
// Check wildcard patterns
for pattern, ips := range s.aWildcards {
if matchWildcard(pattern, domain) {
records = append(records, ips...)
}
}
if len(records) > 0 {
// Return a copy
result := make([]net.IP, len(records))
copy(result, records)
return result
}
case RecordTypeAAAA:
// Check exact match first
if ips, ok := s.aaaaRecords[domain]; ok {
// Return a copy to prevent external modifications
records = make([]net.IP, len(ips))
copy(records, ips)
return records
}
// Check wildcard patterns
for pattern, ips := range s.aaaaWildcards {
if matchWildcard(pattern, domain) {
records = append(records, ips...)
}
}
if len(records) > 0 {
// Return a copy
result := make([]net.IP, len(records))
copy(result, records)
return result
}
}
@@ -126,6 +199,7 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.
}
// HasRecord checks if a domain has any records of the specified type
// Checks both exact matches and wildcard patterns
func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -135,11 +209,27 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
switch recordType {
case RecordTypeA:
_, ok := s.aRecords[domain]
return ok
// Check exact match
if _, ok := s.aRecords[domain]; ok {
return true
}
// Check wildcard patterns
for pattern := range s.aWildcards {
if matchWildcard(pattern, domain) {
return true
}
}
case RecordTypeAAAA:
_, ok := s.aaaaRecords[domain]
return ok
// Check exact match
if _, ok := s.aaaaRecords[domain]; ok {
return true
}
// Check wildcard patterns
for pattern := range s.aaaaWildcards {
if matchWildcard(pattern, domain) {
return true
}
}
}
return false
@@ -152,6 +242,8 @@ func (s *DNSRecordStore) Clear() {
s.aRecords = make(map[string][]net.IP)
s.aaaaRecords = make(map[string][]net.IP)
s.aWildcards = make(map[string][]net.IP)
s.aaaaWildcards = make(map[string][]net.IP)
}
// removeIP is a helper function to remove a specific IP from a slice
@@ -164,3 +256,70 @@ func removeIP(ips []net.IP, toRemove net.IP) []net.IP {
}
return result
}
// matchWildcard checks if a domain matches a wildcard pattern
// Pattern supports * (0+ chars) and ? (exactly 1 char)
// Special case: *.domain.com does not match domain.com itself
func matchWildcard(pattern, domain string) bool {
return matchWildcardInternal(pattern, domain, 0, 0)
}
// matchWildcardInternal performs the actual wildcard matching recursively
func matchWildcardInternal(pattern, domain string, pi, di int) bool {
plen := len(pattern)
dlen := len(domain)
// Base cases
if pi == plen && di == dlen {
return true
}
if pi == plen {
return false
}
// Handle wildcard characters
if pattern[pi] == '*' {
// Special case: if pattern starts with "*." and we're at the beginning,
// ensure we don't match the domain without a prefix
// e.g., *.autoco.internal should not match autoco.internal
if pi == 0 && pi+1 < plen && pattern[pi+1] == '.' {
// The * must match at least one character
if di == dlen {
return false
}
// Try matching 1 or more characters before the dot
for i := di + 1; i <= dlen; i++ {
if matchWildcardInternal(pattern, domain, pi+1, i) {
return true
}
}
return false
}
// Normal * matching (0 or more characters)
// Try matching 0 characters (skip the *)
if matchWildcardInternal(pattern, domain, pi+1, di) {
return true
}
// Try matching 1+ characters
if di < dlen {
return matchWildcardInternal(pattern, domain, pi, di+1)
}
return false
}
if pattern[pi] == '?' {
// ? matches exactly one character
if di >= dlen {
return false
}
return matchWildcardInternal(pattern, domain, pi+1, di+1)
}
// Regular character - must match exactly
if di >= dlen || pattern[pi] != domain[di] {
return false
}
return matchWildcardInternal(pattern, domain, pi+1, di+1)
}

350
dns/dns_records_test.go Normal file
View File

@@ -0,0 +1,350 @@
package dns
import (
"net"
"testing"
)
func TestWildcardMatching(t *testing.T) {
tests := []struct {
name string
pattern string
domain string
expected bool
}{
// Basic wildcard tests
{
name: "*.autoco.internal matches host.autoco.internal",
pattern: "*.autoco.internal.",
domain: "host.autoco.internal.",
expected: true,
},
{
name: "*.autoco.internal matches longerhost.autoco.internal",
pattern: "*.autoco.internal.",
domain: "longerhost.autoco.internal.",
expected: true,
},
{
name: "*.autoco.internal matches sub.host.autoco.internal",
pattern: "*.autoco.internal.",
domain: "sub.host.autoco.internal.",
expected: true,
},
{
name: "*.autoco.internal does NOT match autoco.internal",
pattern: "*.autoco.internal.",
domain: "autoco.internal.",
expected: false,
},
// Question mark wildcard tests
{
name: "host-0?.autoco.internal matches host-01.autoco.internal",
pattern: "host-0?.autoco.internal.",
domain: "host-01.autoco.internal.",
expected: true,
},
{
name: "host-0?.autoco.internal matches host-0a.autoco.internal",
pattern: "host-0?.autoco.internal.",
domain: "host-0a.autoco.internal.",
expected: true,
},
{
name: "host-0?.autoco.internal does NOT match host-0.autoco.internal",
pattern: "host-0?.autoco.internal.",
domain: "host-0.autoco.internal.",
expected: false,
},
{
name: "host-0?.autoco.internal does NOT match host-012.autoco.internal",
pattern: "host-0?.autoco.internal.",
domain: "host-012.autoco.internal.",
expected: false,
},
// Combined wildcard tests
{
name: "*.host-0?.autoco.internal matches sub.host-01.autoco.internal",
pattern: "*.host-0?.autoco.internal.",
domain: "sub.host-01.autoco.internal.",
expected: true,
},
{
name: "*.host-0?.autoco.internal matches prefix.host-0a.autoco.internal",
pattern: "*.host-0?.autoco.internal.",
domain: "prefix.host-0a.autoco.internal.",
expected: true,
},
{
name: "*.host-0?.autoco.internal does NOT match host-01.autoco.internal",
pattern: "*.host-0?.autoco.internal.",
domain: "host-01.autoco.internal.",
expected: false,
},
// Multiple asterisks
{
name: "*.*. autoco.internal matches any.thing.autoco.internal",
pattern: "*.*.autoco.internal.",
domain: "any.thing.autoco.internal.",
expected: true,
},
{
name: "*.*.autoco.internal does NOT match single.autoco.internal",
pattern: "*.*.autoco.internal.",
domain: "single.autoco.internal.",
expected: false,
},
// Asterisk in middle
{
name: "host-*.autoco.internal matches host-anything.autoco.internal",
pattern: "host-*.autoco.internal.",
domain: "host-anything.autoco.internal.",
expected: true,
},
{
name: "host-*.autoco.internal matches host-.autoco.internal (empty match)",
pattern: "host-*.autoco.internal.",
domain: "host-.autoco.internal.",
expected: true,
},
// Multiple question marks
{
name: "host-??.autoco.internal matches host-01.autoco.internal",
pattern: "host-??.autoco.internal.",
domain: "host-01.autoco.internal.",
expected: true,
},
{
name: "host-??.autoco.internal does NOT match host-1.autoco.internal",
pattern: "host-??.autoco.internal.",
domain: "host-1.autoco.internal.",
expected: false,
},
// Exact match (no wildcards)
{
name: "exact.autoco.internal matches exact.autoco.internal",
pattern: "exact.autoco.internal.",
domain: "exact.autoco.internal.",
expected: true,
},
{
name: "exact.autoco.internal does NOT match other.autoco.internal",
pattern: "exact.autoco.internal.",
domain: "other.autoco.internal.",
expected: false,
},
// Edge cases
{
name: "* matches anything",
pattern: "*",
domain: "anything.at.all.",
expected: true,
},
{
name: "*.* matches multi.level.",
pattern: "*.*",
domain: "multi.level.",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchWildcard(tt.pattern, tt.domain)
if result != tt.expected {
t.Errorf("matchWildcard(%q, %q) = %v, want %v", tt.pattern, tt.domain, result, tt.expected)
}
})
}
}
func TestDNSRecordStoreWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard records
wildcardIP := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", wildcardIP)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Add exact record
exactIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("exact.autoco.internal", exactIP)
if err != nil {
t.Fatalf("Failed to add exact record: %v", err)
}
// Test exact match takes precedence
ips := store.GetRecords("exact.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for exact match, got %d", len(ips))
}
if !ips[0].Equal(exactIP) {
t.Errorf("Expected exact IP %v, got %v", exactIP, ips[0])
}
// Test wildcard match
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for wildcard match, got %d", len(ips))
}
if !ips[0].Equal(wildcardIP) {
t.Errorf("Expected wildcard IP %v, got %v", wildcardIP, ips[0])
}
// Test non-match (base domain)
ips = store.GetRecords("autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs for base domain, got %d", len(ips))
}
}
func TestDNSRecordStoreComplexWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add complex wildcard pattern
ip1 := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.host-0?.autoco.internal", ip1)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Test matching domain
ips := store.GetRecords("sub.host-01.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for complex wildcard match, got %d", len(ips))
}
if len(ips) > 0 && !ips[0].Equal(ip1) {
t.Errorf("Expected IP %v, got %v", ip1, ips[0])
}
// Test non-matching domain (missing prefix)
ips = store.GetRecords("host-01.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain without prefix, got %d", len(ips))
}
// Test non-matching domain (wrong ? position)
ips = store.GetRecords("sub.host-012.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs for domain with wrong ? match, got %d", len(ips))
}
}
func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard record
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Verify it exists
ips := store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP before removal, got %d", len(ips))
}
// Remove wildcard record
store.RemoveRecord("*.autoco.internal", nil)
// Verify it's gone
ips = store.GetRecords("host.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
}
}
func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
store := NewDNSRecordStore()
// Add multiple wildcard patterns that don't overlap
ip1 := net.ParseIP("10.0.0.1")
ip2 := net.ParseIP("10.0.0.2")
ip3 := net.ParseIP("10.0.0.3")
err := store.AddRecord("*.prod.autoco.internal", ip1)
if err != nil {
t.Fatalf("Failed to add first wildcard: %v", err)
}
err = store.AddRecord("*.dev.autoco.internal", ip2)
if err != nil {
t.Fatalf("Failed to add second wildcard: %v", err)
}
// Add a broader wildcard that matches both
err = store.AddRecord("*.autoco.internal", ip3)
if err != nil {
t.Fatalf("Failed to add third wildcard: %v", err)
}
// Test domain matching only the prod pattern and the broad pattern
ips := store.GetRecords("host.prod.autoco.internal.", RecordTypeA)
if len(ips) != 2 {
t.Errorf("Expected 2 IPs (prod + broad), got %d", len(ips))
}
// Test domain matching only the dev pattern and the broad pattern
ips = store.GetRecords("service.dev.autoco.internal.", RecordTypeA)
if len(ips) != 2 {
t.Errorf("Expected 2 IPs (dev + broad), got %d", len(ips))
}
// Test domain matching only the broad pattern
ips = store.GetRecords("host.test.autoco.internal.", RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP (broad only), got %d", len(ips))
}
}
func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add IPv6 wildcard record
ip := net.ParseIP("2001:db8::1")
err := store.AddRecord("*.autoco.internal", ip)
if err != nil {
t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
}
// Test wildcard match for IPv6
ips := store.GetRecords("host.autoco.internal.", RecordTypeAAAA)
if len(ips) != 1 {
t.Errorf("Expected 1 IPv6 for wildcard match, got %d", len(ips))
}
if len(ips) > 0 && !ips[0].Equal(ip) {
t.Errorf("Expected IPv6 %v, got %v", ip, ips[0])
}
}
func TestHasRecordWildcard(t *testing.T) {
store := NewDNSRecordStore()
// Add wildcard record
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip)
if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err)
}
// Test HasRecord with wildcard match
if !store.HasRecord("host.autoco.internal.", RecordTypeA) {
t.Error("Expected HasRecord to return true for wildcard match")
}
// Test HasRecord with non-match
if store.HasRecord("autoco.internal.", RecordTypeA) {
t.Error("Expected HasRecord to return false for base domain")
}
}