diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index d46737c26..4def59a53 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -63,6 +63,8 @@ type Controller struct { expNewNetworkMap bool expNewNetworkMapAIDs map[string]struct{} + + compactedNetworkMap bool } type bufferUpdate struct { @@ -85,6 +87,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App newNetworkMapBuilder = false } + compactedNetworkMap, err := strconv.ParseBool(os.Getenv(types.EnvNewNetworkMapCompacted)) + if err != nil { + log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", types.EnvNewNetworkMapCompacted, err) + compactedNetworkMap = false + } + ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",") expIDs := make(map[string]struct{}, len(ids)) for _, id := range ids { @@ -108,6 +116,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App holder: types.NewHolder(), expNewNetworkMap: newNetworkMapBuilder, expNewNetworkMapAIDs: expIDs, + + compactedNetworkMap: compactedNetworkMap, } } @@ -233,6 +243,9 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) } else { remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) + if c.compactedNetworkMap { + account.ShadowCompareNetworkMap(ctx, p.ID, remotePeerNetworkMap, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs, c.accountManagerMetrics) + } } c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) @@ -354,6 +367,9 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) } else { remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) + if c.compactedNetworkMap { + account.ShadowCompareNetworkMap(ctx, peerId, remotePeerNetworkMap, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs, c.accountManagerMetrics) + } } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] @@ -469,7 +485,11 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr } else { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers()) + groupIDToUserIDs := account.GetActiveGroupUsers() + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) + if c.compactedNetworkMap { + account.ShadowCompareNetworkMap(ctx, peer.ID, networkMap, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs, c.accountManagerMetrics) + } } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] @@ -843,6 +863,9 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers()) + if c.compactedNetworkMap { + account.ShadowCompareNetworkMap(ctx, peer.ID, networkMap, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, account.GetActiveGroupUsers(), c.accountManagerMetrics) + } } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] diff --git a/management/server/telemetry/accountmanager_metrics.go b/management/server/telemetry/accountmanager_metrics.go index 3b1e078eb..d32ec0529 100644 --- a/management/server/telemetry/accountmanager_metrics.go +++ b/management/server/telemetry/accountmanager_metrics.go @@ -14,6 +14,10 @@ type AccountManagerMetrics struct { getPeerNetworkMapDurationMs metric.Float64Histogram networkMapObjectCount metric.Int64Histogram peerMetaUpdateCount metric.Int64Counter + + shadowLegacySizeBytes metric.Int64Histogram + shadowComponentsSizeBytes metric.Int64Histogram + shadowSavingsPercent metric.Int64Histogram } // NewAccountManagerMetrics creates an instance of AccountManagerMetrics @@ -55,12 +59,46 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account return nil, err } + shadowLegacySizeBytes, err := meter.Int64Histogram("management.account.shadow.legacy.size.bytes", + metric.WithUnit("bytes"), + metric.WithExplicitBucketBoundaries( + 1024, 5120, 10240, 51200, 102400, 512000, 1048576, 5242880, 10485760, + ), + metric.WithDescription("Size of legacy network map in bytes")) + if err != nil { + return nil, err + } + + shadowComponentsSizeBytes, err := meter.Int64Histogram("management.account.shadow.components.size.bytes", + metric.WithUnit("bytes"), + metric.WithExplicitBucketBoundaries( + 1024, 5120, 10240, 51200, 102400, 512000, 1048576, 5242880, 10485760, + ), + metric.WithDescription("Size of components-based network map in bytes")) + if err != nil { + return nil, err + } + + shadowSavingsPercent, err := meter.Int64Histogram("management.account.shadow.savings.percent", + metric.WithUnit("percent"), + metric.WithExplicitBucketBoundaries( + 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, + ), + metric.WithDescription("Percentage of bandwidth savings with components-based network map")) + if err != nil { + return nil, err + } + return &AccountManagerMetrics{ ctx: ctx, getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs, updateAccountPeersDurationMs: updateAccountPeersDurationMs, networkMapObjectCount: networkMapObjectCount, peerMetaUpdateCount: peerMetaUpdateCount, + + shadowLegacySizeBytes: shadowLegacySizeBytes, + shadowComponentsSizeBytes: shadowComponentsSizeBytes, + shadowSavingsPercent: shadowSavingsPercent, }, nil } @@ -84,3 +122,18 @@ func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) { func (metrics *AccountManagerMetrics) CountPeerMetUpdate() { metrics.peerMetaUpdateCount.Add(metrics.ctx, 1) } + +// CountShadowLegacySize records the size of legacy network map in bytes +func (metrics *AccountManagerMetrics) CountShadowLegacySize(bytes int64) { + metrics.shadowLegacySizeBytes.Record(metrics.ctx, bytes) +} + +// CountShadowComponentsSize records the size of components-based network map in bytes +func (metrics *AccountManagerMetrics) CountShadowComponentsSize(bytes int64) { + metrics.shadowComponentsSizeBytes.Record(metrics.ctx, bytes) +} + +// CountShadowSavingsPercent records the percentage of bandwidth savings +func (metrics *AccountManagerMetrics) CountShadowSavingsPercent(percent int64) { + metrics.shadowSavingsPercent.Record(metrics.ctx, percent) +} diff --git a/management/server/types/account_components.go b/management/server/types/account_components.go new file mode 100644 index 000000000..a2e5877f1 --- /dev/null +++ b/management/server/types/account_components.go @@ -0,0 +1,444 @@ +package types + +import ( + "context" + "slices" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/zones" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/route" +) + +func (a *Account) GetPeerNetworkMapComponents( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + accountZones []*zones.Zone, + validatedPeersMap map[string]struct{}, + resourcePolicies map[string][]*Policy, + routers map[string]map[string]*routerTypes.NetworkRouter, + groupIDToUserIDs map[string][]string, +) *NetworkMapComponents { + + peer := a.Peers[peerID] + if peer == nil { + return nil + } + + if _, ok := validatedPeersMap[peerID]; !ok { + return nil + } + + components := &NetworkMapComponents{ + PeerID: peerID, + Serial: a.Network.Serial, + Network: a.Network.Copy(), + NameServerGroups: make([]*nbdns.NameServerGroup, 0), + CustomZoneDomain: peersCustomZone.Domain, + ResourcePoliciesMap: make(map[string][]*Policy), + RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter), + NetworkResources: make([]*resourceTypes.NetworkResource, 0), + GroupIDToUserIDs: groupIDToUserIDs, + AllowedUserIDs: a.getAllowedUserIDs(), + PostureFailedPeers: make(map[string]map[string]struct{}, len(a.Policies)), + } + + components.AccountSettings = &AccountSettingsInfo{ + PeerLoginExpirationEnabled: a.Settings.PeerLoginExpirationEnabled, + PeerLoginExpiration: a.Settings.PeerLoginExpiration, + PeerInactivityExpirationEnabled: a.Settings.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: a.Settings.PeerInactivityExpiration, + } + + components.DNSSettings = &a.DNSSettings + + relevantPeers, relevantGroups, relevantPolicies, relevantRoutes := a.getPeersGroupsPoliciesRoutes(ctx, peerID, validatedPeersMap, &components.PostureFailedPeers) + + components.Peers = relevantPeers + components.Groups = relevantGroups + components.Policies = relevantPolicies + components.Routes = relevantRoutes + components.AllDNSRecords = filterDNSRecordsByPeers(peersCustomZone.Records, relevantPeers) + + peerGroups := a.GetPeerGroups(peerID) + components.AccountZones = filterPeerAppliedZones(ctx, accountZones, peerGroups) + + for _, nsGroup := range a.NameServerGroups { + if nsGroup.Enabled { + for _, gID := range nsGroup.Groups { + if _, found := relevantGroups[gID]; found { + components.NameServerGroups = append(components.NameServerGroups, nsGroup) + break + } + } + } + } + + for _, resource := range a.NetworkResources { + if !resource.Enabled { + continue + } + + policies, exists := resourcePolicies[resource.ID] + if !exists { + continue + } + + addSourcePeers := false + + networkRoutingPeers, routerExists := routers[resource.NetworkID] + if routerExists { + if _, ok := networkRoutingPeers[peerID]; ok { + addSourcePeers = true + } + } + + for _, policy := range policies { + if addSourcePeers { + var peers []string + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peers = []string{policy.Rules[0].SourceResource.ID} + } else { + peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + } + for _, pID := range a.getPostureValidPeersSaveFailed(peers, policy.SourcePostureChecks, &components.PostureFailedPeers) { + if _, exists := components.Peers[pID]; !exists { + components.Peers[pID] = a.GetPeer(pID) + } + } + } else { + peerInSources := false + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peerInSources = policy.Rules[0].SourceResource.ID == peerID + } else { + for _, groupID := range policy.SourceGroups() { + if group := a.GetGroup(groupID); group != nil && slices.Contains(group.Peers, peerID) { + peerInSources = true + break + } + } + } + if !peerInSources { + continue + } + isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, policy.SourcePostureChecks, peerID) + if !isValid && len(pname) > 0 { + if _, ok := (*components).PostureFailedPeers[pname]; !ok { + (*components).PostureFailedPeers[pname] = make(map[string]struct{}) + } + (*components).PostureFailedPeers[pname][peer.ID] = struct{}{} + continue + } + addSourcePeers = true + } + + for _, rule := range policy.Rules { + for _, srcGroupID := range rule.Sources { + if g := a.Groups[srcGroupID]; g != nil { + if _, exists := components.Groups[srcGroupID]; !exists { + components.Groups[srcGroupID] = g + } + } + } + for _, dstGroupID := range rule.Destinations { + if g := a.Groups[dstGroupID]; g != nil { + if _, exists := components.Groups[dstGroupID]; !exists { + components.Groups[dstGroupID] = g + } + } + } + } + components.ResourcePoliciesMap[resource.ID] = policies + } + + components.RoutersMap[resource.NetworkID] = networkRoutingPeers + for peerIDKey := range networkRoutingPeers { + if _, exists := components.Peers[peerIDKey]; !exists { + if p := a.Peers[peerIDKey]; p != nil { + components.Peers[peerIDKey] = p + } + } + } + + if addSourcePeers { + components.NetworkResources = append(components.NetworkResources, resource) + } + } + + filterGroupPeers(&components.Groups, components.Peers) + + return components +} + +func (a *Account) getPeersGroupsPoliciesRoutes( + ctx context.Context, + peerID string, + validatedPeersMap map[string]struct{}, + postureFailedPeers *map[string]map[string]struct{}, +) (map[string]*nbpeer.Peer, map[string]*Group, []*Policy, []*route.Route) { + relevantPeerIDs := make(map[string]*nbpeer.Peer, len(a.Peers)/4) + relevantGroupIDs := make(map[string]*Group, len(a.Groups)/4) + relevantPolicies := make([]*Policy, 0, len(a.Policies)) + relevantRoutes := make([]*route.Route, 0, len(a.Routes)) + + relevantPeerIDs[peerID] = a.GetPeer(peerID) + + for groupID, group := range a.Groups { + if slices.Contains(group.Peers, peerID) { + relevantGroupIDs[groupID] = a.GetGroup(groupID) + } + } + + routeAccessControlGroups := make(map[string]struct{}) + for _, r := range a.Routes { + for _, groupID := range r.Groups { + relevantGroupIDs[groupID] = a.GetGroup(groupID) + } + for _, groupID := range r.PeerGroups { + relevantGroupIDs[groupID] = a.GetGroup(groupID) + } + if r.Enabled { + for _, groupID := range r.AccessControlGroups { + relevantGroupIDs[groupID] = a.GetGroup(groupID) + routeAccessControlGroups[groupID] = struct{}{} + } + } + relevantRoutes = append(relevantRoutes, r) + } + + for _, policy := range a.Policies { + if !policy.Enabled { + continue + } + + policyRelevant := false + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + if len(routeAccessControlGroups) > 0 { + for _, destGroupID := range rule.Destinations { + if _, needed := routeAccessControlGroups[destGroupID]; needed { + policyRelevant = true + for _, srcGroupID := range rule.Sources { + relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID) + } + for _, dstGroupID := range rule.Destinations { + relevantGroupIDs[dstGroupID] = a.GetGroup(dstGroupID) + } + break + } + } + } + + var sourcePeers, destinationPeers []string + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers = []string{rule.SourceResource.ID} + if rule.SourceResource.ID == peerID { + peerInSources = true + } + } else { + sourcePeers, peerInSources = a.getPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap, postureFailedPeers) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + destinationPeers = []string{rule.DestinationResource.ID} + if rule.DestinationResource.ID == peerID { + peerInDestinations = true + } + } else { + destinationPeers, peerInDestinations = a.getPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap, postureFailedPeers) + } + + if peerInSources { + policyRelevant = true + for _, pid := range destinationPeers { + relevantPeerIDs[pid] = a.GetPeer(pid) + } + for _, dstGroupID := range rule.Destinations { + relevantGroupIDs[dstGroupID] = a.GetGroup(dstGroupID) + } + } + + if peerInDestinations { + policyRelevant = true + for _, pid := range sourcePeers { + relevantPeerIDs[pid] = a.GetPeer(pid) + } + for _, srcGroupID := range rule.Sources { + relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID) + } + } + } + if policyRelevant { + relevantPolicies = append(relevantPolicies, policy) + } + } + + return relevantPeerIDs, relevantGroupIDs, relevantPolicies, relevantRoutes +} + +func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, + validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) ([]string, bool) { + peerInGroups := false + filteredPeerIDs := make([]string, 0, len(a.Peers)) + seenPeerIds := make(map[string]struct{}, len(groups)) + + for _, gid := range groups { + group := a.GetGroup(gid) + if group == nil { + continue + } + + if group.IsGroupAll() || len(groups) == 1 { + filteredPeerIDs = filteredPeerIDs[:0] + seenPeerIds = make(map[string]struct{}, len(group.Peers)) + peerInGroups = false + for _, pid := range group.Peers { + peer, ok := a.Peers[pid] + if !ok || peer == nil { + continue + } + + isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid && len(pname) > 0 { + if _, ok := (*postureFailedPeers)[pname]; !ok { + (*postureFailedPeers)[pname] = make(map[string]struct{}) + } + (*postureFailedPeers)[pname][peer.ID] = struct{}{} + continue + } + + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeerIDs = append(filteredPeerIDs, peer.ID) + } + return filteredPeerIDs, peerInGroups + } + + for _, pid := range group.Peers { + if _, seen := seenPeerIds[pid]; seen { + continue + } + seenPeerIds[pid] = struct{}{} + peer, ok := a.Peers[pid] + if !ok || peer == nil { + continue + } + + isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid && len(pname) > 0 { + if _, ok := (*postureFailedPeers)[pname]; !ok { + (*postureFailedPeers)[pname] = make(map[string]struct{}) + } + (*postureFailedPeers)[pname][peer.ID] = struct{}{} + continue + } + + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeerIDs = append(filteredPeerIDs, peer.ID) + } + } + + return filteredPeerIDs, peerInGroups +} + +func (a *Account) validatePostureChecksOnPeerGetFailed(ctx context.Context, sourcePostureChecksID []string, peerID string) (bool, string) { + peer, ok := a.Peers[peerID] + if !ok && peer == nil { + return false, "" + } + + for _, postureChecksID := range sourcePostureChecksID { + postureChecks := a.GetPostureChecks(postureChecksID) + if postureChecks == nil { + continue + } + + for _, check := range postureChecks.GetChecks() { + isValid, _ := check.Check(ctx, *peer) + if !isValid { + return false, postureChecksID + } + } + } + return true, "" +} + +func (a *Account) getPostureValidPeersSaveFailed(inputPeers []string, postureChecksIDs []string, postureFailedPeers *map[string]map[string]struct{}) []string { + var dest []string + for _, peerID := range inputPeers { + valid, pname := a.validatePostureChecksOnPeerGetFailed(context.Background(), postureChecksIDs, peerID) + if valid { + dest = append(dest, peerID) + continue + } + if _, ok := (*postureFailedPeers)[pname]; !ok { + (*postureFailedPeers)[pname] = make(map[string]struct{}) + } + (*postureFailedPeers)[pname][peerID] = struct{}{} + } + return dest +} + +func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer) { + for groupID, groupInfo := range *groups { + filteredPeers := make([]string, 0, len(groupInfo.Peers)) + for _, pid := range groupInfo.Peers { + if _, exists := peers[pid]; exists { + filteredPeers = append(filteredPeers, pid) + } + } + + if len(filteredPeers) == 0 { + delete(*groups, groupID) + } else if len(filteredPeers) != len(groupInfo.Peers) { + ng := groupInfo.Copy() + ng.Peers = filteredPeers + (*groups)[groupID] = ng + } + } +} + +func filterDNSRecordsByPeers(records []nbdns.SimpleRecord, peers map[string]*nbpeer.Peer) []nbdns.SimpleRecord { + if len(records) == 0 || len(peers) == 0 { + return nil + } + + peerIPs := make(map[string]struct{}, len(peers)) + for _, peer := range peers { + if peer != nil { + peerIPs[peer.IP.String()] = struct{}{} + } + } + + filteredRecords := make([]nbdns.SimpleRecord, 0, len(records)) + for _, record := range records { + if _, exists := peerIPs[record.RData]; exists { + filteredRecords = append(filteredRecords, record) + } + } + + return filteredRecords +} diff --git a/management/server/types/networkmap_comparison_test.go b/management/server/types/networkmap_comparison_test.go new file mode 100644 index 000000000..446494db8 --- /dev/null +++ b/management/server/types/networkmap_comparison_test.go @@ -0,0 +1,841 @@ +package types + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/netip" + "os" + "path/filepath" + "sort" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/dns" + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/route" +) + +func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) { + account := createTestAccount() + ctx := context.Background() + + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid == offlinePeerID { + continue + } + validatedPeersMap[pid] = struct{}{} + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + legacyNetworkMap := account.GetPeerNetworkMap( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + nil, + groupIDToUserIDs, + ) + + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + + if components == nil { + t.Fatal("GetPeerNetworkMapComponents returned nil") + } + + newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) + + if newNetworkMap == nil { + t.Fatal("CalculateNetworkMapFromComponents returned nil") + } + + compareNetworkMaps(t, legacyNetworkMap, newNetworkMap) +} + +func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) { + account := createTestAccount() + ctx := context.Background() + + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid == offlinePeerID { + continue + } + validatedPeersMap[pid] = struct{}{} + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + legacyNetworkMap := account.GetPeerNetworkMap( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + nil, + groupIDToUserIDs, + ) + + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + + require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil") + + newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) + require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil") + + normalizeAndSortNetworkMap(legacyNetworkMap) + normalizeAndSortNetworkMap(newNetworkMap) + + componentsJSON, err := json.MarshalIndent(components, "", " ") + require.NoError(t, err, "error marshaling components to JSON") + + legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ") + require.NoError(t, err, "error marshaling legacy network map to JSON") + + newJSON, err := json.MarshalIndent(newNetworkMap, "", " ") + require.NoError(t, err, "error marshaling new network map to JSON") + + goldenDir := filepath.Join("testdata", "comparison") + err = os.MkdirAll(goldenDir, 0755) + require.NoError(t, err) + + legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json") + err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644) + require.NoError(t, err, "error writing legacy golden file") + + newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json") + err = os.WriteFile(newGoldenPath, newJSON, 0644) + require.NoError(t, err, "error writing components golden file") + + componentsPath := filepath.Join(goldenDir, "components.json") + err = os.WriteFile(componentsPath, componentsJSON, 0644) + require.NoError(t, err, "error writing components golden file") + + require.JSONEq(t, string(legacyJSON), string(newJSON), + "NetworkMaps from legacy and components approaches do not match.\n"+ + "Legacy JSON saved to: %s\n"+ + "Components JSON saved to: %s", + legacyGoldenPath, newGoldenPath) + + t.Logf("✅ NetworkMaps are identical") + t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath) + t.Logf(" Components NetworkMap: %s", newGoldenPath) +} + +func normalizeAndSortNetworkMap(nm *NetworkMap) { + if nm == nil { + return + } + + sort.Slice(nm.Peers, func(i, j int) bool { + return nm.Peers[i].ID < nm.Peers[j].ID + }) + + sort.Slice(nm.OfflinePeers, func(i, j int) bool { + return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID + }) + + sort.Slice(nm.Routes, func(i, j int) bool { + return string(nm.Routes[i].ID) < string(nm.Routes[j].ID) + }) + + sort.Slice(nm.FirewallRules, func(i, j int) bool { + if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP { + return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP + } + if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction { + return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction + } + if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol { + return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol + } + if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port { + return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port + } + return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID + }) + + for i := range nm.RoutesFirewallRules { + sort.Strings(nm.RoutesFirewallRules[i].SourceRanges) + } + + sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool { + if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination { + return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination + } + + minLen := len(nm.RoutesFirewallRules[i].SourceRanges) + if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen { + minLen = len(nm.RoutesFirewallRules[j].SourceRanges) + } + for k := 0; k < minLen; k++ { + if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] { + return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k] + } + } + if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) { + return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges) + } + + if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) { + return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID) + } + + if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID { + return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID + } + + if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port { + return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port + } + + return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol + }) + + if nm.DNSConfig.CustomZones != nil { + for i := range nm.DNSConfig.CustomZones { + sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool { + return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name + }) + } + } + + if len(nm.DNSConfig.NameServerGroups) != 0 { + sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool { + return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name + }) + } +} + +func compareNetworkMaps(t *testing.T, legacy, new *NetworkMap) { + t.Helper() + + if legacy.Network.Serial != new.Network.Serial { + t.Errorf("Network Serial mismatch: legacy=%d, new=%d", legacy.Network.Serial, new.Network.Serial) + } + + if len(legacy.Peers) != len(new.Peers) { + t.Errorf("Peers count mismatch: legacy=%d, new=%d", len(legacy.Peers), len(new.Peers)) + } + + legacyPeerIDs := make(map[string]bool) + for _, p := range legacy.Peers { + legacyPeerIDs[p.ID] = true + } + + for _, p := range new.Peers { + if !legacyPeerIDs[p.ID] { + t.Errorf("New NetworkMap contains peer %s not in legacy", p.ID) + } + } + + if len(legacy.OfflinePeers) != len(new.OfflinePeers) { + t.Errorf("OfflinePeers count mismatch: legacy=%d, new=%d", len(legacy.OfflinePeers), len(new.OfflinePeers)) + } + + if len(legacy.FirewallRules) != len(new.FirewallRules) { + t.Logf("FirewallRules count mismatch: legacy=%d, new=%d", len(legacy.FirewallRules), len(new.FirewallRules)) + } + + if len(legacy.Routes) != len(new.Routes) { + t.Logf("Routes count mismatch: legacy=%d, new=%d", len(legacy.Routes), len(new.Routes)) + } + + if len(legacy.RoutesFirewallRules) != len(new.RoutesFirewallRules) { + t.Logf("RoutesFirewallRules count mismatch: legacy=%d, new=%d", len(legacy.RoutesFirewallRules), len(new.RoutesFirewallRules)) + } + + if legacy.DNSConfig.ServiceEnable != new.DNSConfig.ServiceEnable { + t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, new=%v", legacy.DNSConfig.ServiceEnable, new.DNSConfig.ServiceEnable) + } +} + +const ( + numPeers = 100 + devGroupID = "group-dev" + opsGroupID = "group-ops" + allGroupID = "group-all" + routeID = route.ID("route-main") + routeHA1ID = route.ID("route-ha-1") + routeHA2ID = route.ID("route-ha-2") + policyIDDevOps = "policy-dev-ops" + policyIDAll = "policy-all" + policyIDPosture = "policy-posture" + policyIDDrop = "policy-drop" + postureCheckID = "posture-check-ver" + networkResourceID = "res-database" + networkID = "net-database" + networkRouterID = "router-database" + nameserverGroupID = "ns-group-main" + testingPeerID = "peer-60" + expiredPeerID = "peer-98" + offlinePeerID = "peer-99" + routingPeerID = "peer-95" + testAccountID = "account-comparison-test" +) + +func createTestAccount() *Account { + peers := make(map[string]*nbpeer.Peer) + devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} + + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + ip := net.IP{100, 64, 0, byte(i + 1)} + wtVersion := "0.25.0" + if i%2 == 0 { + wtVersion = "0.40.0" + } + + p := &nbpeer.Peer{ + ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, + } + + if peerID == expiredPeerID { + p.LoginExpirationEnabled = true + pastTimestamp := time.Now().Add(-2 * time.Hour) + p.LastLogin = &pastTimestamp + } + + peers[peerID] = p + allGroupPeers = append(allGroupPeers, peerID) + if i < numPeers/2 { + devGroupPeers = append(devGroupPeers, peerID) + } else { + opsGroupPeers = append(opsGroupPeers, peerID) + } + } + + groups := map[string]*Group{ + allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, + devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, + opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, + } + + policies := []*Policy{ + { + ID: policyIDAll, Name: "Default-Allow", Enabled: true, + Rules: []*PolicyRule{{ + ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{allGroupID}, Destinations: []string{allGroupID}, + }}, + }, + { + ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, + Rules: []*PolicyRule{{ + ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolTCP, Bidirectional: false, + PortRanges: []RulePortRange{{Start: 8080, End: 8090}}, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, + Rules: []*PolicyRule{{ + ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop, + Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, + SourcePostureChecks: []string{postureCheckID}, + Rules: []*PolicyRule{{ + ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID}, + }}, + }, + } + + routes := map[route.ID]*route.Route{ + routeID: { + ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), + Peer: peers["peer-75"].Key, + PeerID: "peer-75", + Description: "Route to internal resource", Enabled: true, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + }, + routeHA1ID: { + ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-80"].Key, + PeerID: "peer-80", + Description: "HA Route 1", Enabled: true, Metric: 1000, + PeerGroups: []string{allGroupID}, + Groups: []string{allGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + routeHA2ID: { + ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-90"].Key, + PeerID: "peer-90", + Description: "HA Route 2", Enabled: true, Metric: 900, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + } + + account := &Account{ + Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, + Network: &Network{ + Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, + }, + DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, + NameServerGroups: map[string]*nbdns.NameServerGroup{ + nameserverGroupID: { + ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, + NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{ + {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, + }}, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, + }, + Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, + NetworkRouters: []*routerTypes.NetworkRouter{ + {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, + }, + Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + return account +} + +func BenchmarkLegacyNetworkMap(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMap( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + nil, + groupIDToUserIDs, + ) + } +} + +func BenchmarkComponentsNetworkMap(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + _ = CalculateNetworkMapFromComponents(ctx, components) + } +} + +func BenchmarkComponentsCreation(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + } +} + +func BenchmarkCalculationFromComponents(b *testing.B) { + account := createTestAccount() + ctx := context.Background() + peerID := testingPeerID + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + pid := fmt.Sprintf("peer-%d", i) + if pid != offlinePeerID { + validatedPeersMap[pid] = struct{}{} + } + } + + peersCustomZone := nbdns.CustomZone{} + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + components := account.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = CalculateNetworkMapFromComponents(ctx, components) + } +} + +func TestGetPeerNetworkMap_ProdAccount_CompareImplementations(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + + testAccount := loadProdAccountFromJSON(t) + + testingPeerID := "cq3526bl0ubs73bbtpbg" + require.Contains(t, testAccount.Peers, testingPeerID, "Testing peer should exist in account") + + validatedPeersMap := make(map[string]struct{}) + for peerID := range testAccount.Peers { + validatedPeersMap[peerID] = struct{}{} + } + + resourcePolicies := testAccount.GetResourcePoliciesMap() + routers := testAccount.GetResourceRoutersMap() + groupIDToUserIDs := testAccount.GetActiveGroupUsers() + + legacyNetworkMap := testAccount.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, groupIDToUserIDs) + require.NotNil(t, legacyNetworkMap, "GetPeerNetworkMap returned nil") + + components := testAccount.GetPeerNetworkMapComponents(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, groupIDToUserIDs) + require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil") + + newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) + require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil") + + normalizeAndSortNetworkMap(legacyNetworkMap) + normalizeAndSortNetworkMap(newNetworkMap) + + componentsJSON, err := json.MarshalIndent(components, "", " ") + require.NoError(t, err, "error marshaling components to JSON") + + legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ") + require.NoError(t, err, "error marshaling legacy network map to JSON") + + newJSON, err := json.MarshalIndent(newNetworkMap, "", " ") + require.NoError(t, err, "error marshaling new network map to JSON") + + outputDir := filepath.Join("testdata", fmt.Sprintf("compare_peer_%s", testingPeerID)) + err = os.MkdirAll(outputDir, 0755) + require.NoError(t, err) + + legacyFilePath := filepath.Join(outputDir, "legacy_networkmap.json") + err = os.WriteFile(legacyFilePath, legacyJSON, 0644) + require.NoError(t, err) + + componentsPath := filepath.Join(outputDir, "components.json") + err = os.WriteFile(componentsPath, componentsJSON, 0644) + require.NoError(t, err) + + newFilePath := filepath.Join(outputDir, "components_networkmap.json") + err = os.WriteFile(newFilePath, newJSON, 0644) + require.NoError(t, err) + + t.Logf("Files saved to:\n Legacy NetworkMap: %s\n Components: %s\n Components NetworkMap: %s", + legacyFilePath, componentsPath, newFilePath) + + require.JSONEq(t, string(legacyJSON), string(newJSON), + "NetworkMaps from legacy and components approaches do not match for peer %s.\n"+ + "Legacy JSON saved to: %s\n"+ + "Components JSON saved to: %s\n"+ + "Components NetworkMap saved to: %s", + testingPeerID, legacyFilePath, componentsPath, newFilePath) + + t.Logf("✅ NetworkMaps are identical for peer %s", testingPeerID) +} + +func TestGetPeerNetworkMap_ProdAccount_AllPeers(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + + testAccount := loadProdAccountFromJSON(t) + + validatedPeersMap := make(map[string]struct{}) + for peerID := range testAccount.Peers { + validatedPeersMap[peerID] = struct{}{} + } + + resourcePolicies := testAccount.GetResourcePoliciesMap() + routers := testAccount.GetResourceRoutersMap() + groupIDToUserIDs := testAccount.GetActiveGroupUsers() + + var failedPeers []string + + for peerID := range testAccount.Peers { + legacyNetworkMap := testAccount.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil, groupIDToUserIDs) + require.NotNil(t, legacyNetworkMap, "GetPeerNetworkMap returned nil for peer %s", peerID) + + components := testAccount.GetPeerNetworkMapComponents(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, groupIDToUserIDs) + require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil for peer %s", peerID) + + newNetworkMap := CalculateNetworkMapFromComponents(ctx, components) + require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil for peer %s", peerID) + + normalizeAndSortNetworkMap(legacyNetworkMap) + normalizeAndSortNetworkMap(newNetworkMap) + + legacyJSON, err := json.Marshal(legacyNetworkMap) + require.NoError(t, err, "error marshaling legacy network map for peer %s", peerID) + + newJSON, err := json.Marshal(newNetworkMap) + require.NoError(t, err, "error marshaling new network map for peer %s", peerID) + + if string(legacyJSON) != string(newJSON) { + failedPeers = append(failedPeers, peerID) + + outputDir := filepath.Join("testdata", fmt.Sprintf("failed_peer_%s", peerID)) + require.NoError(t, os.MkdirAll(outputDir, 0755)) + + legacyIndented, _ := json.MarshalIndent(legacyNetworkMap, "", " ") + newIndented, _ := json.MarshalIndent(newNetworkMap, "", " ") + componentsIndented, _ := json.MarshalIndent(components, "", " ") + + _ = os.WriteFile(filepath.Join(outputDir, "legacy_networkmap.json"), legacyIndented, 0644) + _ = os.WriteFile(filepath.Join(outputDir, "components_networkmap.json"), newIndented, 0644) + _ = os.WriteFile(filepath.Join(outputDir, "components.json"), componentsIndented, 0644) + + t.Errorf("NetworkMap mismatch for peer %s. Files saved to %s", peerID, outputDir) + } + } + + require.Empty(t, failedPeers, "NetworkMap comparison failed for %d peers: %v", len(failedPeers), failedPeers) + + t.Logf("✅ NetworkMaps are identical for all %d peers", len(testAccount.Peers)) +} + +func loadProdAccountFromJSON(t testing.TB) *Account { + t.Helper() + + testDataPath := filepath.Join("testdata", "account_cnlf3j3l0ubs738o5d4g.json") + data, err := os.ReadFile(testDataPath) + require.NoError(t, err, "Failed to read prod account JSON file") + + var account Account + err = json.Unmarshal(data, &account) + require.NoError(t, err, "Failed to unmarshal prod account") + + if account.Groups == nil { + account.Groups = make(map[string]*Group) + } + if account.Peers == nil { + account.Peers = make(map[string]*nbpeer.Peer) + } + if account.Policies == nil { + account.Policies = []*Policy{} + } + + return &account +} + +func BenchmarkGetPeerNetworkMapCompactCached(b *testing.B) { + account := loadProdAccountFromJSON(b) + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}, len(account.Peers)) + for _, peer := range account.Peers { + validatedPeersMap[peer.ID] = struct{}{} + } + dnsDomain := account.Settings.DNSDomain + customZone := account.GetPeersCustomZone(ctx, dnsDomain) + + builder := NewNetworkMapBuilder(account, validatedPeersMap) + + testingPeerID := "d3knp53l0ubs738a3n6g" + + regularNm := builder.GetPeerNetworkMap(ctx, testingPeerID, customZone, nil, validatedPeersMap, nil) + + regularJSON, err := json.Marshal(regularNm) + require.NoError(b, err) + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + agUsers := account.GetActiveGroupUsers() + components := account.GetPeerNetworkMapComponents(ctx, testingPeerID, customZone, nil, validatedPeersMap, resourcePolicies, routers, agUsers) + componentsJSON, err := json.Marshal(components) + require.NoError(b, err) + + regularSize := len(regularJSON) + componentsSize := len(componentsJSON) + + componentsSavingsPercent := 100 - int(float64(componentsSize)/float64(regularSize)*100) + + b.ReportMetric(float64(regularSize), "regular_bytes") + b.ReportMetric(float64(componentsSize), "components_bytes") + b.ReportMetric(float64(componentsSavingsPercent), "components_savings_%") + + b.Logf("========== Network Map Size Comparison ==========") + b.Logf("Regular network map: %d bytes", regularSize) + b.Logf("Components: %d bytes (-%d%%)", componentsSize, componentsSavingsPercent) + b.Logf("") + b.Logf("Bandwidth savings (Components): %d bytes saved (%d%%)", regularSize-componentsSize, componentsSavingsPercent) + b.Logf("=================================================") + + b.Run("Legacy", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, customZone, nil, validatedPeersMap, resourcePolicies, routers, nil, agUsers) + } + }) + + b.Run("ComponentsNetworkMap", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + components := account.GetPeerNetworkMapComponents( + ctx, + testingPeerID, + customZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + agUsers, + ) + _ = CalculateNetworkMapFromComponents(ctx, components) + } + }) + + b.Run("ComponentsCreation", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = account.GetPeerNetworkMapComponents( + ctx, + testingPeerID, + customZone, + nil, + validatedPeersMap, + resourcePolicies, + routers, + agUsers, + ) + } + }) + + b.Run("CalculationFromComponents", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = CalculateNetworkMapFromComponents(ctx, components) + } + }) +} diff --git a/management/server/types/networkmap_components.go b/management/server/types/networkmap_components.go new file mode 100644 index 000000000..a17464749 --- /dev/null +++ b/management/server/types/networkmap_components.go @@ -0,0 +1,966 @@ +package types + +import ( + "context" + "net" + "net/netip" + "slices" + "strings" + "time" + + "github.com/netbirdio/netbird/client/ssh/auth" + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" +) + +type NetworkMapComponents struct { + PeerID string + Serial uint64 + + Network *Network + AccountSettings *AccountSettingsInfo + DNSSettings *DNSSettings + CustomZoneDomain string + + Peers map[string]*nbpeer.Peer + Groups map[string]*Group + Policies []*Policy + Routes []*route.Route + NameServerGroups []*nbdns.NameServerGroup + AllDNSRecords []nbdns.SimpleRecord + AccountZones []nbdns.CustomZone + ResourcePoliciesMap map[string][]*Policy + RoutersMap map[string]map[string]*routerTypes.NetworkRouter + NetworkResources []*resourceTypes.NetworkResource + + GroupIDToUserIDs map[string][]string + AllowedUserIDs map[string]struct{} + PostureFailedPeers map[string]map[string]struct{} +} + +type AccountSettingsInfo struct { + PeerLoginExpirationEnabled bool + PeerLoginExpiration time.Duration + PeerInactivityExpirationEnabled bool + PeerInactivityExpiration time.Duration +} + +func (c *NetworkMapComponents) GetPeerInfo(peerID string) *nbpeer.Peer { + return c.Peers[peerID] +} + +func (c *NetworkMapComponents) GetGroupInfo(groupID string) *Group { + return c.Groups[groupID] +} + +func (c *NetworkMapComponents) IsPeerInGroup(peerID, groupID string) bool { + group := c.GetGroupInfo(groupID) + if group == nil { + return false + } + + return slices.Contains(group.Peers, peerID) +} + +func (c *NetworkMapComponents) GetPeerGroups(peerID string) map[string]struct{} { + groups := make(map[string]struct{}) + for groupID, group := range c.Groups { + if slices.Contains(group.Peers, peerID) { + groups[groupID] = struct{}{} + } + } + return groups +} + +func (c *NetworkMapComponents) ValidatePostureChecksOnPeer(peerID string, postureCheckIDs []string) bool { + _, exists := c.Peers[peerID] + if !exists { + return false + } + if len(postureCheckIDs) == 0 { + return true + } + for _, checkID := range postureCheckIDs { + if failedPeers, exists := c.PostureFailedPeers[checkID]; exists { + if _, failed := failedPeers[peerID]; failed { + return false + } + } + } + return true +} + +type NetworkMapCalculator struct { + components *NetworkMapComponents +} + +func NewNetworkMapCalculator(components *NetworkMapComponents) *NetworkMapCalculator { + return &NetworkMapCalculator{ + components: components, + } +} + +func CalculateNetworkMapFromComponents(ctx context.Context, components *NetworkMapComponents) *NetworkMap { + calculator := NewNetworkMapCalculator(components) + return calculator.Calculate(ctx) +} + +func (calc *NetworkMapCalculator) Calculate(ctx context.Context) *NetworkMap { + targetPeerID := calc.components.PeerID + + peerGroups := calc.components.GetPeerGroups(targetPeerID) + + aclPeers, firewallRules, authorizedUsers, sshEnabled := calc.getPeerConnectionResources(ctx, targetPeerID) + + peersToConnect, expiredPeers := calc.filterPeersByLoginExpiration(aclPeers) + + routesUpdate := calc.getRoutesToSync(ctx, targetPeerID, peersToConnect, peerGroups) + routesFirewallRules := calc.getPeerRoutesFirewallRules(ctx, targetPeerID) + + isRouter, networkResourcesRoutes, sourcePeers := calc.getNetworkResourcesRoutesToSync(ctx, targetPeerID) + var networkResourcesFirewallRules []*RouteFirewallRule + if isRouter { + networkResourcesFirewallRules = calc.getPeerNetworkResourceFirewallRules(ctx, targetPeerID, networkResourcesRoutes) + } + + peersToConnectIncludingRouters := calc.addNetworksRoutingPeers( + networkResourcesRoutes, + targetPeerID, + peersToConnect, + expiredPeers, + isRouter, + sourcePeers, + ) + + dnsManagementStatus := calc.getPeerDNSManagementStatus(targetPeerID) + dnsUpdate := nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + var customZones []nbdns.CustomZone + + if calc.components.CustomZoneDomain != "" && len(calc.components.AllDNSRecords) > 0 { + customZones = append(customZones, nbdns.CustomZone{ + Domain: calc.components.CustomZoneDomain, + Records: calc.components.AllDNSRecords, + }) + } + + customZones = append(customZones, calc.components.AccountZones...) + + dnsUpdate.CustomZones = customZones + dnsUpdate.NameServerGroups = calc.getPeerNSGroups(targetPeerID) + } + + return &NetworkMap{ + Peers: peersToConnectIncludingRouters, + Network: calc.components.Network.Copy(), + Routes: append(networkResourcesRoutes, routesUpdate...), + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: append(networkResourcesFirewallRules, routesFirewallRules...), + AuthorizedUsers: authorizedUsers, + EnableSSH: sshEnabled, + } +} + +func (calc *NetworkMapCalculator) getPeerConnectionResources(ctx context.Context, targetPeerID string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) { + targetPeer := calc.components.GetPeerInfo(targetPeerID) + if targetPeer == nil { + return nil, nil, nil, false + } + + generateResources, getAccumulatedResources := calc.connResourcesGenerator(ctx, targetPeer) + authorizedUsers := make(map[string]map[string]struct{}) + sshEnabled := false + + for _, policy := range calc.components.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + var sourcePeers, destinationPeers []*nbpeer.Peer + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers, peerInSources = calc.getPeerFromResource(rule.SourceResource, targetPeerID) + } else { + sourcePeers, peerInSources = calc.getAllPeersFromGroups(ctx, rule.Sources, targetPeerID, policy.SourcePostureChecks) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + destinationPeers, peerInDestinations = calc.getPeerFromResource(rule.DestinationResource, targetPeerID) + } else { + destinationPeers, peerInDestinations = calc.getAllPeersFromGroups(ctx, rule.Destinations, targetPeerID, nil) + } + + if rule.Bidirectional { + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionIN) + } + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionOUT) + } + } + + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionOUT) + } + + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionIN) + } + + if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH { + sshEnabled = true + switch { + case len(rule.AuthorizedGroups) > 0: + for groupID, localUsers := range rule.AuthorizedGroups { + userIDs, ok := calc.components.GroupIDToUserIDs[groupID] + if !ok { + continue + } + + if len(localUsers) == 0 { + localUsers = []string{auth.Wildcard} + } + + for _, localUser := range localUsers { + if authorizedUsers[localUser] == nil { + authorizedUsers[localUser] = make(map[string]struct{}) + } + for _, userID := range userIDs { + authorizedUsers[localUser][userID] = struct{}{} + } + } + } + case rule.AuthorizedUser != "": + if authorizedUsers[auth.Wildcard] == nil { + authorizedUsers[auth.Wildcard] = make(map[string]struct{}) + } + authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{} + default: + authorizedUsers[auth.Wildcard] = calc.getAllowedUserIDs() + } + } else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled { + sshEnabled = true + authorizedUsers[auth.Wildcard] = calc.getAllowedUserIDs() + } + } + } + + peers, fwRules := getAccumulatedResources() + return peers, fwRules, authorizedUsers, sshEnabled +} + +func (calc *NetworkMapCalculator) getAllowedUserIDs() map[string]struct{} { + if calc.components.AllowedUserIDs != nil { + result := make(map[string]struct{}, len(calc.components.AllowedUserIDs)) + for k, v := range calc.components.AllowedUserIDs { + result[k] = v + } + return result + } + return make(map[string]struct{}) +} + +func (calc *NetworkMapCalculator) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { + rulesExists := make(map[string]struct{}) + peersExists := make(map[string]struct{}) + rules := make([]*FirewallRule, 0) + peers := make([]*nbpeer.Peer, 0) + + return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { + for _, peer := range groupPeers { + if peer == nil { + continue + } + + if _, ok := peersExists[peer.ID]; !ok { + peers = append(peers, peer) + peersExists[peer.ID] = struct{}{} + } + + fr := FirewallRule{ + PolicyID: rule.ID, + PeerIP: net.IP(peer.IP).String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + ruleID := rule.ID + fr.PeerIP + string(rune(direction)) + + fr.Protocol + fr.Action + for _, port := range rule.Ports { + ruleID += port + } + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { + rules = append(rules, &fr) + continue + } + + rules = append(rules, expandPortsAndRanges(fr, &PolicyRule{ + ID: rule.ID, + Ports: rule.Ports, + PortRanges: rule.PortRanges, + Protocol: rule.Protocol, + Action: rule.Action, + }, targetPeer)...) + } + }, func() ([]*nbpeer.Peer, []*FirewallRule) { + return peers, rules + } +} + +func (calc *NetworkMapCalculator) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) { + peerInGroups := false + uniquePeerIDs := calc.getUniquePeerIDsFromGroupsIDs(ctx, groups) + filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs)) + + for _, p := range uniquePeerIDs { + peerInfo := calc.components.GetPeerInfo(p) + if peerInfo == nil { + continue + } + + if _, ok := calc.components.Peers[p]; !ok { + continue + } + + if !calc.components.ValidatePostureChecksOnPeer(p, sourcePostureChecksIDs) { + continue + } + + if p == peerID { + peerInGroups = true + continue + } + + filteredPeers = append(filteredPeers, peerInfo) + } + + return filteredPeers, peerInGroups +} + +func (calc *NetworkMapCalculator) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { + peerIDs := make(map[string]struct{}, len(groups)) + for _, groupID := range groups { + group := calc.components.GetGroupInfo(groupID) + if group == nil { + continue + } + + if group.IsGroupAll() || len(groups) == 1 { + return group.Peers + } + + for _, peerID := range group.Peers { + peerIDs[peerID] = struct{}{} + } + } + + ids := make([]string, 0, len(peerIDs)) + for peerID := range peerIDs { + ids = append(ids, peerID) + } + + return ids +} + +func (calc *NetworkMapCalculator) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) { + if resource.ID == peerID { + return []*nbpeer.Peer{}, true + } + + peerInfo := calc.components.GetPeerInfo(resource.ID) + if peerInfo == nil { + return []*nbpeer.Peer{}, false + } + + return []*nbpeer.Peer{peerInfo}, false +} + +func (calc *NetworkMapCalculator) filterPeersByLoginExpiration(aclPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*nbpeer.Peer) { + var peersToConnect []*nbpeer.Peer + var expiredPeers []*nbpeer.Peer + + for _, p := range aclPeers { + expired, _ := p.LoginExpired(calc.components.AccountSettings.PeerLoginExpiration) + if calc.components.AccountSettings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, p) + continue + } + peersToConnect = append(peersToConnect, p) + } + + return peersToConnect, expiredPeers +} + +func (calc *NetworkMapCalculator) getPeerDNSManagementStatus(peerID string) bool { + peerGroups := calc.components.GetPeerGroups(peerID) + enabled := true + for _, groupID := range calc.components.DNSSettings.DisabledManagementGroups { + if _, found := peerGroups[groupID]; found { + enabled = false + break + } + } + return enabled +} + +func (calc *NetworkMapCalculator) getPeerNSGroups(peerID string) []*nbdns.NameServerGroup { + groupList := calc.components.GetPeerGroups(peerID) + + var peerNSGroups []*nbdns.NameServerGroup + + for _, nsGroup := range calc.components.NameServerGroups { + if !nsGroup.Enabled { + continue + } + for _, gID := range nsGroup.Groups { + _, found := groupList[gID] + if found { + targetPeerInfo := calc.components.GetPeerInfo(peerID) + if targetPeerInfo != nil && !calc.peerIsNameserver(targetPeerInfo, nsGroup) { + peerNSGroups = append(peerNSGroups, nsGroup.Copy()) + break + } + } + } + } + + return peerNSGroups +} + +func (calc *NetworkMapCalculator) peerIsNameserver(peerInfo *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { + for _, ns := range nsGroup.NameServers { + if peerInfo.IP.String() == ns.IP.String() { + return true + } + } + return false +} + +func (calc *NetworkMapCalculator) getRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route { + routes, peerDisabledRoutes := calc.getRoutingPeerRoutes(ctx, peerID) + peerRoutesMembership := make(LookupMap) + for _, r := range append(routes, peerDisabledRoutes...) { + peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} + } + + for _, peer := range aclPeers { + activeRoutes, _ := calc.getRoutingPeerRoutes(ctx, peer.ID) + groupFilteredRoutes := calc.filterRoutesByGroups(activeRoutes, peerGroups) + filteredRoutes := calc.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) + routes = append(routes, filteredRoutes...) + } + + return routes +} + +func (calc *NetworkMapCalculator) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + peerInfo := calc.components.GetPeerInfo(peerID) + if peerInfo == nil { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[route.ID]struct{}) + + takeRoute := func(r *route.Route, id string) { + if _, ok := seenRoute[r.ID]; ok { + return + } + seenRoute[r.ID] = struct{}{} + + routeObj := calc.copyRoute(r) + routeObj.Peer = peerInfo.Key + + if r.Enabled { + enabledRoutes = append(enabledRoutes, routeObj) + return + } + disabledRoutes = append(disabledRoutes, routeObj) + } + + for _, r := range calc.components.Routes { + for _, groupID := range r.PeerGroups { + group := calc.components.GetGroupInfo(groupID) + if group == nil { + continue + } + for _, id := range group.Peers { + if id != peerID { + continue + } + + newPeerRoute := calc.copyRoute(r) + newPeerRoute.Peer = id + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) + takeRoute(newPeerRoute, id) + break + } + } + if r.Peer == peerID { + takeRoute(calc.copyRoute(r), peerID) + } + } + + return enabledRoutes, disabledRoutes +} + +func (calc *NetworkMapCalculator) copyRoute(r *route.Route) *route.Route { + var groups, accessControlGroups, peerGroups []string + var domains domain.List + + if r.Groups != nil { + groups = append([]string{}, r.Groups...) + } + if r.AccessControlGroups != nil { + accessControlGroups = append([]string{}, r.AccessControlGroups...) + } + if r.PeerGroups != nil { + peerGroups = append([]string{}, r.PeerGroups...) + } + if r.Domains != nil { + domains = append(domain.List{}, r.Domains...) + } + + return &route.Route{ + ID: r.ID, + AccountID: r.AccountID, + Network: r.Network, + NetworkType: r.NetworkType, + Description: r.Description, + Peer: r.Peer, + PeerID: r.PeerID, + Metric: r.Metric, + Masquerade: r.Masquerade, + NetID: r.NetID, + Enabled: r.Enabled, + Groups: groups, + AccessControlGroups: accessControlGroups, + PeerGroups: peerGroups, + Domains: domains, + KeepRoute: r.KeepRoute, + SkipAutoApply: r.SkipAutoApply, + } +} + +func (calc *NetworkMapCalculator) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + for _, groupID := range r.Groups { + _, found := groupListMap[groupID] + if found { + filteredRoutes = append(filteredRoutes, r) + break + } + } + } + return filteredRoutes +} + +func (calc *NetworkMapCalculator) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + _, found := peerMemberships[string(r.GetHAUniqueID())] + if !found { + filteredRoutes = append(filteredRoutes, r) + } + } + return filteredRoutes +} + +func (calc *NetworkMapCalculator) getPeerRoutesFirewallRules(ctx context.Context, peerID string) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + enabledRoutes, _ := calc.getRoutingPeerRoutes(ctx, peerID) + for _, r := range enabledRoutes { + if len(r.AccessControlGroups) == 0 { + defaultPermit := calc.getDefaultPermit(r) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + distributionPeers := calc.getDistributionGroupsPeers(r) + + for _, accessGroup := range r.AccessControlGroups { + policies := calc.getAllRoutePoliciesFromGroups([]string{accessGroup}) + rules := calc.getRouteFirewallRules(ctx, peerID, policies, r, distributionPeers) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (calc *NetworkMapCalculator) findRoute(routeID route.ID) *route.Route { + for _, r := range calc.components.Routes { + if r.ID == routeID { + return r + } + } + + parts := strings.Split(string(routeID), ":") + if len(parts) > 1 { + baseRouteID := route.ID(parts[0]) + for _, r := range calc.components.Routes { + if r.ID == baseRouteID { + return r + } + } + } + + return nil +} + +func (calc *NetworkMapCalculator) getDefaultPermit(r *route.Route) []*RouteFirewallRule { + var rules []*RouteFirewallRule + + sources := []string{"0.0.0.0/0"} + if r.Network.Addr().Is6() { + sources = []string{"::/0"} + } + + rule := RouteFirewallRule{ + SourceRanges: sources, + Action: string(PolicyTrafficActionAccept), + Destination: r.Network.String(), + Protocol: string(PolicyRuleProtocolALL), + Domains: r.Domains, + IsDynamic: r.IsDynamic(), + RouteID: r.ID, + } + + rules = append(rules, &rule) + + if r.IsDynamic() { + ruleV6 := rule + ruleV6.SourceRanges = []string{"::/0"} + rules = append(rules, &ruleV6) + } + + return rules +} + +func (calc *NetworkMapCalculator) getDistributionGroupsPeers(r *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range r.Groups { + group := calc.components.GetGroupInfo(id) + if group == nil { + continue + } + + for _, pID := range group.Peers { + distPeers[pID] = struct{}{} + } + } + return distPeers +} + +func (calc *NetworkMapCalculator) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy { + routePolicies := make([]*Policy, 0) + for _, groupID := range accessControlGroups { + for _, policy := range calc.components.Policies { + for _, rule := range policy.Rules { + for _, destGroupID := range rule.Destinations { + if destGroupID == groupID { + routePolicies = append(routePolicies, policy) + break + } + } + } + } + } + + return routePolicies +} + +func (calc *NetworkMapCalculator) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, distributionPeers map[string]struct{}) []*RouteFirewallRule { + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + rulePeers := calc.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} + +func (calc *NetworkMapCalculator) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + for _, id := range rule.Sources { + group := calc.components.GetGroupInfo(id) + if group == nil { + continue + } + + for _, pID := range group.Peers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := calc.components.Peers[pID] + if distPeer && valid && calc.components.ValidatePostureChecksOnPeer(pID, postureChecks) { + distPeersWithPolicy[pID] = struct{}{} + } + } + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + _, distPeer := distributionPeers[rule.SourceResource.ID] + _, valid := calc.components.Peers[rule.SourceResource.ID] + if distPeer && valid && calc.components.ValidatePostureChecksOnPeer(rule.SourceResource.ID, postureChecks) { + distPeersWithPolicy[rule.SourceResource.ID] = struct{}{} + } + } + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peerInfo := calc.components.GetPeerInfo(pID) + if peerInfo == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peerInfo) + } + return distributionGroupPeers +} + +func (calc *NetworkMapCalculator) getNetworkResourcesRoutesToSync(ctx context.Context, peerID string) (bool, []*route.Route, map[string]struct{}) { + var isRoutingPeer bool + var routes []*route.Route + allSourcePeers := make(map[string]struct{}) + + for _, resource := range calc.components.NetworkResources { + if !resource.Enabled { + continue + } + + var addSourcePeers bool + + networkRoutingPeers, exists := calc.components.RoutersMap[resource.NetworkID] + if exists { + if router, ok := networkRoutingPeers[peerID]; ok { + isRoutingPeer, addSourcePeers = true, true + routes = append(routes, calc.getNetworkResourcesRoutes(resource, peerID, router)...) + } + } + + addedResourceRoute := false + for _, policy := range calc.components.ResourcePoliciesMap[resource.ID] { + var peers []string + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peers = []string{policy.Rules[0].SourceResource.ID} + } else { + peers = calc.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + } + if addSourcePeers { + for _, pID := range calc.getPostureValidPeers(peers, policy.SourcePostureChecks) { + allSourcePeers[pID] = struct{}{} + } + } else if calc.peerInSlice(peerID, peers) && calc.components.ValidatePostureChecksOnPeer(peerID, policy.SourcePostureChecks) { + for peerId, router := range networkRoutingPeers { + routes = append(routes, calc.getNetworkResourcesRoutes(resource, peerId, router)...) + } + addedResourceRoute = true + } + if addedResourceRoute { + break + } + } + } + + return isRoutingPeer, routes, allSourcePeers +} + +func (calc *NetworkMapCalculator) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerID string, router *routerTypes.NetworkRouter) []*route.Route { + resourceAppliedPolicies := calc.components.ResourcePoliciesMap[resource.ID] + + var routes []*route.Route + if len(resourceAppliedPolicies) > 0 { + peerInfo := calc.components.GetPeerInfo(peerID) + if peerInfo != nil { + routes = append(routes, calc.networkResourceToRoute(resource, peerInfo, router)) + } + } + + return routes +} + +func (calc *NetworkMapCalculator) networkResourceToRoute(resource *resourceTypes.NetworkResource, peer *nbpeer.Peer, router *routerTypes.NetworkRouter) *route.Route { + r := &route.Route{ + ID: route.ID(resource.ID + ":" + peer.ID), + AccountID: resource.AccountID, + Peer: peer.Key, + PeerID: peer.ID, + Metric: router.Metric, + Masquerade: router.Masquerade, + Enabled: resource.Enabled, + KeepRoute: true, + NetID: route.NetID(resource.Name), + Description: resource.Description, + } + + if resource.Type == resourceTypes.Host || resource.Type == resourceTypes.Subnet { + r.Network = resource.Prefix + + r.NetworkType = route.IPv4Network + if resource.Prefix.Addr().Is6() { + r.NetworkType = route.IPv6Network + } + } + + if resource.Type == resourceTypes.Domain { + domainList, err := domain.FromStringList([]string{resource.Domain}) + if err == nil { + r.Domains = domainList + r.NetworkType = route.DomainNetwork + r.Network = netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) + } + } + + return r +} + +func (calc *NetworkMapCalculator) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { + var dest []string + for _, peerID := range inputPeers { + if calc.components.ValidatePostureChecksOnPeer(peerID, postureChecksIDs) { + dest = append(dest, peerID) + } + } + return dest +} + +func (calc *NetworkMapCalculator) peerInSlice(peerID string, peers []string) bool { + for _, p := range peers { + if p == peerID { + return true + } + } + return false +} + +func (calc *NetworkMapCalculator) getPeerNetworkResourceFirewallRules(ctx context.Context, peerID string, routes []*route.Route) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + peerInfo := calc.components.GetPeerInfo(peerID) + if peerInfo == nil { + return routesFirewallRules + } + + for _, r := range routes { + if r.Peer != peerInfo.Key { + continue + } + + resourceID := string(r.GetResourceID()) + resourcePolicies := calc.components.ResourcePoliciesMap[resourceID] + distributionPeers := calc.getPoliciesSourcePeers(resourcePolicies) + + rules := calc.getRouteFirewallRules(ctx, peerID, resourcePolicies, r, distributionPeers) + for _, rule := range rules { + if len(rule.SourceRanges) > 0 { + routesFirewallRules = append(routesFirewallRules, rule) + } + } + } + + return routesFirewallRules +} + +func (calc *NetworkMapCalculator) getPoliciesSourcePeers(policies []*Policy) map[string]struct{} { + sourcePeers := make(map[string]struct{}) + + for _, policy := range policies { + for _, rule := range policy.Rules { + for _, sourceGroup := range rule.Sources { + group := calc.components.GetGroupInfo(sourceGroup) + if group == nil { + continue + } + + for _, peer := range group.Peers { + sourcePeers[peer] = struct{}{} + } + } + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + sourcePeers[rule.SourceResource.ID] = struct{}{} + } + } + } + + return sourcePeers +} + +func (calc *NetworkMapCalculator) addNetworksRoutingPeers( + networkResourcesRoutes []*route.Route, + peerID string, + peersToConnect []*nbpeer.Peer, + expiredPeers []*nbpeer.Peer, + isRouter bool, + sourcePeers map[string]struct{}, +) []*nbpeer.Peer { + + networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) + for _, r := range networkResourcesRoutes { + networkRoutesPeers[r.PeerID] = struct{}{} + } + + delete(sourcePeers, peerID) + delete(networkRoutesPeers, peerID) + + for _, existingPeer := range peersToConnect { + delete(sourcePeers, existingPeer.ID) + delete(networkRoutesPeers, existingPeer.ID) + } + for _, expPeer := range expiredPeers { + delete(sourcePeers, expPeer.ID) + delete(networkRoutesPeers, expPeer.ID) + } + + missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) + if isRouter { + for p := range sourcePeers { + missingPeers[p] = struct{}{} + } + } + for p := range networkRoutesPeers { + missingPeers[p] = struct{}{} + } + + for p := range missingPeers { + peerInfo := calc.components.GetPeerInfo(p) + if peerInfo != nil { + peersToConnect = append(peersToConnect, peerInfo) + } + } + + return peersToConnect +} diff --git a/management/server/types/networkmap_shadow.go b/management/server/types/networkmap_shadow.go new file mode 100644 index 000000000..d81ba39d8 --- /dev/null +++ b/management/server/types/networkmap_shadow.go @@ -0,0 +1,253 @@ +package types + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/zones" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/management/server/telemetry" +) + +var ( + shadowOutputDir = "netbird-shadow-compare" + EnvNewNetworkMapCompacted = "NB_NETWORK_MAP_COMPACTED" +) + +func (a *Account) ShadowCompareNetworkMap( + ctx context.Context, + peerID string, + legacyNetworkMap *NetworkMap, + peersCustomZone nbdns.CustomZone, + accountZones []*zones.Zone, + validatedPeersMap map[string]struct{}, + resourcePolicies map[string][]*Policy, + routers map[string]map[string]*routerTypes.NetworkRouter, + groupIDToUserIDs map[string][]string, + metrics *telemetry.AccountManagerMetrics, +) { + + go func() { + defer func() { + if r := recover(); r != nil { + log.WithContext(ctx).Errorf("shadow comparison panic for peer %s: %v", peerID, r) + } + }() + + components := a.GetPeerNetworkMapComponents( + ctx, + peerID, + peersCustomZone, + accountZones, + validatedPeersMap, + resourcePolicies, + routers, + groupIDToUserIDs, + ) + + if components == nil { + log.WithContext(ctx).Warnf("shadow comparison: components nil for peer %s", peerID) + return + } + + componentsNetworkMap := CalculateNetworkMapFromComponents(ctx, components) + if componentsNetworkMap == nil { + log.WithContext(ctx).Warnf("shadow comparison: components network map nil for peer %s", peerID) + return + } + + diff := compareNetworkMapCounts(legacyNetworkMap, componentsNetworkMap) + + legacyBytes, componentsBytes := measureSizes(legacyNetworkMap, components) + savingsPercent := 0 + if legacyBytes > 0 { + savingsPercent = 100 - int(float64(componentsBytes)*100/float64(legacyBytes)) + } + + if metrics != nil { + metrics.CountShadowLegacySize(int64(legacyBytes)) + metrics.CountShadowComponentsSize(int64(componentsBytes)) + metrics.CountShadowSavingsPercent(int64(savingsPercent)) + } + + if diff.HasDifferences() { + log.WithContext(ctx).Warnf( + "shadow comparison MISMATCH for peer %s: %s", + peerID, diff.String(), + ) + saveMismatchedMaps(ctx, a.Id, peerID, legacyNetworkMap, componentsNetworkMap, diff) + } + }() +} + +func measureSizes(networkMap *NetworkMap, components *NetworkMapComponents) (legacyBytes, componentsBytes int) { + if networkMap != nil { + if data, err := json.Marshal(networkMap); err == nil { + legacyBytes = len(data) + } + } + if components != nil { + if data, err := json.Marshal(components); err == nil { + componentsBytes = len(data) + } + } + return +} + +type NetworkMapDiff struct { + PeersLegacy int + PeersComponents int + OfflinePeersLegacy int + OfflinePeersComponents int + RoutesLegacy int + RoutesComponents int + FirewallRulesLegacy int + FirewallRulesComponents int + RouteFWRulesLegacy int + RouteFWRulesComponents int + ForwardingRulesLegacy int + ForwardingRulesComponents int + DNSZonesLegacy int + DNSZonesComponents int + DNSNSGroupsLegacy int + DNSNSGroupsComponents int + EnableSSHLegacy bool + EnableSSHComponents bool + AuthUsersLegacy int + AuthUsersComponents int +} + +func (d *NetworkMapDiff) HasDifferences() bool { + return d.PeersLegacy != d.PeersComponents || + d.OfflinePeersLegacy != d.OfflinePeersComponents || + d.RoutesLegacy != d.RoutesComponents || + d.FirewallRulesLegacy != d.FirewallRulesComponents || + d.RouteFWRulesLegacy != d.RouteFWRulesComponents || + d.ForwardingRulesLegacy != d.ForwardingRulesComponents || + d.DNSZonesLegacy != d.DNSZonesComponents || + d.DNSNSGroupsLegacy != d.DNSNSGroupsComponents || + d.EnableSSHLegacy != d.EnableSSHComponents || + d.AuthUsersLegacy != d.AuthUsersComponents +} + +func (d *NetworkMapDiff) String() string { + var diffs []string + + if d.PeersLegacy != d.PeersComponents { + diffs = append(diffs, fmt.Sprintf("Peers: %d vs %d", d.PeersLegacy, d.PeersComponents)) + } + if d.OfflinePeersLegacy != d.OfflinePeersComponents { + diffs = append(diffs, fmt.Sprintf("OfflinePeers: %d vs %d", d.OfflinePeersLegacy, d.OfflinePeersComponents)) + } + if d.RoutesLegacy != d.RoutesComponents { + diffs = append(diffs, fmt.Sprintf("Routes: %d vs %d", d.RoutesLegacy, d.RoutesComponents)) + } + if d.FirewallRulesLegacy != d.FirewallRulesComponents { + diffs = append(diffs, fmt.Sprintf("FirewallRules: %d vs %d", d.FirewallRulesLegacy, d.FirewallRulesComponents)) + } + if d.RouteFWRulesLegacy != d.RouteFWRulesComponents { + diffs = append(diffs, fmt.Sprintf("RoutesFirewallRules: %d vs %d", d.RouteFWRulesLegacy, d.RouteFWRulesComponents)) + } + if d.ForwardingRulesLegacy != d.ForwardingRulesComponents { + diffs = append(diffs, fmt.Sprintf("ForwardingRules: %d vs %d", d.ForwardingRulesLegacy, d.ForwardingRulesComponents)) + } + if d.DNSZonesLegacy != d.DNSZonesComponents { + diffs = append(diffs, fmt.Sprintf("DNSZones: %d vs %d", d.DNSZonesLegacy, d.DNSZonesComponents)) + } + if d.DNSNSGroupsLegacy != d.DNSNSGroupsComponents { + diffs = append(diffs, fmt.Sprintf("DNSNSGroups: %d vs %d", d.DNSNSGroupsLegacy, d.DNSNSGroupsComponents)) + } + if d.EnableSSHLegacy != d.EnableSSHComponents { + diffs = append(diffs, fmt.Sprintf("EnableSSH: %v vs %v", d.EnableSSHLegacy, d.EnableSSHComponents)) + } + if d.AuthUsersLegacy != d.AuthUsersComponents { + diffs = append(diffs, fmt.Sprintf("AuthorizedUsers: %d vs %d", d.AuthUsersLegacy, d.AuthUsersComponents)) + } + + if len(diffs) == 0 { + return "no differences" + } + + result := "" + for i, d := range diffs { + if i > 0 { + result += ", " + } + result += d + } + return result +} + +func compareNetworkMapCounts(legacy, components *NetworkMap) NetworkMapDiff { + diff := NetworkMapDiff{} + + if legacy != nil { + diff.PeersLegacy = len(legacy.Peers) + diff.OfflinePeersLegacy = len(legacy.OfflinePeers) + diff.RoutesLegacy = len(legacy.Routes) + diff.FirewallRulesLegacy = len(legacy.FirewallRules) + diff.RouteFWRulesLegacy = len(legacy.RoutesFirewallRules) + diff.ForwardingRulesLegacy = len(legacy.ForwardingRules) + diff.DNSZonesLegacy = len(legacy.DNSConfig.CustomZones) + diff.DNSNSGroupsLegacy = len(legacy.DNSConfig.NameServerGroups) + diff.EnableSSHLegacy = legacy.EnableSSH + diff.AuthUsersLegacy = len(legacy.AuthorizedUsers) + } + + if components != nil { + diff.PeersComponents = len(components.Peers) + diff.OfflinePeersComponents = len(components.OfflinePeers) + diff.RoutesComponents = len(components.Routes) + diff.FirewallRulesComponents = len(components.FirewallRules) + diff.RouteFWRulesComponents = len(components.RoutesFirewallRules) + diff.ForwardingRulesComponents = len(components.ForwardingRules) + diff.DNSZonesComponents = len(components.DNSConfig.CustomZones) + diff.DNSNSGroupsComponents = len(components.DNSConfig.NameServerGroups) + diff.EnableSSHComponents = components.EnableSSH + diff.AuthUsersComponents = len(components.AuthorizedUsers) + } + + return diff +} + +func saveMismatchedMaps(ctx context.Context, accountID, peerID string, legacy, components *NetworkMap, diff NetworkMapDiff) { + outputDir := filepath.Join(shadowOutputDir, accountID, peerID) + + if err := os.MkdirAll(outputDir, 0755); err != nil { + log.WithContext(ctx).Errorf("failed to create shadow output dir: %v", err) + return + } + + timestamp := time.Now().Format("20060102-150405") + + legacyPath := filepath.Join(outputDir, timestamp+"_legacy.json") + componentsPath := filepath.Join(outputDir, timestamp+"_components.json") + diffPath := filepath.Join(outputDir, timestamp+"_diff.json") + + if legacyJSON, err := json.MarshalIndent(legacy, "", " "); err == nil { + if err := os.WriteFile(legacyPath, legacyJSON, 0644); err != nil { + log.WithContext(ctx).Errorf("failed to write legacy map: %v", err) + } + } + + if componentsJSON, err := json.MarshalIndent(components, "", " "); err == nil { + if err := os.WriteFile(componentsPath, componentsJSON, 0644); err != nil { + log.WithContext(ctx).Errorf("failed to write components map: %v", err) + } + } + + if diffJSON, err := json.MarshalIndent(diff, "", " "); err == nil { + if err := os.WriteFile(diffPath, diffJSON, 0644); err != nil { + log.WithContext(ctx).Errorf("failed to write diff: %v", err) + } + } + + log.WithContext(ctx).Infof("shadow comparison mismatch saved to %s", outputDir) +}