diff --git a/management/server/types/account.go b/management/server/types/account.go index 9ac2568a0..ca075b9f6 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -302,7 +302,11 @@ func (a *Account) GetPeerNetworkMap( var zones []nbdns.CustomZone if peersCustomZone.Domain != "" { - zones = append(zones, peersCustomZone) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + zones = append(zones, nbdns.CustomZone{ + Domain: peersCustomZone.Domain, + Records: records, + }) } dnsUpdate.CustomZones = zones dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) @@ -1651,3 +1655,24 @@ func peerSupportsPortRanges(peerVer string) bool { meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer) return err == nil && meetMinVer } + +// filterZoneRecordsForPeers filters DNS records to only include peers to connect. +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { + filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) + peerIPs := make(map[string]struct{}) + + // Add peer's own IP to include its own DNS records + peerIPs[peer.IP.String()] = struct{}{} + + for _, peerToConnect := range peersToConnect { + peerIPs[peerToConnect.IP.String()] = struct{}{} + } + + for _, record := range customZone.Records { + if _, exists := peerIPs[record.RData]; exists { + filteredRecords = append(filteredRecords, record) + } + } + + return filteredRecords +} diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index f8ab1d627..cd221b590 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -2,14 +2,17 @@ package types import ( "context" + "fmt" "net" "net/netip" "slices" "testing" + "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -835,3 +838,109 @@ func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 2, "expected source peers don't match") } + +func Test_FilterZoneRecordsForPeers(t *testing.T) { + tests := []struct { + name string + peer *nbpeer.Peer + customZone nbdns.CustomZone + peersToConnect []*nbpeer.Peer + expectedRecords []nbdns.SimpleRecord + }{ + { + name: "empty peers to connect", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + { + name: "multiple peers multiple records match", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: func() []nbdns.SimpleRecord { + var records []nbdns.SimpleRecord + for i := 1; i <= 100; i++ { + records = append(records, nbdns.SimpleRecord{ + Name: fmt.Sprintf("peer%d.netbird.cloud", i), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256), + }) + } + return records + }(), + }, + peersToConnect: func() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { + peers = append(peers, &nbpeer.Peer{ + ID: fmt.Sprintf("peer%d", i), + IP: net.ParseIP(fmt.Sprintf("10.0.%d.%d", i/256, i%256)), + }) + } + return peers + }(), + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: func() []nbdns.SimpleRecord { + var records []nbdns.SimpleRecord + for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { + records = append(records, nbdns.SimpleRecord{ + Name: fmt.Sprintf("peer%d.netbird.cloud", i), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: fmt.Sprintf("10.0.%d.%d", i/256, i%256), + }) + } + return records + }(), + }, + { + name: "peers with multiple DNS labels", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer3.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, + {Name: "peer3-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.3"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, + {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer1-backup.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "peer2-service.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + assert.Equal(t, len(tt.expectedRecords), len(result)) + assert.ElementsMatch(t, tt.expectedRecords, result) + }) + } +}