calculate affected peers

This commit is contained in:
pascal
2026-04-27 17:49:12 +02:00
parent 154b81645a
commit 285bbc5ffb
16 changed files with 540 additions and 259 deletions

View File

@@ -125,6 +125,7 @@ type Manager interface {
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
UpdateAccountPeers(ctx context.Context, accountID string)
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string)
BufferUpdateAccountPeers(ctx context.Context, accountID string)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error

View File

@@ -1608,6 +1608,18 @@ func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID)
}
// UpdateAffectedPeers mocks base method.
func (m *MockManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
}
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
func (mr *MockManagerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockManager)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
}
// UpdateAccountSettings mocks base method.
func (m *MockManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
m.ctrl.T.Helper()

View File

@@ -47,8 +47,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var eventsToStore []func()
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
@@ -63,11 +63,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
if err != nil {
return err
}
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...)
@@ -75,6 +70,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return err
}
allGroups := slices.Concat(addedGroups, removedGroups)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroups, nil)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -85,8 +83,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -133,20 +131,6 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
return eventsToStore
}
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups)
}
// validateDNSSettings validates the DNS settings.
func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error {
if len(settings.DisabledManagementGroups) == 0 {

View File

@@ -79,7 +79,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
var eventsToStore []func()
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
@@ -91,11 +91,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
}
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err)
}
@@ -106,6 +101,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
}
groupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{newGroup.ID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -116,8 +114,8 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -134,7 +132,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
var eventsToStore []func()
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
@@ -165,15 +163,13 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
}
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
return err
}
groupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{newGroup.ID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -184,8 +180,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -205,7 +201,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
@@ -243,17 +238,14 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, am.Store, accountID, groupIDs)
affectedPeerIDs := am.resolvePeerIDs(ctx, am.Store, accountID, allGroupIDs, directPeerIDs)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return globalErr
@@ -273,7 +265,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
@@ -311,17 +302,14 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, am.Store, accountID, groupIDs)
affectedPeerIDs := am.resolvePeerIDs(ctx, am.Store, accountID, allGroupIDs, directPeerIDs)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return globalErr
@@ -473,27 +461,25 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
return err
}
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -502,7 +488,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupAddResource appends resource to the group
func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
var group *types.Group
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -515,23 +501,21 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.UpdateGroup(ctx, group); err != nil {
return err
}
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -539,14 +523,13 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
// Resolve before removing, so the peer being removed is still included
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
return err
@@ -558,8 +541,8 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -568,7 +551,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
// GroupDeleteResource removes resource from the group
func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
var group *types.Group
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -581,23 +564,21 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.UpdateGroup(ctx, group); err != nil {
return err
}
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -840,18 +821,175 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil
}
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
if err != nil {
return false, err
// collectGroupChangeAffectedGroups walks all entities that reference the changed groups
// and collects the full set of affected group IDs and direct peer IDs.
// This ensures that when a group changes, we update not just the peers in that group
// but also peers in other groups that share policies, routes, DNS, or nameserver configs.
func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) (allGroupIDs []string, directPeerIDs []string) {
if len(changedGroupIDs) == 0 {
return nil, nil
}
for _, group := range groups {
if group.HasPeers() || group.HasResources() {
return true, nil
changedSet := make(map[string]struct{}, len(changedGroupIDs))
for _, id := range changedGroupIDs {
changedSet[id] = struct{}{}
}
groupSet := make(map[string]struct{})
// Always include the changed groups themselves
for _, id := range changedGroupIDs {
groupSet[id] = struct{}{}
}
peerSet := make(map[string]struct{})
// Policies: collect all rule groups + direct peer resources from policies that reference any changed group
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get policies for group change resolution: %v", err)
} else {
for _, policy := range policies {
if !policyReferencesGroups(policy, changedSet) {
continue
}
for _, gID := range policy.RuleGroups() {
groupSet[gID] = struct{}{}
}
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
}
}
return false, nil
// Routes: collect all groups + direct peer from routes that reference any changed group
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get routes for group change resolution: %v", err)
} else {
for _, r := range routes {
if !routeReferencesGroups(r, changedSet) {
continue
}
for _, gID := range r.Groups {
groupSet[gID] = struct{}{}
}
for _, gID := range r.PeerGroups {
groupSet[gID] = struct{}{}
}
for _, gID := range r.AccessControlGroups {
groupSet[gID] = struct{}{}
}
if r.Peer != "" {
peerSet[r.Peer] = struct{}{}
}
}
}
// Nameserver groups: collect groups from NS groups that reference any changed group
nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get nameserver groups for group change resolution: %v", err)
} else {
for _, ns := range nsGroups {
for _, gID := range ns.Groups {
if _, ok := changedSet[gID]; ok {
for _, g := range ns.Groups {
groupSet[g] = struct{}{}
}
break
}
}
}
}
// DNS settings: if any changed group is in DisabledManagementGroups, include those groups
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get DNS settings for group change resolution: %v", err)
} else {
for _, gID := range dnsSettings.DisabledManagementGroups {
if _, ok := changedSet[gID]; ok {
groupSet[gID] = struct{}{}
}
}
}
// Network routers: collect peer groups + direct peer from routers that reference any changed group
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get network routers for group change resolution: %v", err)
} else {
for _, router := range routers {
if !routerReferencesGroups(router, changedSet) {
continue
}
for _, gID := range router.PeerGroups {
groupSet[gID] = struct{}{}
}
if router.Peer != "" {
peerSet[router.Peer] = struct{}{}
}
}
}
allGroupIDs = make([]string, 0, len(groupSet))
for gID := range groupSet {
allGroupIDs = append(allGroupIDs, gID)
}
directPeerIDs = make([]string, 0, len(peerSet))
for pID := range peerSet {
directPeerIDs = append(directPeerIDs, pID)
}
return allGroupIDs, directPeerIDs
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
for _, gID := range rule.Sources {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range rule.Destinations {
if _, ok := groupSet[gID]; ok {
return true
}
}
}
return false
}
func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool {
for _, gID := range r.Groups {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range r.PeerGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range r.AccessControlGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
return false
}
func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[string]struct{}) bool {
for _, gID := range router.PeerGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
return false
}

