diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 4b47ecaa0..5cbf6fceb 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -261,6 +261,137 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) e return c.sendUpdateAccountPeers(ctx, accountID) } +// UpdateAffectedPeers updates only the specified peers that belong to an account. +// Should be called when a change is known to affect only a subset of peers. +// If peerIDs is empty, this is a no-op. +func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error { + if len(peerIDs) == 0 { + return nil + } + return c.sendUpdateForAffectedPeers(ctx, accountID, peerIDs) +} + +func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error { + log.WithContext(ctx).Tracef("updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName()) + + affected := make(map[string]struct{}, len(peerIDs)) + for _, id := range peerIDs { + affected[id] = struct{}{} + } + + // Fast check: any of the affected peers actually connected? + hasConnected := false + for _, id := range peerIDs { + if c.peersUpdateManager.HasChannel(id) { + hasConnected = true + break + } + } + if !hasConnected { + return nil + } + + account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to get account: %v", err) + } + + globalStart := time.Now() + + // Collect the subset of account peers that are both affected and connected. + var peersToUpdate []*nbpeer.Peer + for _, peer := range account.Peers { + if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) { + peersToUpdate = append(peersToUpdate, peer) + } + } + + if len(peersToUpdate) == 0 { + return nil + } + + approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + return fmt.Errorf("failed to get validate peers: %v", err) + } + + var wg sync.WaitGroup + semaphore := make(chan struct{}, 10) + + account.InjectProxyPolicies(ctx) + dnsCache := &cache.DNSConfigCache{} + dnsDomain := c.GetDNSDomain(account.Settings) + peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + groupIDToUserIDs := account.GetActiveGroupUsers() + + proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) + if err != nil { + log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) + return fmt.Errorf("failed to get proxy network maps: %v", err) + } + + extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to get flow enabled status: %v", err) + } + + dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) + + accountZones, err := c.repo.GetAccountZones(ctx, account.Id) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account zones: %v", err) + return fmt.Errorf("failed to get account zones: %v", err) + } + + for _, peer := range peersToUpdate { + wg.Add(1) + semaphore <- struct{}{} + go func(p *nbpeer.Peer) { + defer wg.Done() + defer func() { <-semaphore }() + + start := time.Now() + + postureChecks, err := c.getPeerPostureChecks(account, p.ID) + if err != nil { + log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err) + return + } + + c.metrics.CountCalcPostureChecksDuration(time.Since(start)) + start = time.Now() + + remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) + + c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) + + proxyNetworkMap, ok := proxyNetworkMaps[p.ID] + if ok { + remotePeerNetworkMap.Merge(proxyNetworkMap) + } + + peerGroups := account.GetPeerGroups(p.ID) + start = time.Now() + update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) + c.metrics.CountToSyncResponseDuration(time.Since(start)) + + c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeNetworkMap, + }) + }(peer) + } + + wg.Wait() + if c.accountManagerMetrics != nil { + c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart)) + } + + return nil +} + func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error { if !c.peersUpdateManager.HasChannel(peerId) { return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId) diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go index cfea2d3de..8d81556f9 100644 --- a/management/internals/controllers/network_map/interface.go +++ b/management/internals/controllers/network_map/interface.go @@ -19,6 +19,7 @@ const ( type Controller interface { UpdateAccountPeers(ctx context.Context, accountID string) error + UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error BufferUpdateAccountPeers(ctx context.Context, accountID string) error GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go index 4e86d2973..b2ef0b861 100644 --- a/management/internals/controllers/network_map/interface_mock.go +++ b/management/internals/controllers/network_map/interface_mock.go @@ -250,3 +250,17 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *go mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID) } + +// UpdateAffectedPeers mocks base method. +func (m *MockController) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers. +func (mr *MockControllerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs) +} diff --git a/management/server/account/manager.go b/management/server/account/manager.go index b4516d512..576054d1e 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -125,6 +125,7 @@ type Manager interface { GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error UpdateAccountPeers(ctx context.Context, accountID string) + UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) BufferUpdateAccountPeers(ctx context.Context, accountID string) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error diff --git a/management/server/account/manager_mock.go b/management/server/account/manager_mock.go index 36e5fe39f..c595346f8 100644 --- a/management/server/account/manager_mock.go +++ b/management/server/account/manager_mock.go @@ -1608,6 +1608,18 @@ func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID) } +// UpdateAffectedPeers mocks base method. +func (m *MockManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs) +} + +// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers. +func (mr *MockManagerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockManager)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs) +} + // UpdateAccountSettings mocks base method. func (m *MockManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) { m.ctrl.T.Helper() diff --git a/management/server/dns.go b/management/server/dns.go index baf6debc3..1e213ffbb 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -47,8 +47,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID return status.NewPermissionDeniedError() } - var updateAccountPeers bool var eventsToStore []func() + var affectedPeerIDs []string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil { @@ -63,11 +63,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) - updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups) - if err != nil { - return err - } - events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) eventsToStore = append(eventsToStore, events...) @@ -75,6 +70,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID return err } + allGroups := slices.Concat(addedGroups, removedGroups) + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroups, nil) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -85,8 +83,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID storeEvent() } - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return nil @@ -133,20 +131,6 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t return eventsToStore } -// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers. -func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) { - hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups) - if err != nil { - return false, err - } - - if hasPeers { - return true, nil - } - - return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups) -} - // validateDNSSettings validates the DNS settings. func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error { if len(settings.DisabledManagementGroups) == 0 { diff --git a/management/server/group.go b/management/server/group.go index 7b5b9b86c..4bd249398 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -79,7 +79,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use } var eventsToStore []func() - var updateAccountPeers bool + var affectedPeerIDs []string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { @@ -91,11 +91,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) eventsToStore = append(eventsToStore, events...) - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID}) - if err != nil { - return err - } - if err := transaction.CreateGroup(ctx, newGroup); err != nil { return status.Errorf(status.Internal, "failed to create group: %v", err) } @@ -106,6 +101,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use } } + groupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{newGroup.ID}) + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -116,8 +114,8 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use storeEvent() } - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return nil @@ -134,7 +132,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } var eventsToStore []func() - var updateAccountPeers bool + var affectedPeerIDs []string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { @@ -165,15 +163,13 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID}) - if err != nil { - return err - } - if err = transaction.UpdateGroup(ctx, newGroup); err != nil { return err } + groupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{newGroup.ID}) + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -184,8 +180,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use storeEvent() } - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return nil @@ -205,7 +201,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } var eventsToStore []func() - var updateAccountPeers bool var globalErr error groupIDs := make([]string, 0, len(groups)) @@ -243,17 +238,14 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) - if err != nil { - return err - } - for _, storeEvent := range eventsToStore { storeEvent() } - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, am.Store, accountID, groupIDs) + affectedPeerIDs := am.resolvePeerIDs(ctx, am.Store, accountID, allGroupIDs, directPeerIDs) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return globalErr @@ -273,7 +265,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } var eventsToStore []func() - var updateAccountPeers bool var globalErr error groupIDs := make([]string, 0, len(groups)) @@ -311,17 +302,14 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs) - if err != nil { - return err - } - for _, storeEvent := range eventsToStore { storeEvent() } - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, am.Store, accountID, groupIDs) + affectedPeerIDs := am.resolvePeerIDs(ctx, am.Store, accountID, allGroupIDs, directPeerIDs) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return globalErr @@ -473,27 +461,25 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { - var updateAccountPeers bool + var affectedPeerIDs []string var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) - if err != nil { - return err - } - if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil { return err } + allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID}) + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err } - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return nil @@ -502,7 +488,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupAddResource appends resource to the group func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { var group *types.Group - var updateAccountPeers bool + var affectedPeerIDs []string var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -515,23 +501,21 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID return nil } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) - if err != nil { - return err - } - if err = transaction.UpdateGroup(ctx, group); err != nil { return err } + allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID}) + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err } - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return nil @@ -539,14 +523,13 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { - var updateAccountPeers bool + var affectedPeerIDs []string var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) - if err != nil { - return err - } + // Resolve before removing, so the peer being removed is still included + allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID}) + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs) if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil { return err @@ -558,8 +541,8 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return err } - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return nil @@ -568,7 +551,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, // GroupDeleteResource removes resource from the group func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { var group *types.Group - var updateAccountPeers bool + var affectedPeerIDs []string var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -581,23 +564,21 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun return nil } - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) - if err != nil { - return err - } - if err = transaction.UpdateGroup(ctx, group); err != nil { return err } + allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID}) + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { return err } - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return nil @@ -840,18 +821,175 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac return false, nil } -// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. -func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { - groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs) - if err != nil { - return false, err +// collectGroupChangeAffectedGroups walks all entities that reference the changed groups +// and collects the full set of affected group IDs and direct peer IDs. +// This ensures that when a group changes, we update not just the peers in that group +// but also peers in other groups that share policies, routes, DNS, or nameserver configs. +func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) (allGroupIDs []string, directPeerIDs []string) { + if len(changedGroupIDs) == 0 { + return nil, nil } - for _, group := range groups { - if group.HasPeers() || group.HasResources() { - return true, nil + changedSet := make(map[string]struct{}, len(changedGroupIDs)) + for _, id := range changedGroupIDs { + changedSet[id] = struct{}{} + } + + groupSet := make(map[string]struct{}) + // Always include the changed groups themselves + for _, id := range changedGroupIDs { + groupSet[id] = struct{}{} + } + + peerSet := make(map[string]struct{}) + + // Policies: collect all rule groups + direct peer resources from policies that reference any changed group + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get policies for group change resolution: %v", err) + } else { + for _, policy := range policies { + if !policyReferencesGroups(policy, changedSet) { + continue + } + for _, gID := range policy.RuleGroups() { + groupSet[gID] = struct{}{} + } + for _, rule := range policy.Rules { + if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" { + peerSet[rule.SourceResource.ID] = struct{}{} + } + if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" { + peerSet[rule.DestinationResource.ID] = struct{}{} + } + } } } - return false, nil + // Routes: collect all groups + direct peer from routes that reference any changed group + routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get routes for group change resolution: %v", err) + } else { + for _, r := range routes { + if !routeReferencesGroups(r, changedSet) { + continue + } + for _, gID := range r.Groups { + groupSet[gID] = struct{}{} + } + for _, gID := range r.PeerGroups { + groupSet[gID] = struct{}{} + } + for _, gID := range r.AccessControlGroups { + groupSet[gID] = struct{}{} + } + if r.Peer != "" { + peerSet[r.Peer] = struct{}{} + } + } + } + + // Nameserver groups: collect groups from NS groups that reference any changed group + nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get nameserver groups for group change resolution: %v", err) + } else { + for _, ns := range nsGroups { + for _, gID := range ns.Groups { + if _, ok := changedSet[gID]; ok { + for _, g := range ns.Groups { + groupSet[g] = struct{}{} + } + break + } + } + } + } + + // DNS settings: if any changed group is in DisabledManagementGroups, include those groups + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get DNS settings for group change resolution: %v", err) + } else { + for _, gID := range dnsSettings.DisabledManagementGroups { + if _, ok := changedSet[gID]; ok { + groupSet[gID] = struct{}{} + } + } + } + + // Network routers: collect peer groups + direct peer from routers that reference any changed group + routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get network routers for group change resolution: %v", err) + } else { + for _, router := range routers { + if !routerReferencesGroups(router, changedSet) { + continue + } + for _, gID := range router.PeerGroups { + groupSet[gID] = struct{}{} + } + if router.Peer != "" { + peerSet[router.Peer] = struct{}{} + } + } + } + + allGroupIDs = make([]string, 0, len(groupSet)) + for gID := range groupSet { + allGroupIDs = append(allGroupIDs, gID) + } + + directPeerIDs = make([]string, 0, len(peerSet)) + for pID := range peerSet { + directPeerIDs = append(directPeerIDs, pID) + } + + return allGroupIDs, directPeerIDs +} + +func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool { + for _, rule := range policy.Rules { + for _, gID := range rule.Sources { + if _, ok := groupSet[gID]; ok { + return true + } + } + for _, gID := range rule.Destinations { + if _, ok := groupSet[gID]; ok { + return true + } + } + } + return false +} + +func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool { + for _, gID := range r.Groups { + if _, ok := groupSet[gID]; ok { + return true + } + } + for _, gID := range r.PeerGroups { + if _, ok := groupSet[gID]; ok { + return true + } + } + for _, gID := range r.AccessControlGroups { + if _, ok := groupSet[gID]; ok { + return true + } + } + return false +} + +func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[string]struct{}) bool { + for _, gID := range router.PeerGroups { + if _, ok := groupSet[gID]; ok { + return true + } + } + return false } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index ff369355e..5a9009da7 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -129,6 +129,7 @@ type MockAccountManager struct { AllowSyncFunc func(string, uint64) bool UpdateAccountPeersFunc func(ctx context.Context, accountID string) + UpdateAffectedPeersFunc func(ctx context.Context, accountID string, peerIDs []string) BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error @@ -206,6 +207,12 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID } } +func (am *MockAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) { + if am.UpdateAffectedPeersFunc != nil { + am.UpdateAffectedPeersFunc(ctx, accountID, peerIDs) + } +} + func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { if am.BufferUpdateAccountPeersFunc != nil { am.BufferUpdateAccountPeersFunc(ctx, accountID) diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 3d8c78912..823fc72d5 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "strings" "unicode/utf8" @@ -57,22 +58,19 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco SearchDomainsEnabled: searchDomainEnabled, } - var updateAccountPeers bool + var affectedPeerIDs []string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil { return err } - updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, newNSGroup.Groups) - if err != nil { - return err - } - if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil { return err } + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, newNSGroup.Groups, nil) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -81,8 +79,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return newNSGroup.Copy(), nil @@ -102,7 +100,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return status.NewPermissionDeniedError() } - var updateAccountPeers bool + var affectedPeerIDs []string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID) @@ -115,15 +113,13 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } - updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup) - if err != nil { - return err - } - if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil { return err } + allGroups := slices.Concat(nsGroupToSave.Groups, oldNSGroup.Groups) + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroups, nil) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -132,8 +128,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return nil @@ -150,7 +146,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco } var nsGroup *nbdns.NameServerGroup - var updateAccountPeers bool + var affectedPeerIDs []string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID) @@ -158,10 +154,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return err } - updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, nsGroup.Groups) - if err != nil { - return err - } + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, nsGroup.Groups, nil) if err = transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID); err != nil { return err @@ -175,8 +168,8 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return nil @@ -224,24 +217,6 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou return validateGroups(nameserverGroup.Groups, groups) } -// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. -func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) { - if !newNSGroup.Enabled && !oldNSGroup.Enabled { - return false, nil - } - - hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups) - if err != nil { - return false, err - } - - if hasPeers { - return true, nil - } - - return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups) -} - func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { if !primary && len(domains) == 0 { return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+ diff --git a/management/server/peer.go b/management/server/peer.go index a95ae17a3..39368c840 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -1294,6 +1294,38 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account _ = am.networkMapController.UpdateAccountPeers(ctx, accountID) } +// UpdateAffectedPeers updates only the specified peers that belong to an account. +// Should be called when a change is known to affect only a subset of peers. +func (am *DefaultAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) { + _ = am.networkMapController.UpdateAffectedPeers(ctx, accountID, peerIDs) +} + +// resolvePeerIDs resolves a set of group IDs and direct peer IDs into a +// deduplicated list of peer IDs suitable for UpdateAffectedPeers. +func (am *DefaultAccountManager) resolvePeerIDs(ctx context.Context, s store.Store, accountID string, groupIDs []string, directPeerIDs []string) []string { + peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs) + if err != nil { + log.WithContext(ctx).Errorf("failed to resolve peer IDs by groups: %v", err) + return nil + } + + if len(directPeerIDs) == 0 { + return peerIDs + } + + seen := make(map[string]struct{}, len(peerIDs)) + for _, id := range peerIDs { + seen[id] = struct{}{} + } + for _, id := range directPeerIDs { + if _, exists := seen[id]; !exists { + peerIDs = append(peerIDs, id) + seen[id] = struct{}{} + } + } + return peerIDs +} + func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { _ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID) } diff --git a/management/server/policy.go b/management/server/policy.go index 48297ca11..9ba4f98dd 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -45,12 +45,13 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user } var isUpdate = policy.ID != "" - var updateAccountPeers bool + var existingPolicy *types.Policy var action = activity.PolicyAdded var unchanged bool + var affectedPeerIDs []string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy) + existingPolicy, err = validatePolicy(ctx, transaction, accountID, policy) if err != nil { return err } @@ -64,25 +65,18 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user action = activity.PolicyUpdated - updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy) - if err != nil { - return err - } - if err = transaction.SavePolicy(ctx, policy); err != nil { return err } } else { - updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy) - if err != nil { - return err - } - if err = transaction.CreatePolicy(ctx, policy); err != nil { return err } } + groupIDs, directPeerIDs := collectPolicyAffectedGroupsAndPeers(policy, existingPolicy) + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -95,8 +89,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return policy, nil @@ -113,7 +107,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po } var policy *types.Policy - var updateAccountPeers bool + var affectedPeerIDs []string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID) @@ -121,10 +115,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy) - if err != nil { - return err - } + groupIDs, directPeerIDs := collectPolicyAffectedGroupsAndPeers(policy) + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs) if err = transaction.DeletePolicy(ctx, accountID, policyID); err != nil { return err @@ -138,8 +130,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return nil @@ -158,44 +150,24 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) } -// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers. -func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) { - for _, rule := range policy.Rules { - if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { - return true, nil +// collectPolicyAffectedGroupsAndPeers returns the group IDs and direct peer IDs +// referenced by the given policies' rules. +func collectPolicyAffectedGroupsAndPeers(policies ...*types.Policy) (groupIDs []string, directPeerIDs []string) { + for _, policy := range policies { + if policy == nil { + continue + } + groupIDs = append(groupIDs, policy.RuleGroups()...) + for _, rule := range policy.Rules { + if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" { + directPeerIDs = append(directPeerIDs, rule.SourceResource.ID) + } + if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" { + directPeerIDs = append(directPeerIDs, rule.DestinationResource.ID) + } } } - - return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) -} - -func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) { - if !policy.Enabled && !existingPolicy.Enabled { - return false, nil - } - - for _, rule := range existingPolicy.Rules { - if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { - return true, nil - } - } - - hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) - if err != nil { - return false, err - } - - if hasPeers { - return true, nil - } - - for _, rule := range policy.Rules { - if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { - return true, nil - } - } - - return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) + return } // validatePolicy validates the policy and its rules. For updates it returns diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 9562487c0..bbf4ed198 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -40,9 +40,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI return nil, status.NewPermissionDeniedError() } - var updateAccountPeers bool var isUpdate = postureChecks.ID != "" var action = activity.PostureCheckCreated + var affectedPeerIDs []string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil { @@ -50,12 +50,10 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI } if isUpdate { - updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID) - if err != nil { - return err - } - action = activity.PostureCheckUpdated + + groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(ctx, transaction, accountID, postureChecks.ID) + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs) } postureChecks.AccountID = accountID @@ -75,8 +73,8 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return postureChecks, nil @@ -132,27 +130,23 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID) } -// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. -func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { +// collectPostureCheckAffectedGroupsAndPeers finds all policies referencing the given posture check +// and collects their affected group IDs and direct peer IDs. +func collectPostureCheckAffectedGroupsAndPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (groupIDs []string, directPeerIDs []string) { policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) if err != nil { - return false, err + return nil, nil } for _, policy := range policies { if slices.Contains(policy.SourcePostureChecks, postureCheckID) { - hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, policy.RuleGroups()) - if err != nil { - return false, err - } - - if hasPeers { - return true, nil - } + gIDs, pIDs := collectPolicyAffectedGroupsAndPeers(policy) + groupIDs = append(groupIDs, gIDs...) + directPeerIDs = append(directPeerIDs, pIDs...) } } - return false, nil + return groupIDs, directPeerIDs } // validatePostureChecks validates the posture checks. diff --git a/management/server/route.go b/management/server/route.go index 2b4f11d05..30297f851 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -147,7 +147,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri } var newRoute *route.Route - var updateAccountPeers bool + var affectedPeerIDs []string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { newRoute = &route.Route{ @@ -173,15 +173,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return err } - updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute) - if err != nil { - return err - } - if err = transaction.SaveRoute(ctx, newRoute); err != nil { return err } + groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(newRoute) + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -190,8 +188,8 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return newRoute, nil @@ -208,8 +206,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } var oldRoute *route.Route - var oldRouteAffectsPeers bool - var newRouteAffectsPeers bool + var affectedPeerIDs []string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil { @@ -221,21 +218,15 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } - oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute) - if err != nil { - return err - } - - newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave) - if err != nil { - return err - } routeToSave.AccountID = accountID if err = transaction.SaveRoute(ctx, routeToSave); err != nil { return err } + groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(routeToSave, oldRoute) + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs) + return transaction.IncrementNetworkSerial(ctx, accountID) }) if err != nil { @@ -244,8 +235,8 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) - if oldRouteAffectsPeers || newRouteAffectsPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return nil @@ -261,19 +252,17 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri return status.NewPermissionDeniedError() } - var route *route.Route - var updateAccountPeers bool + var rt *route.Route + var affectedPeerIDs []string err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID)) + rt, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID)) if err != nil { return err } - updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route) - if err != nil { - return err - } + groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(rt) + affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs) if err = transaction.DeleteRoute(ctx, accountID, string(routeID)); err != nil { return err @@ -285,10 +274,10 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri return fmt.Errorf("failed to delete route %s: %w", routeID, err) } - am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) + am.StoreEvent(ctx, userID, string(rt.ID), accountID, activity.RouteRemoved, rt.EventMeta()) - if updateAccountPeers { - am.UpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) > 0 { + am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } return nil @@ -377,23 +366,20 @@ func getPlaceholderIP() netip.Prefix { return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) } -// areRouteChangesAffectPeers checks if a given route affects peers by determining -// if it has a routing peer, distribution, or peer groups that include peers. -func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) { - if route.Peer != "" { - return true, nil +// collectRouteAffectedGroupsAndPeers returns group IDs and direct peer IDs from the given routes. +func collectRouteAffectedGroupsAndPeers(routes ...*route.Route) (groupIDs []string, directPeerIDs []string) { + for _, r := range routes { + if r == nil { + continue + } + groupIDs = append(groupIDs, r.Groups...) + groupIDs = append(groupIDs, r.PeerGroups...) + groupIDs = append(groupIDs, r.AccessControlGroups...) + if r.Peer != "" { + directPeerIDs = append(directPeerIDs, r.Peer) + } } - - hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups) - if err != nil { - return false, err - } - - if hasPeers { - return true, nil - } - - return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups) + return } // GetRoutesByPrefixOrDomains return list of routes by account and route prefix diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 0a716d08d..c285b70c7 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -4662,6 +4662,23 @@ func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, gro return peers, nil } +func (s *SqlStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) { + if len(groupIDs) == 0 { + return nil, nil + } + + var peerIDs []string + result := s.db.Model(&types.GroupPeer{}). + Select("DISTINCT peer_id"). + Where("account_id = ? AND group_id IN ?", accountID, groupIDs). + Pluck("peer_id", &peerIDs) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "failed to get peer IDs by groups: %s", result.Error) + } + + return peerIDs, nil +} + func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) { tx := s.db if lockStrength != LockingStrengthNone { diff --git a/management/server/store/store.go b/management/server/store/store.go index 0d8b0678a..82489615f 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -159,6 +159,7 @@ type Store interface { GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) + GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index beee13d96..70366ed44 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -178,6 +178,7 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr int mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr) } + // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() @@ -1852,6 +1853,21 @@ func (mr *MockStoreMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupIDs int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockStore)(nil).GetPeersByGroupIDs), ctx, accountID, groupIDs) } +// GetPeerIDsByGroups mocks base method. +func (m *MockStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeerIDsByGroups", ctx, accountID, groupIDs) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPeerIDsByGroups indicates an expected call of GetPeerIDsByGroups. +func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs) +} + // GetPeersByIDs mocks base method. func (m *MockStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*peer.Peer, error) { m.ctrl.T.Helper()