Lowercase all domains before matching

Former-commit-id: 8f8872aa47
This commit is contained in:
Owen
2026-01-30 14:53:25 -08:00
parent 1869e70894
commit 7fc3c7088e
2 changed files with 71 additions and 6 deletions

View File

@@ -48,8 +48,8 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
domain = domain + "." domain = domain + "."
} }
// Normalize domain to lowercase // Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain) domain = strings.ToLower(dns.Fqdn(domain))
// Check if domain contains wildcards // Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?") isWildcard := strings.ContainsAny(domain, "*?")
@@ -86,8 +86,8 @@ func (s *DNSRecordStore) RemoveRecord(domain string, ip net.IP) {
domain = domain + "." domain = domain + "."
} }
// Normalize domain to lowercase // Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain) domain = strings.ToLower(dns.Fqdn(domain))
// Check if domain contains wildcards // Check if domain contains wildcards
isWildcard := strings.ContainsAny(domain, "*?") isWildcard := strings.ContainsAny(domain, "*?")
@@ -148,7 +148,7 @@ func (s *DNSRecordStore) GetRecords(domain string, recordType RecordType) []net.
defer s.mu.RUnlock() defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN // Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain) domain = strings.ToLower(dns.Fqdn(domain))
var records []net.IP var records []net.IP
switch recordType { switch recordType {
@@ -205,7 +205,7 @@ func (s *DNSRecordStore) HasRecord(domain string, recordType RecordType) bool {
defer s.mu.RUnlock() defer s.mu.RUnlock()
// Normalize domain to lowercase FQDN // Normalize domain to lowercase FQDN
domain = dns.Fqdn(domain) domain = strings.ToLower(dns.Fqdn(domain))
switch recordType { switch recordType {
case RecordTypeA: case RecordTypeA:

View File

@@ -348,3 +348,68 @@ func TestHasRecordWildcard(t *testing.T) {
t.Error("Expected HasRecord to return false for base domain") t.Error("Expected HasRecord to return false for base domain")
} }
} }
func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
store := NewDNSRecordStore()
// Add record with mixed case
ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("MyHost.AutoCo.Internal", ip)
if err != nil {
t.Fatalf("Failed to add mixed case record: %v", err)
}
// Test lookup with different cases
testCases := []string{
"myhost.autoco.internal.",
"MYHOST.AUTOCO.INTERNAL.",
"MyHost.AutoCo.Internal.",
"mYhOsT.aUtOcO.iNtErNaL.",
}
for _, domain := range testCases {
ips := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for domain %q, got %d", domain, len(ips))
}
if len(ips) > 0 && !ips[0].Equal(ip) {
t.Errorf("Expected IP %v for domain %q, got %v", ip, domain, ips[0])
}
}
// Test wildcard with mixed case
wildcardIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("*.Example.Com", wildcardIP)
if err != nil {
t.Fatalf("Failed to add mixed case wildcard: %v", err)
}
wildcardTestCases := []string{
"host.example.com.",
"HOST.EXAMPLE.COM.",
"Host.Example.Com.",
"HoSt.ExAmPlE.CoM.",
}
for _, domain := range wildcardTestCases {
ips := store.GetRecords(domain, RecordTypeA)
if len(ips) != 1 {
t.Errorf("Expected 1 IP for wildcard domain %q, got %d", domain, len(ips))
}
if len(ips) > 0 && !ips[0].Equal(wildcardIP) {
t.Errorf("Expected IP %v for wildcard domain %q, got %v", wildcardIP, domain, ips[0])
}
}
// Test removal with different case
store.RemoveRecord("MYHOST.AUTOCO.INTERNAL", nil)
ips := store.GetRecords("myhost.autoco.internal.", RecordTypeA)
if len(ips) != 0 {
t.Errorf("Expected 0 IPs after removal, got %d", len(ips))
}
// Test HasRecord with different case
if !store.HasRecord("HOST.EXAMPLE.COM.", RecordTypeA) {
t.Error("Expected HasRecord to return true for mixed case wildcard match")
}
}