diff --git a/dns/dns_proxy.go b/dns/dns_proxy.go index 27770e4..7b7858c 100644 --- a/dns/dns_proxy.go +++ b/dns/dns_proxy.go @@ -721,8 +721,8 @@ func (p *DNSProxy) runPacketSender() { // AddDNSRecord adds a DNS record to the local store // domain should be a domain name (e.g., "example.com" or "example.com.") // ip should be a valid IPv4 or IPv6 address -func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error { - return p.recordStore.AddRecord(domain, ip) +func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP, siteId int) error { + return p.recordStore.AddRecord(domain, ip, siteId) } // RemoveDNSRecord removes a DNS record from the local store diff --git a/dns/dns_proxy_test.go b/dns/dns_proxy_test.go index 4a1d9f9..9eecad7 100644 --- a/dns/dns_proxy_test.go +++ b/dns/dns_proxy_test.go @@ -14,7 +14,7 @@ func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) { // Add an A record for a domain (no AAAA record) ip := net.ParseIP("10.0.0.1") - err := proxy.recordStore.AddRecord("myservice.internal", ip) + err := proxy.recordStore.AddRecord("myservice.internal", ip, 0) if err != nil { t.Fatalf("Failed to add A record: %v", err) } @@ -64,7 +64,7 @@ func TestCheckLocalRecordsNODATAForA(t *testing.T) { // Add an AAAA record for a domain (no A record) ip := net.ParseIP("2001:db8::1") - err := proxy.recordStore.AddRecord("ipv6only.internal", ip) + err := proxy.recordStore.AddRecord("ipv6only.internal", ip, 0) if err != nil { t.Fatalf("Failed to add AAAA record: %v", err) } @@ -113,7 +113,7 @@ func TestCheckLocalRecordsNonExistentDomain(t *testing.T) { } // Add a record so the store isn't empty - err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1")) + err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1"), 0) if err != nil { t.Fatalf("Failed to add record: %v", err) } @@ -144,7 +144,7 @@ func TestCheckLocalRecordsNODATAWildcard(t *testing.T) { // Add a wildcard A record (no AAAA) ip := net.ParseIP("10.0.0.1") - err := proxy.recordStore.AddRecord("*.wildcard.internal", ip) + err := proxy.recordStore.AddRecord("*.wildcard.internal", ip, 0) if err != nil { t.Fatalf("Failed to add wildcard A record: %v", err) } diff --git a/dns/dns_records.go b/dns/dns_records.go index 10bb7f3..c52c08e 100644 --- a/dns/dns_records.go +++ b/dns/dns_records.go @@ -20,8 +20,9 @@ const ( // recordSet holds A and AAAA records for a single domain or wildcard pattern type recordSet struct { - A []net.IP - AAAA []net.IP + A []net.IP + AAAA []net.IP + SiteId int } // DNSRecordStore manages local DNS records for A, AAAA, and PTR queries. @@ -46,8 +47,9 @@ func NewDNSRecordStore() *DNSRecordStore { // domain should be in FQDN format (e.g., "example.com.") // domain can contain wildcards: * (0+ chars) and ? (exactly 1 char) // ip should be a valid IPv4 or IPv6 address +// siteId is the site that owns this alias/domain // Automatically adds a corresponding PTR record for non-wildcard domains -func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { +func (s *DNSRecordStore) AddRecord(domain string, ip net.IP, siteId int) error { s.mu.Lock() defer s.mu.Unlock() @@ -69,7 +71,7 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { } if m[domain] == nil { - m[domain] = &recordSet{} + m[domain] = &recordSet{SiteId: siteId} } rs := m[domain] if isV4 { @@ -179,6 +181,30 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) { delete(s.ptrRecords, ip.String()) } +// GetSiteIdForDomain returns the siteId associated with the given domain. +// It checks exact matches first, then wildcard patterns. +// The second return value is false if the domain is not found in local records. +func (s *DNSRecordStore) GetSiteIdForDomain(domain string) (int, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + + domain = strings.ToLower(dns.Fqdn(domain)) + + // Check exact match first + if rs, exists := s.exact[domain]; exists { + return rs.SiteId, true + } + + // Check wildcard matches + for pattern, rs := range s.wildcards { + if matchWildcard(pattern, domain) { + return rs.SiteId, true + } + } + + return 0, false +} + // GetRecords returns all IP addresses for a domain and record type. // The second return value indicates whether the domain exists at all // (true = domain exists, use NODATA if no records; false = NXDOMAIN). diff --git a/dns/dns_records_test.go b/dns/dns_records_test.go index 963dcc1..0b4481d 100644 --- a/dns/dns_records_test.go +++ b/dns/dns_records_test.go @@ -170,14 +170,14 @@ func TestDNSRecordStoreWildcard(t *testing.T) { // Add wildcard records wildcardIP := net.ParseIP("10.0.0.1") - err := store.AddRecord("*.autoco.internal", wildcardIP) + err := store.AddRecord("*.autoco.internal", wildcardIP, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } // Add exact record exactIP := net.ParseIP("10.0.0.2") - err = store.AddRecord("exact.autoco.internal", exactIP) + err = store.AddRecord("exact.autoco.internal", exactIP, 0) if err != nil { t.Fatalf("Failed to add exact record: %v", err) } @@ -221,7 +221,7 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) { // Add complex wildcard pattern ip1 := net.ParseIP("10.0.0.1") - err := store.AddRecord("*.host-0?.autoco.internal", ip1) + err := store.AddRecord("*.host-0?.autoco.internal", ip1, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } @@ -262,7 +262,7 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) { // Add wildcard record ip := net.ParseIP("10.0.0.1") - err := store.AddRecord("*.autoco.internal", ip) + err := store.AddRecord("*.autoco.internal", ip, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } @@ -297,18 +297,18 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) { ip2 := net.ParseIP("10.0.0.2") ip3 := net.ParseIP("10.0.0.3") - err := store.AddRecord("*.prod.autoco.internal", ip1) + err := store.AddRecord("*.prod.autoco.internal", ip1, 0) if err != nil { t.Fatalf("Failed to add first wildcard: %v", err) } - err = store.AddRecord("*.dev.autoco.internal", ip2) + err = store.AddRecord("*.dev.autoco.internal", ip2, 0) if err != nil { t.Fatalf("Failed to add second wildcard: %v", err) } // Add a broader wildcard that matches both - err = store.AddRecord("*.autoco.internal", ip3) + err = store.AddRecord("*.autoco.internal", ip3, 0) if err != nil { t.Fatalf("Failed to add third wildcard: %v", err) } @@ -337,7 +337,7 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) { // Add IPv6 wildcard record ip := net.ParseIP("2001:db8::1") - err := store.AddRecord("*.autoco.internal", ip) + err := store.AddRecord("*.autoco.internal", ip, 0) if err != nil { t.Fatalf("Failed to add IPv6 wildcard record: %v", err) } @@ -357,7 +357,7 @@ func TestHasRecordWildcard(t *testing.T) { // Add wildcard record ip := net.ParseIP("10.0.0.1") - err := store.AddRecord("*.autoco.internal", ip) + err := store.AddRecord("*.autoco.internal", ip, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } @@ -378,7 +378,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { // Add record with mixed case ip := net.ParseIP("10.0.0.1") - err := store.AddRecord("MyHost.AutoCo.Internal", ip) + err := store.AddRecord("MyHost.AutoCo.Internal", ip, 0) if err != nil { t.Fatalf("Failed to add mixed case record: %v", err) } @@ -403,7 +403,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) { // Test wildcard with mixed case wildcardIP := net.ParseIP("10.0.0.2") - err = store.AddRecord("*.Example.Com", wildcardIP) + err = store.AddRecord("*.Example.Com", wildcardIP, 0) if err != nil { t.Fatalf("Failed to add mixed case wildcard: %v", err) } @@ -689,7 +689,7 @@ func TestClearPTRRecords(t *testing.T) { store.AddPTRRecord(ip2, "host2.example.com.") // Add some A records too - store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1")) + store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"), 0) // Verify PTR records exist if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") { @@ -719,7 +719,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) { // Add an A record - should automatically add PTR record domain := "host.example.com." ip := net.ParseIP("192.168.1.100") - err := store.AddRecord(domain, ip) + err := store.AddRecord(domain, ip, 0) if err != nil { t.Fatalf("Failed to add A record: %v", err) } @@ -737,7 +737,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) { // Add AAAA record - should also automatically add PTR record domain6 := "ipv6host.example.com." ip6 := net.ParseIP("2001:db8::1") - err = store.AddRecord(domain6, ip6) + err = store.AddRecord(domain6, ip6, 0) if err != nil { t.Fatalf("Failed to add AAAA record: %v", err) } @@ -759,7 +759,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) { // Add an A record (with automatic PTR) domain := "host.example.com." ip := net.ParseIP("192.168.1.100") - store.AddRecord(domain, ip) + store.AddRecord(domain, ip, 0) // Verify PTR exists reverseDomain := "100.1.168.192.in-addr.arpa." @@ -789,8 +789,8 @@ func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) { domain := "host.example.com." ip1 := net.ParseIP("192.168.1.100") ip2 := net.ParseIP("192.168.1.101") - store.AddRecord(domain, ip1) - store.AddRecord(domain, ip2) + store.AddRecord(domain, ip1, 0) + store.AddRecord(domain, ip2, 0) // Verify both PTR records exist reverseDomain1 := "100.1.168.192.in-addr.arpa." @@ -820,7 +820,7 @@ func TestNoPTRForWildcardRecords(t *testing.T) { // Add wildcard record - should NOT create PTR record domain := "*.example.com." ip := net.ParseIP("192.168.1.100") - err := store.AddRecord(domain, ip) + err := store.AddRecord(domain, ip, 0) if err != nil { t.Fatalf("Failed to add wildcard record: %v", err) } @@ -844,7 +844,7 @@ func TestPTRRecordOverwrite(t *testing.T) { // Add first domain with IP domain1 := "host1.example.com." ip := net.ParseIP("192.168.1.100") - store.AddRecord(domain1, ip) + store.AddRecord(domain1, ip, 0) // Verify PTR points to first domain reverseDomain := "100.1.168.192.in-addr.arpa." @@ -858,7 +858,7 @@ func TestPTRRecordOverwrite(t *testing.T) { // Add second domain with same IP - should overwrite PTR domain2 := "host2.example.com." - store.AddRecord(domain2, ip) + store.AddRecord(domain2, ip, 0) // Verify PTR now points to second domain (last one added) result, ok = store.GetPTRRecord(reverseDomain) diff --git a/peers/manager.go b/peers/manager.go index 0566775..e9925eb 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -144,7 +144,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { if address == nil { continue } - pm.dnsProxy.AddDNSRecord(alias.Alias, address) + pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId) } monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] @@ -433,7 +433,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { if address == nil { continue } - pm.dnsProxy.AddDNSRecord(alias.Alias, address) + pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId) } pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint) @@ -713,7 +713,7 @@ func (pm *PeerManager) AddAlias(siteId int, alias Alias) error { address := net.ParseIP(alias.AliasAddress) if address != nil { - pm.dnsProxy.AddDNSRecord(alias.Alias, address) + pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteId) } // Add an allowed IP for the alias