View File

@@ -129,6 +129,7 @@ type MockAccountManager struct {
AllowSyncFunc func(string, uint64) bool
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
UpdateAffectedPeersFunc func(ctx context.Context, accountID string, peerIDs []string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
@@ -206,6 +207,12 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID
}
}
func (am *MockAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
if am.UpdateAffectedPeersFunc != nil {
am.UpdateAffectedPeersFunc(ctx, accountID, peerIDs)
}
}
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
if am.BufferUpdateAccountPeersFunc != nil {
am.BufferUpdateAccountPeersFunc(ctx, accountID)

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"slices"
"strings"
"unicode/utf8"
@@ -57,22 +58,19 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
SearchDomainsEnabled: searchDomainEnabled,
}
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
return err
}
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, newNSGroup.Groups)
if err != nil {
return err
}
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
return err
}
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, newNSGroup.Groups, nil)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -81,8 +79,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return newNSGroup.Copy(), nil
@@ -102,7 +100,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
@@ -115,15 +113,13 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err
}
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
if err != nil {
return err
}
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
return err
}
allGroups := slices.Concat(nsGroupToSave.Groups, oldNSGroup.Groups)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroups, nil)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -132,8 +128,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -150,7 +146,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
}
var nsGroup *nbdns.NameServerGroup
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID)
@@ -158,10 +154,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err
}
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, nsGroup.Groups)
if err != nil {
return err
}
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, nsGroup.Groups, nil)
if err = transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID); err != nil {
return err
@@ -175,8 +168,8 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -224,24 +217,6 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return validateGroups(nameserverGroup.Groups, groups)
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false, nil
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
}
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
if !primary && len(domains) == 0 {
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+

View File

@@ -1294,6 +1294,38 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID)
}
// UpdateAffectedPeers updates only the specified peers that belong to an account.
// Should be called when a change is known to affect only a subset of peers.
func (am *DefaultAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
_ = am.networkMapController.UpdateAffectedPeers(ctx, accountID, peerIDs)
}
// resolvePeerIDs resolves a set of group IDs and direct peer IDs into a
// deduplicated list of peer IDs suitable for UpdateAffectedPeers.
func (am *DefaultAccountManager) resolvePeerIDs(ctx context.Context, s store.Store, accountID string, groupIDs []string, directPeerIDs []string) []string {
peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve peer IDs by groups: %v", err)
return nil
}
if len(directPeerIDs) == 0 {
return peerIDs
}
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
return peerIDs
}
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID)
}

View File

@@ -45,12 +45,13 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
}
var isUpdate = policy.ID != ""
var updateAccountPeers bool
var existingPolicy *types.Policy
var action = activity.PolicyAdded
var unchanged bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
existingPolicy, err = validatePolicy(ctx, transaction, accountID, policy)
if err != nil {
return err
}
@@ -64,25 +65,18 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
action = activity.PolicyUpdated
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
if err != nil {
return err
}
if err = transaction.SavePolicy(ctx, policy); err != nil {
return err
}
} else {
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
if err != nil {
return err
}
if err = transaction.CreatePolicy(ctx, policy); err != nil {
return err
}
}
groupIDs, directPeerIDs := collectPolicyAffectedGroupsAndPeers(policy, existingPolicy)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -95,8 +89,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return policy, nil
@@ -113,7 +107,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
}
var policy *types.Policy
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID)
@@ -121,10 +115,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
return err
}
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
if err != nil {
return err
}
groupIDs, directPeerIDs := collectPolicyAffectedGroupsAndPeers(policy)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
if err = transaction.DeletePolicy(ctx, accountID, policyID); err != nil {
return err
@@ -138,8 +130,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -158,44 +150,24 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
}
// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
for _, rule := range policy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
// collectPolicyAffectedGroupsAndPeers returns the group IDs and direct peer IDs
// referenced by the given policies' rules.
func collectPolicyAffectedGroupsAndPeers(policies ...*types.Policy) (groupIDs []string, directPeerIDs []string) {
for _, policy := range policies {
if policy == nil {
continue
}
groupIDs = append(groupIDs, policy.RuleGroups()...)
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
directPeerIDs = append(directPeerIDs, rule.SourceResource.ID)
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
directPeerIDs = append(directPeerIDs, rule.DestinationResource.ID)
}
}
}
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
}
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
if !policy.Enabled && !existingPolicy.Enabled {
return false, nil
}
for _, rule := range existingPolicy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
for _, rule := range policy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
return
}
// validatePolicy validates the policy and its rules. For updates it returns

