Store site id

This commit is contained in:
Owen
2026-03-06 16:19:00 -08:00
parent f2d0e6a14c
commit e2690bcc03
5 changed files with 59 additions and 33 deletions

View File

@@ -721,8 +721,8 @@ func (p *DNSProxy) runPacketSender() {
// AddDNSRecord adds a DNS record to the local store // AddDNSRecord adds a DNS record to the local store
// domain should be a domain name (e.g., "example.com" or "example.com.") // domain should be a domain name (e.g., "example.com" or "example.com.")
// ip should be a valid IPv4 or IPv6 address // ip should be a valid IPv4 or IPv6 address
func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP) error { func (p *DNSProxy) AddDNSRecord(domain string, ip net.IP, siteId int) error {
return p.recordStore.AddRecord(domain, ip) return p.recordStore.AddRecord(domain, ip, siteId)
} }
// RemoveDNSRecord removes a DNS record from the local store // RemoveDNSRecord removes a DNS record from the local store

View File

@@ -14,7 +14,7 @@ func TestCheckLocalRecordsNODATAForAAAA(t *testing.T) {
// Add an A record for a domain (no AAAA record) // Add an A record for a domain (no AAAA record)
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := proxy.recordStore.AddRecord("myservice.internal", ip) err := proxy.recordStore.AddRecord("myservice.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add A record: %v", err) t.Fatalf("Failed to add A record: %v", err)
} }
@@ -64,7 +64,7 @@ func TestCheckLocalRecordsNODATAForA(t *testing.T) {
// Add an AAAA record for a domain (no A record) // Add an AAAA record for a domain (no A record)
ip := net.ParseIP("2001:db8::1") ip := net.ParseIP("2001:db8::1")
err := proxy.recordStore.AddRecord("ipv6only.internal", ip) err := proxy.recordStore.AddRecord("ipv6only.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err) t.Fatalf("Failed to add AAAA record: %v", err)
} }
@@ -113,7 +113,7 @@ func TestCheckLocalRecordsNonExistentDomain(t *testing.T) {
} }
// Add a record so the store isn't empty // Add a record so the store isn't empty
err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1")) err := proxy.recordStore.AddRecord("exists.internal", net.ParseIP("10.0.0.1"), 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add record: %v", err) t.Fatalf("Failed to add record: %v", err)
} }
@@ -144,7 +144,7 @@ func TestCheckLocalRecordsNODATAWildcard(t *testing.T) {
// Add a wildcard A record (no AAAA) // Add a wildcard A record (no AAAA)
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := proxy.recordStore.AddRecord("*.wildcard.internal", ip) err := proxy.recordStore.AddRecord("*.wildcard.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard A record: %v", err) t.Fatalf("Failed to add wildcard A record: %v", err)
} }

View File

