diff --git a/management/server/types/account.go b/management/server/types/account.go index 9e86d8936..692dc4541 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -355,6 +355,119 @@ func (a *Account) GetPeerNetworkMap( return nm } +// GetPeerNetworkMap returns the networkmap for the given peer ID. +func (a *Account) GetPeerNetworkMapCompacted( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + validatedPeersMap map[string]struct{}, + resourcePolicies map[string][]*Policy, + routers map[string]map[string]*routerTypes.NetworkRouter, + metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + start := time.Now() + peer := a.Peers[peerID] + if peer == nil { + return &NetworkMap{ + Network: a.Network.Copy(), + } + } + + if _, ok := validatedPeersMap[peerID]; !ok { + return &NetworkMap{ + Network: a.Network.Copy(), + } + } + + aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap) + // exclude expired peers + var peersToConnect []*nbpeer.Peer + var expiredPeers []*nbpeer.Peer + for _, p := range aclPeers { + expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) + if a.Settings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, p) + continue + } + peersToConnect = append(peersToConnect, p) + } + + routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect) + routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) + isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers) + var networkResourcesFirewallRules []*RouteFirewallRule + if isRouter { + networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies) + } + peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers) + + dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) + dnsUpdate := nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + var zones []nbdns.CustomZone + if peersCustomZone.Domain != "" { + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers) + zones = append(zones, nbdns.CustomZone{ + Domain: peersCustomZone.Domain, + Records: records, + }) + } + dnsUpdate.CustomZones = zones + dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) + } + type crt struct { + route *route.Route + peerIds []string + } + var routes []*route.Route + rtfilter := make(map[string]crt) + otherRoutesIDs := slices.Concat(networkResourcesRoutes, routesUpdate) + for _, route := range otherRoutesIDs { + rid, pid := splitRouteAndPeer(route) + if pid == peerID || len(pid) == 0 { + routes = append(routes, route) + continue + } + crt := rtfilter[rid] + crt.peerIds = append(crt.peerIds, pid) + crt.route = route.CopyClean() + rtfilter[rid] = crt + } + + for rid, crt := range rtfilter { + crt.route.ApplicablePeerIDs = crt.peerIds + crt.route.ID = route.ID(rid) + routes = append(routes, crt.route) + } + + nm := &NetworkMap{ + Peers: peersToConnectIncludingRouters, + Network: a.Network.Copy(), + Routes: routes, + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules), + } + + if metrics != nil { + objectCount := int64(len(peersToConnectIncludingRouters) + len(expiredPeers) + len(routesUpdate) + len(networkResourcesRoutes) + len(firewallRules) + +len(networkResourcesFirewallRules) + len(routesFirewallRules)) + metrics.CountNetworkMapObjects(objectCount) + metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + + if objectCount > 5000 { + log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+ + "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d, network resources routes: %d, network resources firewall rules: %d, routes firewall rules: %d", + a.Id, objectCount, len(peersToConnectIncludingRouters), len(expiredPeers), len(routesUpdate), len(firewallRules), len(networkResourcesRoutes), len(networkResourcesFirewallRules), len(routesFirewallRules)) + } + } + + return nm +} + func (a *Account) addNetworksRoutingPeers( networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, diff --git a/management/server/types/account_components.go b/management/server/types/account_components.go new file mode 100644 index 000000000..792122b11 --- /dev/null +++ b/management/server/types/account_components.go @@ -0,0 +1,442 @@ +package types + +import ( + "context" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/route" +) + +func (a *Account) GetPeerNetworkMapComponents( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + validatedPeersMap map[string]struct{}, + resourcePolicies map[string][]*Policy, + routers map[string]map[string]*routerTypes.NetworkRouter, +) *NetworkMapComponents { + + peer := a.Peers[peerID] + if peer == nil { + return nil + } + + if _, ok := validatedPeersMap[peerID]; !ok { + return nil + } + + components := &NetworkMapComponents{ + PeerID: peerID, + Serial: a.Network.Serial, + Network: a.Network.Copy(), + Peers: make(map[string]*nbpeer.Peer), + Groups: make(map[string]*Group), + Policies: make([]*Policy, 0), + Routes: make([]*route.Route, 0), + NameServerGroups: make([]*nbdns.NameServerGroup, 0), + CustomZoneDomain: peersCustomZone.Domain, + AllDNSRecords: peersCustomZone.Records, + ResourcePoliciesMap: make(map[string][]*Policy), + RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter), + NetworkResources: make([]*resourceTypes.NetworkResource, 0), + } + + components.AccountSettings = &AccountSettingsInfo{ + PeerLoginExpirationEnabled: a.Settings.PeerLoginExpirationEnabled, + PeerLoginExpiration: a.Settings.PeerLoginExpiration, + PeerInactivityExpirationEnabled: a.Settings.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: a.Settings.PeerInactivityExpiration, + } + + components.DNSSettings = &a.DNSSettings + + relevantPeerIDsList, relevantGroupIDs := a.findRelevantPeersAndGroups(ctx, peerID, validatedPeersMap) + + relevantPeerIDsMap := make(map[string]struct{}) + for _, pid := range relevantPeerIDsList { + relevantPeerIDsMap[pid] = struct{}{} + } + + _, _, networkResourcesSourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers) + for sourcePeerID := range networkResourcesSourcePeers { + relevantPeerIDsMap[sourcePeerID] = struct{}{} + } + + for pid := range relevantPeerIDsMap { + if p := a.Peers[pid]; p != nil { + components.Peers[pid] = p + } + } + + for gid := range relevantGroupIDs { + if g := a.Groups[gid]; g != nil { + components.Groups[gid] = g + } + } + + for _, policy := range a.Policies { + if a.isPolicyRelevantForPeer(ctx, policy, peerID, relevantGroupIDs) { + components.Policies = append(components.Policies, policy) + } + } + + for _, r := range a.Routes { + if a.isRouteRelevantForPeer(ctx, r, peerID, relevantGroupIDs) { + components.Routes = append(components.Routes, r) + } + } + + for _, nsGroup := range a.NameServerGroups { + if nsGroup.Enabled { + for _, gID := range nsGroup.Groups { + if _, found := relevantGroupIDs[gID]; found { + components.NameServerGroups = append(components.NameServerGroups, nsGroup.Copy()) + break + } + } + } + } + + relevantResourceIDs := make(map[string]struct{}) + relevantNetworkIDs := make(map[string]struct{}) + + for _, resource := range a.NetworkResources { + if !resource.Enabled { + continue + } + + policies, exists := resourcePolicies[resource.ID] + if !exists { + continue + } + + isRelevant := false + + networkRoutingPeers, routerExists := routers[resource.NetworkID] + if routerExists { + if _, ok := networkRoutingPeers[peerID]; ok { + isRelevant = true + } + } + + if !isRelevant { + for _, policy := range policies { + var peers []string + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peers = []string{policy.Rules[0].SourceResource.ID} + } else { + peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + } + + for _, p := range peers { + if p == peerID && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + isRelevant = true + break + } + } + + if isRelevant { + break + } + } + } + + if isRelevant { + relevantResourceIDs[resource.ID] = struct{}{} + relevantNetworkIDs[resource.NetworkID] = struct{}{} + components.NetworkResources = append(components.NetworkResources, resource) + } + } + + for resID, policies := range resourcePolicies { + if _, isRelevant := relevantResourceIDs[resID]; !isRelevant { + continue + } + + for _, p := range policies { + for _, rule := range p.Rules { + for _, srcGroupID := range rule.Sources { + if g := a.Groups[srcGroupID]; g != nil { + if _, exists := components.Groups[srcGroupID]; !exists { + components.Groups[srcGroupID] = g + } + } + } + for _, dstGroupID := range rule.Destinations { + if g := a.Groups[dstGroupID]; g != nil { + if _, exists := components.Groups[dstGroupID]; !exists { + components.Groups[dstGroupID] = g + } + } + } + } + } + components.ResourcePoliciesMap[resID] = policies + } + + for networkID, networkRouters := range routers { + if _, isRelevant := relevantNetworkIDs[networkID]; !isRelevant { + continue + } + + components.RoutersMap[networkID] = networkRouters + for peerIDKey := range networkRouters { + if _, exists := components.Peers[peerIDKey]; !exists { + if p := a.Peers[peerIDKey]; p != nil { + components.Peers[peerIDKey] = p + } + } + } + } + + for groupID, groupInfo := range components.Groups { + filteredPeers := make([]string, 0, len(groupInfo.Peers)) + for _, peerID := range groupInfo.Peers { + if _, exists := components.Peers[peerID]; exists { + filteredPeers = append(filteredPeers, peerID) + } + } + + if len(filteredPeers) == 0 { + delete(components.Groups, groupID) + } else { + groupInfo.Peers = filteredPeers + components.Groups[groupID] = groupInfo + } + } + + return components +} + +func (a *Account) findRelevantPeersAndGroups( + ctx context.Context, + peerID string, + validatedPeersMap map[string]struct{}, +) ([]string, map[string]struct{}) { + relevantPeerIDs := make(map[string]struct{}) + relevantGroupIDs := make(map[string]struct{}) + + relevantPeerIDs[peerID] = struct{}{} + + for groupID, group := range a.Groups { + for _, pid := range group.Peers { + if pid == peerID { + relevantGroupIDs[groupID] = struct{}{} + break + } + } + } + + for _, policy := range a.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + var sourcePeers, destinationPeers []string + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers = []string{rule.SourceResource.ID} + if rule.SourceResource.ID == peerID { + peerInSources = true + } + } else { + sourcePeers, peerInSources = a.getPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + destinationPeers = []string{rule.DestinationResource.ID} + if rule.DestinationResource.ID == peerID { + peerInDestinations = true + } + } else { + destinationPeers, peerInDestinations = a.getPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) + } + + if rule.Bidirectional { + if peerInSources { + for _, pid := range destinationPeers { + relevantPeerIDs[pid] = struct{}{} + } + for _, dstGroupID := range rule.Destinations { + relevantGroupIDs[dstGroupID] = struct{}{} + } + } + if peerInDestinations { + for _, pid := range sourcePeers { + relevantPeerIDs[pid] = struct{}{} + } + for _, srcGroupID := range rule.Sources { + relevantGroupIDs[srcGroupID] = struct{}{} + } + } + } + + if peerInSources { + for _, pid := range destinationPeers { + relevantPeerIDs[pid] = struct{}{} + } + for _, dstGroupID := range rule.Destinations { + relevantGroupIDs[dstGroupID] = struct{}{} + } + } + + if peerInDestinations { + for _, pid := range sourcePeers { + relevantPeerIDs[pid] = struct{}{} + } + for _, srcGroupID := range rule.Sources { + relevantGroupIDs[srcGroupID] = struct{}{} + } + } + } + } + + for _, r := range a.Routes { + isRelevant := false + + for _, groupID := range r.Groups { + if _, found := relevantGroupIDs[groupID]; found { + isRelevant = true + break + } + } + + if r.Peer == peerID || r.PeerID == peerID { + isRelevant = true + } + + for _, groupID := range r.PeerGroups { + if group := a.Groups[groupID]; group != nil { + for _, pid := range group.Peers { + if pid == peerID { + isRelevant = true + break + } + } + } + } + + if isRelevant { + for _, groupID := range r.Groups { + relevantGroupIDs[groupID] = struct{}{} + } + for _, groupID := range r.PeerGroups { + relevantGroupIDs[groupID] = struct{}{} + } + for _, groupID := range r.AccessControlGroups { + relevantGroupIDs[groupID] = struct{}{} + } + + if r.Peer != "" { + relevantPeerIDs[r.Peer] = struct{}{} + } + if r.PeerID != "" { + relevantPeerIDs[r.PeerID] = struct{}{} + } + } + } + + peerIDsList := make([]string, 0, len(relevantPeerIDs)) + for pid := range relevantPeerIDs { + peerIDsList = append(peerIDsList, pid) + } + + return peerIDsList, relevantGroupIDs +} + +func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]string, bool) { + peerInGroups := false + uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups) + filteredPeerIDs := make([]string, 0, len(uniquePeerIDs)) + + for _, p := range uniquePeerIDs { + peer, ok := a.Peers[p] + if !ok || peer == nil { + continue + } + + isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid { + continue + } + + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeerIDs = append(filteredPeerIDs, peer.ID) + } + + return filteredPeerIDs, peerInGroups +} + +func (a *Account) isPolicyRelevantForPeer(ctx context.Context, policy *Policy, peerID string, relevantGroupIDs map[string]struct{}) bool { + for _, rule := range policy.Rules { + for _, srcGroupID := range rule.Sources { + if _, found := relevantGroupIDs[srcGroupID]; found { + return true + } + } + + for _, dstGroupID := range rule.Destinations { + if _, found := relevantGroupIDs[dstGroupID]; found { + return true + } + } + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID == peerID { + return true + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID == peerID { + return true + } + } + + return false +} + +func (a *Account) isRouteRelevantForPeer(ctx context.Context, r *route.Route, peerID string, relevantGroupIDs map[string]struct{}) bool { + if r.Peer == peerID || r.PeerID == peerID { + return true + } + + for _, groupID := range r.Groups { + if _, found := relevantGroupIDs[groupID]; found { + return true + } + } + + for _, groupID := range r.PeerGroups { + if group := a.Groups[groupID]; group != nil { + for _, pid := range group.Peers { + if pid == peerID { + return true + } + } + } + } + + for _, groupID := range r.AccessControlGroups { + if _, found := relevantGroupIDs[groupID]; found { + return true + } + } + + return false +} + diff --git a/management/server/types/networkmap_comparison_test.go b/management/server/types/networkmap_comparison_test.go new file mode 100644 index 000000000..383fd1c90 --- /dev/null +++ b/management/server/types/networkmap_comparison_test.go @@ -0,0 +1,769 @@ +package types + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/netip" + "os" + "path/filepath" + "sort" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/dns" + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/route" +) + +func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) { + account := createTestAccount() + ctx := context.Background() + + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid == offlinePeerID { + continue + } + validatedPeersMap[pid] = struct{}{} + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + legacyNetworkMap := account.GetPeerNetworkMap( + ctx, + peerID, + peersCustomZone, + validatedPeersMap, + resourcePolicies, + routers, + nil, + ) + + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + validatedPeersMap, + resourcePolicies, + routers, + ) + + if components == nil { + t.Fatal("GetPeerNetworkMapComponents returned nil") + } + + newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) + + if newNetworkMap == nil { + t.Fatal("CalculateNetworkMapFromComponents returned nil") + } + + compareNetworkMaps(t, legacyNetworkMap, newNetworkMap) +} + +func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) { + account := createTestAccount() + ctx := context.Background() + + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid == offlinePeerID { + continue + } + validatedPeersMap[pid] = struct{}{} + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + legacyNetworkMap := account.GetPeerNetworkMap( + ctx, + peerID, + peersCustomZone, + validatedPeersMap, + resourcePolicies, + routers, + nil, + ) + + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + validatedPeersMap, + resourcePolicies, + routers, + ) + + require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil") + + newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) + require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil") + + normalizeAndSortNetworkMap(legacyNetworkMap) + normalizeAndSortNetworkMap(newNetworkMap) + + componentsJSON, err := json.MarshalIndent(components, "", " ") + require.NoError(t, err, "error marshaling components to JSON") + + legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ") + require.NoError(t, err, "error marshaling legacy network map to JSON") + + newJSON, err := json.MarshalIndent(newNetworkMap, "", " ") + require.NoError(t, err, "error marshaling new network map to JSON") + + goldenDir := filepath.Join("testdata", "comparison") + err = os.MkdirAll(goldenDir, 0755) + require.NoError(t, err) + + legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json") + err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644) + require.NoError(t, err, "error writing legacy golden file") + + newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json") + err = os.WriteFile(newGoldenPath, newJSON, 0644) + require.NoError(t, err, "error writing components golden file") + + componentsPath := filepath.Join(goldenDir, "components.json") + err = os.WriteFile(componentsPath, componentsJSON, 0644) + require.NoError(t, err, "error writing components golden file") + + require.JSONEq(t, string(legacyJSON), string(newJSON), + "NetworkMaps from legacy and components approaches do not match.\n"+ + "Legacy JSON saved to: %s\n"+ + "Components JSON saved to: %s", + legacyGoldenPath, newGoldenPath) + + t.Logf("✅ NetworkMaps are identical") + t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath) + t.Logf(" Components NetworkMap: %s", newGoldenPath) +} + +func normalizeAndSortNetworkMap(nm *NetworkMap) { + if nm == nil { + return + } + + sort.Slice(nm.Peers, func(i, j int) bool { + return nm.Peers[i].ID < nm.Peers[j].ID + }) + + sort.Slice(nm.OfflinePeers, func(i, j int) bool { + return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID + }) + + sort.Slice(nm.Routes, func(i, j int) bool { + return string(nm.Routes[i].ID) < string(nm.Routes[j].ID) + }) + + sort.Slice(nm.FirewallRules, func(i, j int) bool { + if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP { + return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP + } + if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction { + return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction + } + return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol + }) + + for i := range nm.RoutesFirewallRules { + sort.Strings(nm.RoutesFirewallRules[i].SourceRanges) + } + + sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool { + if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination { + return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination + } + + minLen := len(nm.RoutesFirewallRules[i].SourceRanges) + if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen { + minLen = len(nm.RoutesFirewallRules[j].SourceRanges) + } + for k := 0; k < minLen; k++ { + if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] { + return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k] + } + } + if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) { + return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges) + } + + return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID) + }) + + if nm.DNSConfig.CustomZones != nil { + for i := range nm.DNSConfig.CustomZones { + sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool { + return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name + }) + } + } +} + +func compareNetworkMaps(t *testing.T, legacy, new *NetworkMap) { + t.Helper() + + if legacy.Network.Serial != new.Network.Serial { + t.Errorf("Network Serial mismatch: legacy=%d, new=%d", legacy.Network.Serial, new.Network.Serial) + } + + if len(legacy.Peers) != len(new.Peers) { + t.Errorf("Peers count mismatch: legacy=%d, new=%d", len(legacy.Peers), len(new.Peers)) + } + + legacyPeerIDs := make(map[string]bool) + for _, p := range legacy.Peers { + legacyPeerIDs[p.ID] = true + } + + for _, p := range new.Peers { + if !legacyPeerIDs[p.ID] { + t.Errorf("New NetworkMap contains peer %s not in legacy", p.ID) + } + } + + if len(legacy.OfflinePeers) != len(new.OfflinePeers) { + t.Errorf("OfflinePeers count mismatch: legacy=%d, new=%d", len(legacy.OfflinePeers), len(new.OfflinePeers)) + } + + if len(legacy.FirewallRules) != len(new.FirewallRules) { + t.Logf("FirewallRules count mismatch: legacy=%d, new=%d", len(legacy.FirewallRules), len(new.FirewallRules)) + } + + if len(legacy.Routes) != len(new.Routes) { + t.Logf("Routes count mismatch: legacy=%d, new=%d", len(legacy.Routes), len(new.Routes)) + } + + if len(legacy.RoutesFirewallRules) != len(new.RoutesFirewallRules) { + t.Logf("RoutesFirewallRules count mismatch: legacy=%d, new=%d", len(legacy.RoutesFirewallRules), len(new.RoutesFirewallRules)) + } + + if legacy.DNSConfig.ServiceEnable != new.DNSConfig.ServiceEnable { + t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, new=%v", legacy.DNSConfig.ServiceEnable, new.DNSConfig.ServiceEnable) + } +} + +const ( + numPeers = 100 + devGroupID = "group-dev" + opsGroupID = "group-ops" + allGroupID = "group-all" + routeID = route.ID("route-main") + routeHA1ID = route.ID("route-ha-1") + routeHA2ID = route.ID("route-ha-2") + policyIDDevOps = "policy-dev-ops" + policyIDAll = "policy-all" + policyIDPosture = "policy-posture" + policyIDDrop = "policy-drop" + postureCheckID = "posture-check-ver" + networkResourceID = "res-database" + networkID = "net-database" + networkRouterID = "router-database" + nameserverGroupID = "ns-group-main" + testingPeerID = "peer-60" + expiredPeerID = "peer-98" + offlinePeerID = "peer-99" + routingPeerID = "peer-95" + testAccountID = "account-comparison-test" +) + +func createTestAccount() *Account { + peers := make(map[string]*nbpeer.Peer) + devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} + + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + ip := net.IP{100, 64, 0, byte(i + 1)} + wtVersion := "0.25.0" + if i%2 == 0 { + wtVersion = "0.40.0" + } + + p := &nbpeer.Peer{ + ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, + } + + if peerID == expiredPeerID { + p.LoginExpirationEnabled = true + pastTimestamp := time.Now().Add(-2 * time.Hour) + p.LastLogin = &pastTimestamp + } + + peers[peerID] = p + allGroupPeers = append(allGroupPeers, peerID) + if i < numPeers/2 { + devGroupPeers = append(devGroupPeers, peerID) + } else { + opsGroupPeers = append(opsGroupPeers, peerID) + } + } + + groups := map[string]*Group{ + allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, + devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, + opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, + } + + policies := []*Policy{ + { + ID: policyIDAll, Name: "Default-Allow", Enabled: true, + Rules: []*PolicyRule{{ + ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{allGroupID}, Destinations: []string{allGroupID}, + }}, + }, + { + ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, + Rules: []*PolicyRule{{ + ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolTCP, Bidirectional: false, + PortRanges: []RulePortRange{{Start: 8080, End: 8090}}, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, + Rules: []*PolicyRule{{ + ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop, + Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, + SourcePostureChecks: []string{postureCheckID}, + Rules: []*PolicyRule{{ + ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID}, + }}, + }, + } + + routes := map[route.ID]*route.Route{ + routeID: { + ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), + Peer: peers["peer-75"].Key, + PeerID: "peer-75", + Description: "Route to internal resource", Enabled: true, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + }, + routeHA1ID: { + ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-80"].Key, + PeerID: "peer-80", + Description: "HA Route 1", Enabled: true, Metric: 1000, + PeerGroups: []string{allGroupID}, + Groups: []string{allGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + routeHA2ID: { + ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-90"].Key, + PeerID: "peer-90", + Description: "HA Route 2", Enabled: true, Metric: 900, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + } + + account := &Account{ + Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, + Network: &Network{ + Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, + }, + DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, + NameServerGroups: map[string]*nbdns.NameServerGroup{ + nameserverGroupID: { + ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{ + {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, + }}, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, + }, + Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, + NetworkRouters: []*routerTypes.NetworkRouter{ + {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, + }, + Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + return account +} + +func BenchmarkLegacyNetworkMap(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMap( + ctx, + peerID, + peersCustomZone, + validatedPeersMap, + resourcePolicies, + routers, + nil, + ) + } +} + +func BenchmarkComponentsNetworkMap(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + validatedPeersMap, + resourcePolicies, + routers, + ) + _ = CalculateNetworkMapFromComponents(ctx, components) + } +} + +func BenchmarkComponentsCreation(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + validatedPeersMap, + resourcePolicies, + routers, + ) + } +} + +func BenchmarkCalculationFromComponents(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + validatedPeersMap, + resourcePolicies, + routers, + ) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = CalculateNetworkMapFromComponents(ctx, components) + } +} + +func TestGetPeerNetworkMap_ProdAccount_CompareImplementations(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + + testAccount := loadProdAccountFromJSON(t) + + testingPeerID := "cq3526bl0ubs73bbtpbg" + require.Contains(t, testAccount.Peers, testingPeerID, "Testing peer should exist in account") + + validatedPeersMap := make(map[string]struct{}) + for peerID := range testAccount.Peers { + validatedPeersMap[peerID] = struct{}{} + } + + resourcePolicies := testAccount.GetResourcePoliciesMap() + routers := testAccount.GetResourceRoutersMap() + + legacyNetworkMap := testAccount.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + require.NotNil(t, legacyNetworkMap, "GetPeerNetworkMap returned nil") + + components := testAccount.GetPeerNetworkMapComponents(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers) + require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil") + + newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) + require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil") + + normalizeAndSortNetworkMap(legacyNetworkMap) + normalizeAndSortNetworkMap(newNetworkMap) + + componentsJSON, err := json.MarshalIndent(components, "", " ") + require.NoError(t, err, "error marshaling components to JSON") + + legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ") + require.NoError(t, err, "error marshaling legacy network map to JSON") + + newJSON, err := json.MarshalIndent(newNetworkMap, "", " ") + require.NoError(t, err, "error marshaling new network map to JSON") + + outputDir := filepath.Join("testdata", fmt.Sprintf("compare_peer_%s", testingPeerID)) + err = os.MkdirAll(outputDir, 0755) + require.NoError(t, err) + + legacyFilePath := filepath.Join(outputDir, "legacy_networkmap.json") + err = os.WriteFile(legacyFilePath, legacyJSON, 0644) + require.NoError(t, err) + + componentsPath := filepath.Join(outputDir, "components.json") + err = os.WriteFile(componentsPath, componentsJSON, 0644) + require.NoError(t, err) + + newFilePath := filepath.Join(outputDir, "components_networkmap.json") + err = os.WriteFile(newFilePath, newJSON, 0644) + require.NoError(t, err) + + t.Logf("Files saved to:\n Legacy NetworkMap: %s\n Components: %s\n Components NetworkMap: %s", + legacyFilePath, componentsPath, newFilePath) + + require.JSONEq(t, string(legacyJSON), string(newJSON), + "NetworkMaps from legacy and components approaches do not match for peer %s.\n"+ + "Legacy JSON saved to: %s\n"+ + "Components JSON saved to: %s\n"+ + "Components NetworkMap saved to: %s", + testingPeerID, legacyFilePath, componentsPath, newFilePath) + + t.Logf("✅ NetworkMaps are identical for peer %s", testingPeerID) +} + +func loadProdAccountFromJSON(t testing.TB) *Account { + t.Helper() + + testDataPath := filepath.Join("testdata", "account_cnlf3j3l0ubs738o5d4g.json") + data, err := os.ReadFile(testDataPath) + require.NoError(t, err, "Failed to read prod account JSON file") + + var account Account + err = json.Unmarshal(data, &account) + require.NoError(t, err, "Failed to unmarshal prod account") + + if account.Groups == nil { + account.Groups = make(map[string]*Group) + } + if account.Peers == nil { + account.Peers = make(map[string]*nbpeer.Peer) + } + if account.Policies == nil { + account.Policies = []*Policy{} + } + + return &account +} + +func BenchmarkGetPeerNetworkMapCompactCached(b *testing.B) { + account := loadProdAccountFromJSON(b) + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}, len(account.Peers)) + for _, peer := range account.Peers { + validatedPeersMap[peer.ID] = struct{}{} + } + dnsDomain := account.Settings.DNSDomain + customZone := account.GetPeersCustomZone(ctx, dnsDomain) + + builder := NewNetworkMapBuilder(account, validatedPeersMap) + + testingPeerID := "d3knp53l0ubs738a3n6g" + + regularNm := builder.GetPeerNetworkMap(ctx, testingPeerID, customZone, validatedPeersMap, nil) + compactNm := builder.GetPeerNetworkMapCompact(ctx, testingPeerID, customZone, validatedPeersMap, nil) + compactCachedNm := builder.GetPeerNetworkMapCompactCached(ctx, testingPeerID, customZone, validatedPeersMap, nil) + + regularJSON, err := json.Marshal(regularNm) + require.NoError(b, err) + + compactJSON, err := json.Marshal(compactNm) + require.NoError(b, err) + + compactCachedJSON, err := json.Marshal(compactCachedNm) + require.NoError(b, err) + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + components := account.GetPeerNetworkMapComponents(ctx, testingPeerID, customZone, validatedPeersMap, resourcePolicies, routers) + componentsJSON, err := json.Marshal(components) + require.NoError(b, err) + + regularSize := len(regularJSON) + compactSize := len(compactJSON) + compactCachedSize := len(compactCachedJSON) + componentsSize := len(componentsJSON) + + compactSavingsPercent := 100 - int(float64(compactCachedSize)/float64(regularSize)*100) + componentsSavingsPercent := 100 - int(float64(componentsSize)/float64(regularSize)*100) + + b.ReportMetric(float64(regularSize), "regular_bytes") + b.ReportMetric(float64(compactCachedSize), "compact_cached_bytes") + b.ReportMetric(float64(componentsSize), "components_bytes") + b.ReportMetric(float64(compactSavingsPercent), "compact_savings_%") + b.ReportMetric(float64(componentsSavingsPercent), "components_savings_%") + + b.Logf("========== Network Map Size Comparison ==========") + b.Logf("Regular network map: %d bytes", regularSize) + b.Logf("Compact network map: %d bytes (-%d%%)", compactSize, 100-int(float64(compactSize)/float64(regularSize)*100)) + b.Logf("Compact cached network map: %d bytes (-%d%%)", compactCachedSize, compactSavingsPercent) + b.Logf("Components: %d bytes (-%d%%)", componentsSize, componentsSavingsPercent) + b.Logf("") + b.Logf("Bandwidth savings (Compact cached): %d bytes saved (%d%%)", regularSize-compactCachedSize, compactSavingsPercent) + b.Logf("Bandwidth savings (Components): %d bytes saved (%d%%)", regularSize-componentsSize, componentsSavingsPercent) + b.Logf("=================================================") + + b.Run("Regular", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, customZone, validatedPeersMap, nil) + } + }) + + b.Run("CompactOnDemand", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = builder.GetPeerNetworkMapCompact(ctx, testingPeerID, customZone, validatedPeersMap, nil) + } + }) + + b.Run("CompactCached", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = builder.GetPeerNetworkMapCompactCached(ctx, testingPeerID, customZone, validatedPeersMap, nil) + } + }) + b.Run("Legacy", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, customZone, validatedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + } + }) + b.Run("LegacyCompacted", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMapCompacted(ctx, testingPeerID, customZone, validatedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + } + }) + + b.Run("ComponentsNetworkMap", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + components := account.GetPeerNetworkMapComponents( + ctx, + testingPeerID, + customZone, + validatedPeersMap, + resourcePolicies, + routers, + ) + _ = CalculateNetworkMapFromComponents(ctx, components) + } + }) + + b.Run("ComponentsCreation", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMapComponents( + ctx, + testingPeerID, + customZone, + validatedPeersMap, + resourcePolicies, + routers, + ) + } + }) + + b.Run("CalculationFromComponents", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = CalculateNetworkMapFromComponents(ctx, components) + } + }) +} diff --git a/management/server/types/networkmap_components.go b/management/server/types/networkmap_components.go new file mode 100644 index 000000000..1bb6c0c9d --- /dev/null +++ b/management/server/types/networkmap_components.go @@ -0,0 +1,951 @@ +package types + +import ( + "context" + "net" + "net/netip" + "strings" + "time" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" +) + +type NetworkMapComponents struct { + PeerID string + Serial uint64 + + Network *Network + AccountSettings *AccountSettingsInfo + DNSSettings *DNSSettings + CustomZoneDomain string + + Peers map[string]*nbpeer.Peer + Groups map[string]*Group + Policies []*Policy + Routes []*route.Route + NameServerGroups []*nbdns.NameServerGroup + AllDNSRecords []nbdns.SimpleRecord + ResourcePoliciesMap map[string][]*Policy + RoutersMap map[string]map[string]*routerTypes.NetworkRouter + NetworkResources []*resourceTypes.NetworkResource +} + +type AccountSettingsInfo struct { + PeerLoginExpirationEnabled bool + PeerLoginExpiration time.Duration + PeerInactivityExpirationEnabled bool + PeerInactivityExpiration time.Duration +} + +func (c *NetworkMapComponents) GetPeerInfo(peerID string) *nbpeer.Peer { + return c.Peers[peerID] +} + +func (c *NetworkMapComponents) GetGroupInfo(groupID string) *Group { + return c.Groups[groupID] +} + +func (c *NetworkMapComponents) IsPeerInGroup(peerID, groupID string) bool { + group := c.GetGroupInfo(groupID) + if group == nil { + return false + } + + for _, pid := range group.Peers { + if pid == peerID { + return true + } + } + return false +} + +func (c *NetworkMapComponents) GetPeerGroups(peerID string) map[string]struct{} { + groups := make(map[string]struct{}) + for groupID, group := range c.Groups { + for _, pid := range group.Peers { + if pid == peerID { + groups[groupID] = struct{}{} + break + } + } + } + return groups +} + +func (c *NetworkMapComponents) ValidatePostureChecksOnPeer(peerID string, postureCheckIDs []string) bool { + if len(postureCheckIDs) == 0 { + return true + } + _, exists := c.Peers[peerID] + return exists +} + +type NetworkMapCalculator struct { + components *NetworkMapComponents +} + +func NewNetworkMapCalculator(components *NetworkMapComponents) *NetworkMapCalculator { + return &NetworkMapCalculator{ + components: components, + } +} + +func CalculateNetworkMapFromComponents(ctx context.Context, components *NetworkMapComponents) *NetworkMap { + calculator := NewNetworkMapCalculator(components) + return calculator.Calculate(ctx) +} + +func (calc *NetworkMapCalculator) Calculate(ctx context.Context) *NetworkMap { + targetPeerID := calc.components.PeerID + + aclPeers, firewallRules := calc.getPeerConnectionResources(ctx, targetPeerID) + + peersToConnect, expiredPeers := calc.filterPeersByLoginExpiration(aclPeers) + + routesUpdate := calc.getRoutesToSync(ctx, targetPeerID, peersToConnect) + routesFirewallRules := calc.getPeerRoutesFirewallRules(ctx, targetPeerID) + + isRouter, networkResourcesRoutes, sourcePeers := calc.getNetworkResourcesRoutesToSync(ctx, targetPeerID) + var networkResourcesFirewallRules []*RouteFirewallRule + if isRouter { + networkResourcesFirewallRules = calc.getPeerNetworkResourceFirewallRules(ctx, targetPeerID, networkResourcesRoutes) + } + + peersToConnectIncludingRouters := calc.addNetworksRoutingPeers( + networkResourcesRoutes, + targetPeerID, + peersToConnect, + expiredPeers, + isRouter, + sourcePeers, + ) + + dnsManagementStatus := calc.getPeerDNSManagementStatus(targetPeerID) + dnsUpdate := nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + var zones []nbdns.CustomZone + if calc.components.CustomZoneDomain != "" { + records := calc.filterZoneRecordsForPeers(targetPeerID, peersToConnectIncludingRouters, expiredPeers) + zones = append(zones, nbdns.CustomZone{ + Domain: calc.components.CustomZoneDomain, + Records: records, + }) + } + dnsUpdate.CustomZones = zones + dnsUpdate.NameServerGroups = calc.getPeerNSGroups(targetPeerID) + } + + return &NetworkMap{ + Peers: peersToConnectIncludingRouters, + Network: calc.components.Network.Copy(), + Routes: append(networkResourcesRoutes, routesUpdate...), + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: append(networkResourcesFirewallRules, routesFirewallRules...), + } +} + +func (calc *NetworkMapCalculator) getPeerConnectionResources(ctx context.Context, targetPeerID string) ([]*nbpeer.Peer, []*FirewallRule) { + targetPeer := calc.components.GetPeerInfo(targetPeerID) + if targetPeer == nil { + return nil, nil + } + + generateResources, getAccumulatedResources := calc.connResourcesGenerator(ctx, targetPeer) + + for _, policy := range calc.components.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + var sourcePeers, destinationPeers []*nbpeer.Peer + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers, peerInSources = calc.getPeerFromResource(rule.SourceResource, targetPeerID) + } else { + sourcePeers, peerInSources = calc.getAllPeersFromGroups(ctx, rule.Sources, targetPeerID, policy.SourcePostureChecks) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + destinationPeers, peerInDestinations = calc.getPeerFromResource(rule.DestinationResource, targetPeerID) + } else { + destinationPeers, peerInDestinations = calc.getAllPeersFromGroups(ctx, rule.Destinations, targetPeerID, nil) + } + + if rule.Bidirectional { + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionIN) + } + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionOUT) + } + } + + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionOUT) + } + + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionIN) + } + } + } + + return getAccumulatedResources() +} + +func (calc *NetworkMapCalculator) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { + rulesExists := make(map[string]struct{}) + peersExists := make(map[string]struct{}) + rules := make([]*FirewallRule, 0) + peers := make([]*nbpeer.Peer, 0) + + return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { + for _, peer := range groupPeers { + if peer == nil { + continue + } + + if _, ok := peersExists[peer.ID]; !ok { + peers = append(peers, peer) + peersExists[peer.ID] = struct{}{} + } + + fr := FirewallRule{ + PolicyID: rule.ID, + PeerIP: net.IP(peer.IP).String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + ruleID := rule.ID + fr.PeerIP + string(rune(direction)) + + fr.Protocol + fr.Action + for _, port := range rule.Ports { + ruleID += port + } + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { + rules = append(rules, &fr) + continue + } + + rules = append(rules, expandPortsAndRanges(fr, &PolicyRule{ + ID: rule.ID, + Ports: rule.Ports, + PortRanges: rule.PortRanges, + Protocol: rule.Protocol, + Action: rule.Action, + }, targetPeer)...) + } + }, func() ([]*nbpeer.Peer, []*FirewallRule) { + return peers, rules + } +} + +func (calc *NetworkMapCalculator) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) { + peerInGroups := false + uniquePeerIDs := calc.getUniquePeerIDsFromGroupsIDs(ctx, groups) + filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs)) + + for _, p := range uniquePeerIDs { + peerInfo := calc.components.GetPeerInfo(p) + if peerInfo == nil { + continue + } + + if _, ok := calc.components.Peers[p]; !ok { + continue + } + + if !calc.components.ValidatePostureChecksOnPeer(p, sourcePostureChecksIDs) { + continue + } + + if p == peerID { + peerInGroups = true + continue + } + + filteredPeers = append(filteredPeers, peerInfo) + } + + return filteredPeers, peerInGroups +} + +func (calc *NetworkMapCalculator) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { + peerIDs := make(map[string]struct{}, len(groups)) + for _, groupID := range groups { + group := calc.components.GetGroupInfo(groupID) + if group == nil { + continue + } + + if len(groups) == 1 { + return group.Peers + } + + for _, peerID := range group.Peers { + peerIDs[peerID] = struct{}{} + } + } + + ids := make([]string, 0, len(peerIDs)) + for peerID := range peerIDs { + ids = append(ids, peerID) + } + + return ids +} + +func (calc *NetworkMapCalculator) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) { + peerInfo := calc.components.GetPeerInfo(resource.ID) + if peerInfo == nil { + return []*nbpeer.Peer{}, false + } + + return []*nbpeer.Peer{peerInfo}, resource.ID == peerID +} + +func (calc *NetworkMapCalculator) filterPeersByLoginExpiration(aclPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*nbpeer.Peer) { + var peersToConnect []*nbpeer.Peer + var expiredPeers []*nbpeer.Peer + + for _, p := range aclPeers { + expired, _ := p.LoginExpired(calc.components.AccountSettings.PeerLoginExpiration) + if calc.components.AccountSettings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, p) + continue + } + peersToConnect = append(peersToConnect, p) + } + + return peersToConnect, expiredPeers +} + +func (calc *NetworkMapCalculator) getPeerDNSManagementStatus(peerID string) bool { + peerGroups := calc.components.GetPeerGroups(peerID) + enabled := true + for _, groupID := range calc.components.DNSSettings.DisabledManagementGroups { + if _, found := peerGroups[groupID]; found { + enabled = false + break + } + } + return enabled +} + +func (calc *NetworkMapCalculator) filterZoneRecordsForPeers(peerID string, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord { + filteredRecords := make([]nbdns.SimpleRecord, 0, len(calc.components.AllDNSRecords)) + peerIPs := make(map[string]struct{}) + + targetPeerInfo := calc.components.GetPeerInfo(peerID) + if targetPeerInfo != nil { + peerIPs[string(targetPeerInfo.IP)] = struct{}{} + } + + for _, peer := range peersToConnect { + peerIPs[string(peer.IP)] = struct{}{} + } + + for _, peer := range expiredPeers { + peerIPs[string(peer.IP)] = struct{}{} + } + + for _, record := range calc.components.AllDNSRecords { + if _, exists := peerIPs[record.RData]; exists { + filteredRecords = append(filteredRecords, record) + } + } + + return filteredRecords +} + +func (calc *NetworkMapCalculator) getPeerNSGroups(peerID string) []*nbdns.NameServerGroup { + groupList := calc.components.GetPeerGroups(peerID) + + var peerNSGroups []*nbdns.NameServerGroup + + for _, nsGroup := range calc.components.NameServerGroups { + if !nsGroup.Enabled { + continue + } + for _, gID := range nsGroup.Groups { + _, found := groupList[gID] + if found { + targetPeerInfo := calc.components.GetPeerInfo(peerID) + if targetPeerInfo != nil && !calc.peerIsNameserver(targetPeerInfo, nsGroup) { + peerNSGroups = append(peerNSGroups, nsGroup.Copy()) + break + } + } + } + } + + return peerNSGroups +} + +func (calc *NetworkMapCalculator) peerIsNameserver(peerInfo *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { + for _, ns := range nsGroup.NameServers { + if peerInfo.IP.String() == ns.IP.String() { + return true + } + } + return false +} + +func (calc *NetworkMapCalculator) getRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route { + routes, peerDisabledRoutes := calc.getRoutingPeerRoutes(ctx, peerID) + peerRoutesMembership := make(LookupMap) + for _, r := range append(routes, peerDisabledRoutes...) { + peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} + } + + groupListMap := calc.components.GetPeerGroups(peerID) + for _, peer := range aclPeers { + activeRoutes, _ := calc.getRoutingPeerRoutes(ctx, peer.ID) + groupFilteredRoutes := calc.filterRoutesByGroups(activeRoutes, groupListMap) + filteredRoutes := calc.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) + routes = append(routes, filteredRoutes...) + } + + return routes +} + +func (calc *NetworkMapCalculator) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + peerInfo := calc.components.GetPeerInfo(peerID) + if peerInfo == nil { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[route.ID]struct{}) + + takeRoute := func(r *route.Route, id string) { + if _, ok := seenRoute[r.ID]; ok { + return + } + seenRoute[r.ID] = struct{}{} + + routeObj := calc.copyRoute(r) + routeObj.Peer = peerInfo.Key + + if r.Enabled { + enabledRoutes = append(enabledRoutes, routeObj) + return + } + disabledRoutes = append(disabledRoutes, routeObj) + } + + for _, r := range calc.components.Routes { + for _, groupID := range r.PeerGroups { + group := calc.components.GetGroupInfo(groupID) + if group == nil { + continue + } + for _, id := range group.Peers { + if id != peerID { + continue + } + + newPeerRoute := calc.copyRoute(r) + newPeerRoute.Peer = id + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) + takeRoute(newPeerRoute, id) + break + } + } + if r.Peer == peerID || r.PeerID == peerID { + takeRoute(calc.copyRoute(r), peerID) + } + } + + return enabledRoutes, disabledRoutes +} + +func (calc *NetworkMapCalculator) copyRoute(r *route.Route) *route.Route { + var groups, accessControlGroups, peerGroups []string + var domains domain.List + + if r.Groups != nil { + groups = append([]string{}, r.Groups...) + } + if r.AccessControlGroups != nil { + accessControlGroups = append([]string{}, r.AccessControlGroups...) + } + if r.PeerGroups != nil { + peerGroups = append([]string{}, r.PeerGroups...) + } + if r.Domains != nil { + domains = append(domain.List{}, r.Domains...) + } + + return &route.Route{ + ID: r.ID, + AccountID: r.AccountID, + Network: r.Network, + NetworkType: r.NetworkType, + Description: r.Description, + Peer: r.Peer, + PeerID: r.PeerID, + Metric: r.Metric, + Masquerade: r.Masquerade, + NetID: r.NetID, + Enabled: r.Enabled, + Groups: groups, + AccessControlGroups: accessControlGroups, + PeerGroups: peerGroups, + Domains: domains, + KeepRoute: r.KeepRoute, + } +} + +func (calc *NetworkMapCalculator) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + for _, groupID := range r.Groups { + _, found := groupListMap[groupID] + if found { + filteredRoutes = append(filteredRoutes, r) + break + } + } + } + return filteredRoutes +} + +func (calc *NetworkMapCalculator) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + _, found := peerMemberships[string(r.GetHAUniqueID())] + if !found { + filteredRoutes = append(filteredRoutes, r) + } + } + return filteredRoutes +} + +func (calc *NetworkMapCalculator) getPeerRoutesFirewallRules(ctx context.Context, peerID string) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + enabledRoutes, _ := calc.getRoutingPeerRoutes(ctx, peerID) + for _, r := range enabledRoutes { + if len(r.AccessControlGroups) == 0 { + defaultPermit := calc.getDefaultPermit(r) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + distributionPeers := calc.getDistributionGroupsPeers(r) + + for _, accessGroup := range r.AccessControlGroups { + policies := calc.getAllRoutePoliciesFromGroups([]string{accessGroup}) + rules := calc.getRouteFirewallRules(ctx, peerID, policies, r, distributionPeers) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (calc *NetworkMapCalculator) findRoute(routeID route.ID) *route.Route { + for _, r := range calc.components.Routes { + if r.ID == routeID { + return r + } + } + + parts := strings.Split(string(routeID), ":") + if len(parts) > 1 { + baseRouteID := route.ID(parts[0]) + for _, r := range calc.components.Routes { + if r.ID == baseRouteID { + return r + } + } + } + + return nil +} + +func (calc *NetworkMapCalculator) getDefaultPermit(r *route.Route) []*RouteFirewallRule { + var rules []*RouteFirewallRule + + sources := []string{"0.0.0.0/0"} + + rule := RouteFirewallRule{ + SourceRanges: sources, + Action: string(PolicyTrafficActionAccept), + Destination: r.Network.String(), + Protocol: string(PolicyRuleProtocolALL), + Domains: r.Domains, + IsDynamic: len(r.Domains) > 0, + RouteID: r.ID, + } + + rules = append(rules, &rule) + + if len(r.Domains) > 0 { + ruleV6 := rule + ruleV6.SourceRanges = []string{"::/0"} + rules = append(rules, &ruleV6) + } + + return rules +} + +func (calc *NetworkMapCalculator) getDistributionGroupsPeers(r *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range r.Groups { + group := calc.components.GetGroupInfo(id) + if group == nil { + continue + } + + for _, pID := range group.Peers { + distPeers[pID] = struct{}{} + } + } + return distPeers +} + +func (calc *NetworkMapCalculator) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy { + routePolicies := make([]*Policy, 0) + for _, groupID := range accessControlGroups { + group := calc.components.GetGroupInfo(groupID) + if group == nil { + continue + } + + for _, policy := range calc.components.Policies { + for _, rule := range policy.Rules { + for _, destGroupID := range rule.Destinations { + if destGroupID == group.ID { + routePolicies = append(routePolicies, policy) + break + } + } + } + } + } + + return routePolicies +} + +func (calc *NetworkMapCalculator) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, distributionPeers map[string]struct{}) []*RouteFirewallRule { + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + rulePeers := calc.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers) + rules := calc.generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} + +func (calc *NetworkMapCalculator) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + for _, id := range rule.Sources { + group := calc.components.GetGroupInfo(id) + if group == nil { + continue + } + + for _, pID := range group.Peers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := calc.components.Peers[pID] + if distPeer && valid && calc.components.ValidatePostureChecksOnPeer(pID, postureChecks) { + distPeersWithPolicy[pID] = struct{}{} + } + } + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + _, distPeer := distributionPeers[rule.SourceResource.ID] + _, valid := calc.components.Peers[rule.SourceResource.ID] + if distPeer && valid && calc.components.ValidatePostureChecksOnPeer(rule.SourceResource.ID, postureChecks) { + distPeersWithPolicy[rule.SourceResource.ID] = struct{}{} + } + } + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peerInfo := calc.components.GetPeerInfo(pID) + if peerInfo == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peerInfo) + } + return distributionGroupPeers +} + +func (calc *NetworkMapCalculator) generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, rulePeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { + sourceRanges := make([]string, 0, len(rulePeers)) + for _, peer := range rulePeers { + if peer == nil { + continue + } + sourceRanges = append(sourceRanges, peer.IP.String()+"/32") + } + + if len(sourceRanges) == 0 { + return nil + } + + baseRule := &RouteFirewallRule{ + RouteID: route.ID, + SourceRanges: sourceRanges, + Action: string(rule.Action), + Destination: route.Network.String(), + Protocol: string(rule.Protocol), + Domains: route.Domains, + IsDynamic: len(route.Domains) > 0, + } + + return []*RouteFirewallRule{baseRule} +} + +func (calc *NetworkMapCalculator) getNetworkResourcesRoutesToSync(ctx context.Context, peerID string) (bool, []*route.Route, map[string]struct{}) { + var isRoutingPeer bool + var routes []*route.Route + allSourcePeers := make(map[string]struct{}) + + for _, resource := range calc.components.NetworkResources { + if !resource.Enabled { + continue + } + + var addSourcePeers bool + + networkRoutingPeers, exists := calc.components.RoutersMap[resource.NetworkID] + if exists { + if router, ok := networkRoutingPeers[peerID]; ok { + isRoutingPeer, addSourcePeers = true, true + routes = append(routes, calc.getNetworkResourcesRoutes(resource, peerID, router)...) + } + } + + addedResourceRoute := false + for _, policy := range calc.components.ResourcePoliciesMap[resource.ID] { + var peers []string + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peers = []string{policy.Rules[0].SourceResource.ID} + } else { + peers = calc.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + } + if addSourcePeers { + for _, pID := range calc.getPostureValidPeers(peers, policy.SourcePostureChecks) { + allSourcePeers[pID] = struct{}{} + } + } else if calc.peerInSlice(peerID, peers) && calc.components.ValidatePostureChecksOnPeer(peerID, policy.SourcePostureChecks) { + for peerId, router := range networkRoutingPeers { + routes = append(routes, calc.getNetworkResourcesRoutes(resource, peerId, router)...) + } + addedResourceRoute = true + } + if addedResourceRoute { + break + } + } + } + + return isRoutingPeer, routes, allSourcePeers +} + +func (calc *NetworkMapCalculator) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerID string, router *routerTypes.NetworkRouter) []*route.Route { + resourceAppliedPolicies := calc.components.ResourcePoliciesMap[resource.ID] + + var routes []*route.Route + if len(resourceAppliedPolicies) > 0 { + peerInfo := calc.components.GetPeerInfo(peerID) + if peerInfo != nil { + routes = append(routes, calc.networkResourceToRoute(resource, peerInfo, router)) + } + } + + return routes +} + +func (calc *NetworkMapCalculator) networkResourceToRoute(resource *resourceTypes.NetworkResource, peer *nbpeer.Peer, router *routerTypes.NetworkRouter) *route.Route { + r := &route.Route{ + ID: route.ID(resource.ID + ":" + peer.ID), + AccountID: resource.AccountID, + Peer: peer.Key, + PeerID: peer.ID, + Metric: router.Metric, + Masquerade: router.Masquerade, + Enabled: resource.Enabled, + KeepRoute: true, + NetID: route.NetID(resource.Name), + Description: resource.Description, + } + + if resource.Type == resourceTypes.Host || resource.Type == resourceTypes.Subnet { + r.Network = resource.Prefix + + r.NetworkType = route.IPv4Network + if resource.Prefix.Addr().Is6() { + r.NetworkType = route.IPv6Network + } + } + + if resource.Type == resourceTypes.Domain { + domainList, err := domain.FromStringList([]string{resource.Domain}) + if err == nil { + r.Domains = domainList + r.NetworkType = route.DomainNetwork + r.Network = netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) + } + } + + return r +} + +func (calc *NetworkMapCalculator) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { + var dest []string + for _, peerID := range inputPeers { + if calc.components.ValidatePostureChecksOnPeer(peerID, postureChecksIDs) { + dest = append(dest, peerID) + } + } + return dest +} + +func (calc *NetworkMapCalculator) peerInSlice(peerID string, peers []string) bool { + for _, p := range peers { + if p == peerID { + return true + } + } + return false +} + + +func (calc *NetworkMapCalculator) getPeerNetworkResourceFirewallRules(ctx context.Context, peerID string, routes []*route.Route) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + peerInfo := calc.components.GetPeerInfo(peerID) + if peerInfo == nil { + return routesFirewallRules + } + + for _, r := range routes { + if r.Peer != peerInfo.Key { + continue + } + + resourceID := string(r.GetResourceID()) + resourcePolicies := calc.components.ResourcePoliciesMap[resourceID] + distributionPeers := calc.getPoliciesSourcePeers(resourcePolicies) + + rules := calc.getRouteFirewallRules(ctx, peerID, resourcePolicies, r, distributionPeers) + for _, rule := range rules { + if len(rule.SourceRanges) > 0 { + routesFirewallRules = append(routesFirewallRules, rule) + } + } + } + + return routesFirewallRules +} + +func (calc *NetworkMapCalculator) getPoliciesSourcePeers(policies []*Policy) map[string]struct{} { + sourcePeers := make(map[string]struct{}) + + for _, policy := range policies { + for _, rule := range policy.Rules { + for _, sourceGroup := range rule.Sources { + group := calc.components.GetGroupInfo(sourceGroup) + if group == nil { + continue + } + + for _, peer := range group.Peers { + sourcePeers[peer] = struct{}{} + } + } + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers[rule.SourceResource.ID] = struct{}{} + } + } + } + + return sourcePeers +} + +func (calc *NetworkMapCalculator) addNetworksRoutingPeers( + networkResourcesRoutes []*route.Route, + peerID string, + peersToConnect []*nbpeer.Peer, + expiredPeers []*nbpeer.Peer, + isRouter bool, + sourcePeers map[string]struct{}, +) []*nbpeer.Peer { + + networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) + for _, r := range networkResourcesRoutes { + networkRoutesPeers[r.PeerID] = struct{}{} + } + + delete(sourcePeers, peerID) + delete(networkRoutesPeers, peerID) + + for _, existingPeer := range peersToConnect { + delete(sourcePeers, existingPeer.ID) + delete(networkRoutesPeers, existingPeer.ID) + } + for _, expPeer := range expiredPeers { + delete(sourcePeers, expPeer.ID) + delete(networkRoutesPeers, expPeer.ID) + } + + missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) + if isRouter { + for p := range sourcePeers { + missingPeers[p] = struct{}{} + } + } + for p := range networkRoutesPeers { + missingPeers[p] = struct{}{} + } + + for p := range missingPeers { + peerInfo := calc.components.GetPeerInfo(p) + if peerInfo != nil { + peersToConnect = append(peersToConnect, peerInfo) + } + } + + return peersToConnect +} diff --git a/management/server/types/networkmap_golden_test.go b/management/server/types/networkmap_golden_test.go index 1e8f917ec..fa2fe56b2 100644 --- a/management/server/types/networkmap_golden_test.go +++ b/management/server/types/networkmap_golden_test.go @@ -1290,19 +1290,35 @@ func BenchmarkGetPeerNetworkMapCompactCached(b *testing.B) { compactCachedJSON, err := json.Marshal(compactCachedNm) require.NoError(b, err) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + components := account.GetPeerNetworkMapComponents(ctx, testingPeerID, customZone, validatedPeersMap, resourcePolicies, routers) + componentsJSON, err := json.Marshal(components) + require.NoError(b, err) + regularSize := len(regularJSON) compactSize := len(compactJSON) compactCachedSize := len(compactCachedJSON) - savingsPercent := 100 - int(float64(compactCachedSize)/float64(regularSize)*100) + componentsSize := len(componentsJSON) + + compactSavingsPercent := 100 - int(float64(compactCachedSize)/float64(regularSize)*100) + componentsSavingsPercent := 100 - int(float64(componentsSize)/float64(regularSize)*100) b.ReportMetric(float64(regularSize), "regular_bytes") b.ReportMetric(float64(compactCachedSize), "compact_cached_bytes") - b.ReportMetric(float64(savingsPercent), "savings_%") + b.ReportMetric(float64(componentsSize), "components_bytes") + b.ReportMetric(float64(compactSavingsPercent), "compact_savings_%") + b.ReportMetric(float64(componentsSavingsPercent), "components_savings_%") - b.Logf("Regular network map: %d bytes", regularSize) - b.Logf("Compact network map: %d bytes", compactSize) - b.Logf("Compact cached network map: %d bytes", compactCachedSize) - b.Logf("Data savings: %d%% (%d bytes saved)", savingsPercent, regularSize-compactCachedSize) + b.Logf("========== Network Map Size Comparison ==========") + b.Logf("Regular network map: %d bytes", regularSize) + b.Logf("Compact network map: %d bytes (-%d%%)", compactSize, 100-int(float64(compactSize)/float64(regularSize)*100)) + b.Logf("Compact cached network map: %d bytes (-%d%%)", compactCachedSize, compactSavingsPercent) + b.Logf("Components: %d bytes (-%d%%)", componentsSize, componentsSavingsPercent) + b.Logf("") + b.Logf("Bandwidth savings (Compact cached): %d bytes saved (%d%%)", regularSize-compactCachedSize, compactSavingsPercent) + b.Logf("Bandwidth savings (Components): %d bytes saved (%d%%)", regularSize-componentsSize, componentsSavingsPercent) + b.Logf("=================================================") b.Run("Regular", func(b *testing.B) { b.ResetTimer() @@ -1324,4 +1340,52 @@ func BenchmarkGetPeerNetworkMapCompactCached(b *testing.B) { _ = builder.GetPeerNetworkMapCompactCached(ctx, testingPeerID, customZone, validatedPeersMap, nil) } }) + b.Run("Legacy", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, customZone, validatedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + } + }) + b.Run("LegacyCompacted", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMapCompacted(ctx, testingPeerID, customZone, validatedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + } + }) + + b.Run("ComponentsNetworkMap", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + components := account.GetPeerNetworkMapComponents( + ctx, + testingPeerID, + customZone, + validatedPeersMap, + resourcePolicies, + routers, + ) + _ = types.CalculateNetworkMapFromComponents(ctx, components) + } + }) + + b.Run("ComponentsCreation", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMapComponents( + ctx, + testingPeerID, + customZone, + validatedPeersMap, + resourcePolicies, + routers, + ) + } + }) + + b.Run("CalculationFromComponents", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = types.CalculateNetworkMapFromComponents(ctx, components) + } + }) }