View File

@@ -40,9 +40,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
return nil, status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var isUpdate = postureChecks.ID != ""
var action = activity.PostureCheckCreated
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
@@ -50,12 +50,10 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
}
if isUpdate {
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
if err != nil {
return err
}
action = activity.PostureCheckUpdated
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(ctx, transaction, accountID, postureChecks.ID)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
}
postureChecks.AccountID = accountID
@@ -75,8 +73,8 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return postureChecks, nil
@@ -132,27 +130,23 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
}
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
// collectPostureCheckAffectedGroupsAndPeers finds all policies referencing the given posture check
// and collects their affected group IDs and direct peer IDs.
func collectPostureCheckAffectedGroupsAndPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (groupIDs []string, directPeerIDs []string) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
return nil, nil
}
for _, policy := range policies {
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, policy.RuleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
gIDs, pIDs := collectPolicyAffectedGroupsAndPeers(policy)
groupIDs = append(groupIDs, gIDs...)
directPeerIDs = append(directPeerIDs, pIDs...)
}
}
return false, nil
return groupIDs, directPeerIDs
}
// validatePostureChecks validates the posture checks.

View File

@@ -147,7 +147,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
}
var newRoute *route.Route
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
newRoute = &route.Route{
@@ -173,15 +173,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return err
}
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
if err != nil {
return err
}
if err = transaction.SaveRoute(ctx, newRoute); err != nil {
return err
}
groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(newRoute)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -190,8 +188,8 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return newRoute, nil
@@ -208,8 +206,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
}
var oldRoute *route.Route
var oldRouteAffectsPeers bool
var newRouteAffectsPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
@@ -221,21 +218,15 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
if err != nil {
return err
}
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
if err != nil {
return err
}
routeToSave.AccountID = accountID
if err = transaction.SaveRoute(ctx, routeToSave); err != nil {
return err
}
groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(routeToSave, oldRoute)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -244,8 +235,8 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -261,19 +252,17 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return status.NewPermissionDeniedError()
}
var route *route.Route
var updateAccountPeers bool
var rt *route.Route
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
rt, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
if err != nil {
return err
}
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
if err != nil {
return err
}
groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(rt)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
if err = transaction.DeleteRoute(ctx, accountID, string(routeID)); err != nil {
return err
@@ -285,10 +274,10 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return fmt.Errorf("failed to delete route %s: %w", routeID, err)
}
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
am.StoreEvent(ctx, userID, string(rt.ID), accountID, activity.RouteRemoved, rt.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -377,23 +366,20 @@ func getPlaceholderIP() netip.Prefix {
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
}
// areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers.
func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
if route.Peer != "" {
return true, nil
// collectRouteAffectedGroupsAndPeers returns group IDs and direct peer IDs from the given routes.
func collectRouteAffectedGroupsAndPeers(routes ...*route.Route) (groupIDs []string, directPeerIDs []string) {
for _, r := range routes {
if r == nil {
continue
}
groupIDs = append(groupIDs, r.Groups...)
groupIDs = append(groupIDs, r.PeerGroups...)
groupIDs = append(groupIDs, r.AccessControlGroups...)
if r.Peer != "" {
directPeerIDs = append(directPeerIDs, r.Peer)
}
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups)
return
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix

View File

@@ -4662,6 +4662,23 @@ func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, gro
return peers, nil
}
func (s *SqlStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
if len(groupIDs) == 0 {
return nil, nil
}
var peerIDs []string
result := s.db.Model(&types.GroupPeer{}).
Select("DISTINCT peer_id").
Where("account_id = ? AND group_id IN ?", accountID, groupIDs).
Pluck("peer_id", &peerIDs)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get peer IDs by groups: %s", result.Error)
}
return peerIDs, nil
}
func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {

View File

@@ -159,6 +159,7 @@ type Store interface {
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error)
GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error)
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)

View File

@@ -178,6 +178,7 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr int
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
}
// Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper()
@@ -1852,6 +1853,21 @@ func (mr *MockStoreMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupIDs int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockStore)(nil).GetPeersByGroupIDs), ctx, accountID, groupIDs)
}
// GetPeerIDsByGroups mocks base method.
func (m *MockStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerIDsByGroups", ctx, accountID, groupIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerIDsByGroups indicates an expected call of GetPeerIDsByGroups.
func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs)
}
// GetPeersByIDs mocks base method.
func (m *MockStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
m.ctrl.T.Helper()