@@ -22,6 +22,7 @@ const (
type recordSet struct { type recordSet struct {
A []net.IP A []net.IP
AAAA []net.IP AAAA []net.IP
SiteId int
} }
// DNSRecordStore manages local DNS records for A, AAAA, and PTR queries. // DNSRecordStore manages local DNS records for A, AAAA, and PTR queries.
@@ -46,8 +47,9 @@ func NewDNSRecordStore() *DNSRecordStore {
// domain should be in FQDN format (e.g., "example.com.") // domain should be in FQDN format (e.g., "example.com.")
// domain can contain wildcards: * (0+ chars) and ? (exactly 1 char) // domain can contain wildcards: * (0+ chars) and ? (exactly 1 char)
// ip should be a valid IPv4 or IPv6 address // ip should be a valid IPv4 or IPv6 address
// siteId is the site that owns this alias/domain
// Automatically adds a corresponding PTR record for non-wildcard domains // Automatically adds a corresponding PTR record for non-wildcard domains
func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error { func (s *DNSRecordStore) AddRecord(domain string, ip net.IP, siteId int) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -69,7 +71,7 @@ func (s *DNSRecordStore) AddRecord(domain string, ip net.IP) error {
} }
if m[domain] == nil { if m[domain] == nil {
m[domain] = &recordSet{} m[domain] = &recordSet{SiteId: siteId}
} }
rs := m[domain] rs := m[domain]
if isV4 { if isV4 {
@@ -179,6 +181,30 @@ func (s *DNSRecordStore) RemovePTRRecord(ip net.IP) {
delete(s.ptrRecords, ip.String()) delete(s.ptrRecords, ip.String())
} }
// GetSiteIdForDomain returns the siteId associated with the given domain.
// It checks exact matches first, then wildcard patterns.
// The second return value is false if the domain is not found in local records.
func (s *DNSRecordStore) GetSiteIdForDomain(domain string) (int, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
domain = strings.ToLower(dns.Fqdn(domain))
// Check exact match first
if rs, exists := s.exact[domain]; exists {
return rs.SiteId, true
}
// Check wildcard matches
for pattern, rs := range s.wildcards {
if matchWildcard(pattern, domain) {
return rs.SiteId, true
}
}
return 0, false
}
// GetRecords returns all IP addresses for a domain and record type. // GetRecords returns all IP addresses for a domain and record type.
// The second return value indicates whether the domain exists at all // The second return value indicates whether the domain exists at all
// (true = domain exists, use NODATA if no records; false = NXDOMAIN). // (true = domain exists, use NODATA if no records; false = NXDOMAIN).

View File

@@ -170,14 +170,14 @@ func TestDNSRecordStoreWildcard(t *testing.T) {
// Add wildcard records // Add wildcard records
wildcardIP := net.ParseIP("10.0.0.1") wildcardIP := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", wildcardIP) err := store.AddRecord("*.autoco.internal", wildcardIP, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
// Add exact record // Add exact record
exactIP := net.ParseIP("10.0.0.2") exactIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("exact.autoco.internal", exactIP) err = store.AddRecord("exact.autoco.internal", exactIP, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add exact record: %v", err) t.Fatalf("Failed to add exact record: %v", err)
} }
@@ -221,7 +221,7 @@ func TestDNSRecordStoreComplexWildcard(t *testing.T) {
// Add complex wildcard pattern // Add complex wildcard pattern
ip1 := net.ParseIP("10.0.0.1") ip1 := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.host-0?.autoco.internal", ip1) err := store.AddRecord("*.host-0?.autoco.internal", ip1, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
@@ -262,7 +262,7 @@ func TestDNSRecordStoreRemoveWildcard(t *testing.T) {
// Add wildcard record // Add wildcard record
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
@@ -297,18 +297,18 @@ func TestDNSRecordStoreMultipleWildcards(t *testing.T) {
ip2 := net.ParseIP("10.0.0.2") ip2 := net.ParseIP("10.0.0.2")
ip3 := net.ParseIP("10.0.0.3") ip3 := net.ParseIP("10.0.0.3")
err := store.AddRecord("*.prod.autoco.internal", ip1) err := store.AddRecord("*.prod.autoco.internal", ip1, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add first wildcard: %v", err) t.Fatalf("Failed to add first wildcard: %v", err)
} }
err = store.AddRecord("*.dev.autoco.internal", ip2) err = store.AddRecord("*.dev.autoco.internal", ip2, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add second wildcard: %v", err) t.Fatalf("Failed to add second wildcard: %v", err)
} }
// Add a broader wildcard that matches both // Add a broader wildcard that matches both
err = store.AddRecord("*.autoco.internal", ip3) err = store.AddRecord("*.autoco.internal", ip3, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add third wildcard: %v", err) t.Fatalf("Failed to add third wildcard: %v", err)
} }
@@ -337,7 +337,7 @@ func TestDNSRecordStoreIPv6Wildcard(t *testing.T) {
// Add IPv6 wildcard record // Add IPv6 wildcard record
ip := net.ParseIP("2001:db8::1") ip := net.ParseIP("2001:db8::1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add IPv6 wildcard record: %v", err) t.Fatalf("Failed to add IPv6 wildcard record: %v", err)
} }
@@ -357,7 +357,7 @@ func TestHasRecordWildcard(t *testing.T) {
// Add wildcard record // Add wildcard record
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("*.autoco.internal", ip) err := store.AddRecord("*.autoco.internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
@@ -378,7 +378,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Add record with mixed case // Add record with mixed case
ip := net.ParseIP("10.0.0.1") ip := net.ParseIP("10.0.0.1")
err := store.AddRecord("MyHost.AutoCo.Internal", ip) err := store.AddRecord("MyHost.AutoCo.Internal", ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add mixed case record: %v", err) t.Fatalf("Failed to add mixed case record: %v", err)
} }
@@ -403,7 +403,7 @@ func TestDNSRecordStoreCaseInsensitive(t *testing.T) {
// Test wildcard with mixed case // Test wildcard with mixed case
wildcardIP := net.ParseIP("10.0.0.2") wildcardIP := net.ParseIP("10.0.0.2")
err = store.AddRecord("*.Example.Com", wildcardIP) err = store.AddRecord("*.Example.Com", wildcardIP, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add mixed case wildcard: %v", err) t.Fatalf("Failed to add mixed case wildcard: %v", err)
} }
@@ -689,7 +689,7 @@ func TestClearPTRRecords(t *testing.T) {
store.AddPTRRecord(ip2, "host2.example.com.") store.AddPTRRecord(ip2, "host2.example.com.")
// Add some A records too // Add some A records too
store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1")) store.AddRecord("test.example.com.", net.ParseIP("10.0.0.1"), 0)
// Verify PTR records exist // Verify PTR records exist
if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") { if !store.HasPTRRecord("1.1.168.192.in-addr.arpa.") {
@@ -719,7 +719,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) {
// Add an A record - should automatically add PTR record // Add an A record - should automatically add PTR record
domain := "host.example.com." domain := "host.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip) err := store.AddRecord(domain, ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add A record: %v", err) t.Fatalf("Failed to add A record: %v", err)
} }
@@ -737,7 +737,7 @@ func TestAutomaticPTRRecordOnAdd(t *testing.T) {
// Add AAAA record - should also automatically add PTR record // Add AAAA record - should also automatically add PTR record
domain6 := "ipv6host.example.com." domain6 := "ipv6host.example.com."
ip6 := net.ParseIP("2001:db8::1") ip6 := net.ParseIP("2001:db8::1")
err = store.AddRecord(domain6, ip6) err = store.AddRecord(domain6, ip6, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add AAAA record: %v", err) t.Fatalf("Failed to add AAAA record: %v", err)
} }
@@ -759,7 +759,7 @@ func TestAutomaticPTRRecordOnRemove(t *testing.T) {
// Add an A record (with automatic PTR) // Add an A record (with automatic PTR)
domain := "host.example.com." domain := "host.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain, ip) store.AddRecord(domain, ip, 0)
// Verify PTR exists // Verify PTR exists
reverseDomain := "100.1.168.192.in-addr.arpa." reverseDomain := "100.1.168.192.in-addr.arpa."
@@ -789,8 +789,8 @@ func TestAutomaticPTRRecordOnRemoveAll(t *testing.T) {
domain := "host.example.com." domain := "host.example.com."
ip1 := net.ParseIP("192.168.1.100") ip1 := net.ParseIP("192.168.1.100")
ip2 := net.ParseIP("192.168.1.101") ip2 := net.ParseIP("192.168.1.101")
store.AddRecord(domain, ip1) store.AddRecord(domain, ip1, 0)
store.AddRecord(domain, ip2) store.AddRecord(domain, ip2, 0)
// Verify both PTR records exist // Verify both PTR records exist
reverseDomain1 := "100.1.168.192.in-addr.arpa." reverseDomain1 := "100.1.168.192.in-addr.arpa."
@@ -820,7 +820,7 @@ func TestNoPTRForWildcardRecords(t *testing.T) {
// Add wildcard record - should NOT create PTR record // Add wildcard record - should NOT create PTR record
domain := "*.example.com." domain := "*.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
err := store.AddRecord(domain, ip) err := store.AddRecord(domain, ip, 0)
if err != nil { if err != nil {
t.Fatalf("Failed to add wildcard record: %v", err) t.Fatalf("Failed to add wildcard record: %v", err)
} }
@@ -844,7 +844,7 @@ func TestPTRRecordOverwrite(t *testing.T) {
// Add first domain with IP // Add first domain with IP
domain1 := "host1.example.com." domain1 := "host1.example.com."
ip := net.ParseIP("192.168.1.100") ip := net.ParseIP("192.168.1.100")
store.AddRecord(domain1, ip) store.AddRecord(domain1, ip, 0)
// Verify PTR points to first domain // Verify PTR points to first domain
reverseDomain := "100.1.168.192.in-addr.arpa." reverseDomain := "100.1.168.192.in-addr.arpa."
@@ -858,7 +858,7 @@ func TestPTRRecordOverwrite(t *testing.T) {
// Add second domain with same IP - should overwrite PTR // Add second domain with same IP - should overwrite PTR
domain2 := "host2.example.com." domain2 := "host2.example.com."
store.AddRecord(domain2, ip) store.AddRecord(domain2, ip, 0)
// Verify PTR now points to second domain (last one added) // Verify PTR now points to second domain (last one added)
result, ok = store.GetPTRRecord(reverseDomain) result, ok = store.GetPTRRecord(reverseDomain)

View File

@@ -144,7 +144,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
if address == nil { if address == nil {
continue continue
} }
pm.dnsProxy.AddDNSRecord(alias.Alias, address) pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId)
} }
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
@@ -433,7 +433,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
if address == nil { if address == nil {
continue continue
} }
pm.dnsProxy.AddDNSRecord(alias.Alias, address) pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteConfig.SiteId)
} }
pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint) pm.peerMonitor.UpdateHolepunchEndpoint(siteConfig.SiteId, siteConfig.Endpoint)
@@ -713,7 +713,7 @@ func (pm *PeerManager) AddAlias(siteId int, alias Alias) error {
address := net.ParseIP(alias.AliasAddress) address := net.ParseIP(alias.AliasAddress)
if address != nil { if address != nil {
pm.dnsProxy.AddDNSRecord(alias.Alias, address) pm.dnsProxy.AddDNSRecord(alias.Alias, address, siteId)
} }
// Add an allowed IP for the alias // Add an allowed IP for the alias