mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-29 13:46:41 +00:00
calculate affected peers
This commit is contained in:
@@ -125,6 +125,7 @@ type Manager interface {
|
||||
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
|
||||
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
|
||||
UpdateAccountPeers(ctx context.Context, accountID string)
|
||||
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string)
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string)
|
||||
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
|
||||
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error
|
||||
|
||||
@@ -1608,6 +1608,18 @@ func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID interface{}
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID)
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers mocks base method.
|
||||
func (m *MockManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
|
||||
func (mr *MockManagerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockManager)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
// UpdateAccountSettings mocks base method.
|
||||
func (m *MockManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -47,8 +47,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
|
||||
@@ -63,11 +63,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
|
||||
removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
|
||||
|
||||
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
@@ -75,6 +70,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return err
|
||||
}
|
||||
|
||||
allGroups := slices.Concat(addedGroups, removedGroups)
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroups, nil)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -85,8 +83,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -133,20 +131,6 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
|
||||
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups)
|
||||
}
|
||||
|
||||
// validateDNSSettings validates the DNS settings.
|
||||
func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error {
|
||||
if len(settings.DisabledManagementGroups) == 0 {
|
||||
|
||||
@@ -79,7 +79,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
@@ -91,11 +91,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||
}
|
||||
@@ -106,6 +101,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
}
|
||||
|
||||
groupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -116,8 +114,8 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -134,7 +132,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
@@ -165,15 +163,13 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
groupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -184,8 +180,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -205,7 +201,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
@@ -243,17 +238,14 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, am.Store, accountID, groupIDs)
|
||||
affectedPeerIDs := am.resolvePeerIDs(ctx, am.Store, accountID, allGroupIDs, directPeerIDs)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return globalErr
|
||||
@@ -273,7 +265,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
@@ -311,17 +302,14 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, am.Store, accountID, groupIDs)
|
||||
affectedPeerIDs := am.resolvePeerIDs(ctx, am.Store, accountID, allGroupIDs, directPeerIDs)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return globalErr
|
||||
@@ -473,27 +461,25 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -502,7 +488,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
// GroupAddResource appends resource to the group
|
||||
func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
@@ -515,23 +501,21 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, group); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -539,14 +523,13 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Resolve before removing, so the peer being removed is still included
|
||||
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
|
||||
|
||||
if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
return err
|
||||
@@ -558,8 +541,8 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -568,7 +551,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
// GroupDeleteResource removes resource from the group
|
||||
func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
@@ -581,23 +564,21 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, group); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -840,18 +821,175 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
|
||||
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
// collectGroupChangeAffectedGroups walks all entities that reference the changed groups
|
||||
// and collects the full set of affected group IDs and direct peer IDs.
|
||||
// This ensures that when a group changes, we update not just the peers in that group
|
||||
// but also peers in other groups that share policies, routes, DNS, or nameserver configs.
|
||||
func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) (allGroupIDs []string, directPeerIDs []string) {
|
||||
if len(changedGroupIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if group.HasPeers() || group.HasResources() {
|
||||
return true, nil
|
||||
changedSet := make(map[string]struct{}, len(changedGroupIDs))
|
||||
for _, id := range changedGroupIDs {
|
||||
changedSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
groupSet := make(map[string]struct{})
|
||||
// Always include the changed groups themselves
|
||||
for _, id := range changedGroupIDs {
|
||||
groupSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
peerSet := make(map[string]struct{})
|
||||
|
||||
// Policies: collect all rule groups + direct peer resources from policies that reference any changed group
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get policies for group change resolution: %v", err)
|
||||
} else {
|
||||
for _, policy := range policies {
|
||||
if !policyReferencesGroups(policy, changedSet) {
|
||||
continue
|
||||
}
|
||||
for _, gID := range policy.RuleGroups() {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
// Routes: collect all groups + direct peer from routes that reference any changed group
|
||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get routes for group change resolution: %v", err)
|
||||
} else {
|
||||
for _, r := range routes {
|
||||
if !routeReferencesGroups(r, changedSet) {
|
||||
continue
|
||||
}
|
||||
for _, gID := range r.Groups {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
for _, gID := range r.PeerGroups {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
for _, gID := range r.AccessControlGroups {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
if r.Peer != "" {
|
||||
peerSet[r.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Nameserver groups: collect groups from NS groups that reference any changed group
|
||||
nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get nameserver groups for group change resolution: %v", err)
|
||||
} else {
|
||||
for _, ns := range nsGroups {
|
||||
for _, gID := range ns.Groups {
|
||||
if _, ok := changedSet[gID]; ok {
|
||||
for _, g := range ns.Groups {
|
||||
groupSet[g] = struct{}{}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DNS settings: if any changed group is in DisabledManagementGroups, include those groups
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get DNS settings for group change resolution: %v", err)
|
||||
} else {
|
||||
for _, gID := range dnsSettings.DisabledManagementGroups {
|
||||
if _, ok := changedSet[gID]; ok {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Network routers: collect peer groups + direct peer from routers that reference any changed group
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get network routers for group change resolution: %v", err)
|
||||
} else {
|
||||
for _, router := range routers {
|
||||
if !routerReferencesGroups(router, changedSet) {
|
||||
continue
|
||||
}
|
||||
for _, gID := range router.PeerGroups {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
if router.Peer != "" {
|
||||
peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
allGroupIDs = make([]string, 0, len(groupSet))
|
||||
for gID := range groupSet {
|
||||
allGroupIDs = append(allGroupIDs, gID)
|
||||
}
|
||||
|
||||
directPeerIDs = make([]string, 0, len(peerSet))
|
||||
for pID := range peerSet {
|
||||
directPeerIDs = append(directPeerIDs, pID)
|
||||
}
|
||||
|
||||
return allGroupIDs, directPeerIDs
|
||||
}
|
||||
|
||||
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
for _, gID := range rule.Sources {
|
||||
if _, ok := groupSet[gID]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, gID := range rule.Destinations {
|
||||
if _, ok := groupSet[gID]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool {
|
||||
for _, gID := range r.Groups {
|
||||
if _, ok := groupSet[gID]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, gID := range r.PeerGroups {
|
||||
if _, ok := groupSet[gID]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, gID := range r.AccessControlGroups {
|
||||
if _, ok := groupSet[gID]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[string]struct{}) bool {
|
||||
for _, gID := range router.PeerGroups {
|
||||
if _, ok := groupSet[gID]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -129,6 +129,7 @@ type MockAccountManager struct {
|
||||
|
||||
AllowSyncFunc func(string, uint64) bool
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
UpdateAffectedPeersFunc func(ctx context.Context, accountID string, peerIDs []string)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
|
||||
|
||||
@@ -206,6 +207,12 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID
|
||||
}
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
|
||||
if am.UpdateAffectedPeersFunc != nil {
|
||||
am.UpdateAffectedPeersFunc(ctx, accountID, peerIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
if am.BufferUpdateAccountPeersFunc != nil {
|
||||
am.BufferUpdateAccountPeersFunc(ctx, accountID)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
|
||||
@@ -57,22 +58,19 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
SearchDomainsEnabled: searchDomainEnabled,
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, newNSGroup.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, newNSGroup.Groups, nil)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -81,8 +79,8 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
|
||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return newNSGroup.Copy(), nil
|
||||
@@ -102,7 +100,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
|
||||
@@ -115,15 +113,13 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allGroups := slices.Concat(nsGroupToSave.Groups, oldNSGroup.Groups)
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroups, nil)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -132,8 +128,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
|
||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -150,7 +146,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
}
|
||||
|
||||
var nsGroup *nbdns.NameServerGroup
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID)
|
||||
@@ -158,10 +154,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, nsGroup.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, nsGroup.Groups, nil)
|
||||
|
||||
if err = transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID); err != nil {
|
||||
return err
|
||||
@@ -175,8 +168,8 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
|
||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -224,24 +217,6 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
|
||||
return validateGroups(nameserverGroup.Groups, groups)
|
||||
}
|
||||
|
||||
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
|
||||
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
|
||||
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
|
||||
}
|
||||
|
||||
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
|
||||
if !primary && len(domains) == 0 {
|
||||
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+
|
||||
|
||||
@@ -1294,6 +1294,38 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
// UpdateAffectedPeers updates only the specified peers that belong to an account.
|
||||
// Should be called when a change is known to affect only a subset of peers.
|
||||
func (am *DefaultAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
|
||||
_ = am.networkMapController.UpdateAffectedPeers(ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
// resolvePeerIDs resolves a set of group IDs and direct peer IDs into a
|
||||
// deduplicated list of peer IDs suitable for UpdateAffectedPeers.
|
||||
func (am *DefaultAccountManager) resolvePeerIDs(ctx context.Context, s store.Store, accountID string, groupIDs []string, directPeerIDs []string) []string {
|
||||
peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to resolve peer IDs by groups: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(directPeerIDs) == 0 {
|
||||
return peerIDs
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for _, id := range directPeerIDs {
|
||||
if _, exists := seen[id]; !exists {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
return peerIDs
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -45,12 +45,13 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
var isUpdate = policy.ID != ""
|
||||
var updateAccountPeers bool
|
||||
var existingPolicy *types.Policy
|
||||
var action = activity.PolicyAdded
|
||||
var unchanged bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
|
||||
existingPolicy, err = validatePolicy(ctx, transaction, accountID, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -64,25 +65,18 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
|
||||
action = activity.PolicyUpdated
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SavePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.CreatePolicy(ctx, policy); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
groupIDs, directPeerIDs := collectPolicyAffectedGroupsAndPeers(policy, existingPolicy)
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -95,8 +89,8 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return policy, nil
|
||||
@@ -113,7 +107,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
}
|
||||
|
||||
var policy *types.Policy
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID)
|
||||
@@ -121,10 +115,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
groupIDs, directPeerIDs := collectPolicyAffectedGroupsAndPeers(policy)
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
|
||||
|
||||
if err = transaction.DeletePolicy(ctx, accountID, policyID); err != nil {
|
||||
return err
|
||||
@@ -138,8 +130,8 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
|
||||
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -158,44 +150,24 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
|
||||
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
|
||||
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
// collectPolicyAffectedGroupsAndPeers returns the group IDs and direct peer IDs
|
||||
// referenced by the given policies' rules.
|
||||
func collectPolicyAffectedGroupsAndPeers(policies ...*types.Policy) (groupIDs []string, directPeerIDs []string) {
|
||||
for _, policy := range policies {
|
||||
if policy == nil {
|
||||
continue
|
||||
}
|
||||
groupIDs = append(groupIDs, policy.RuleGroups()...)
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
directPeerIDs = append(directPeerIDs, rule.SourceResource.ID)
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
directPeerIDs = append(directPeerIDs, rule.DestinationResource.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||
}
|
||||
|
||||
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
|
||||
if !policy.Enabled && !existingPolicy.Enabled {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, rule := range existingPolicy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
|
||||
return
|
||||
}
|
||||
|
||||
// validatePolicy validates the policy and its rules. For updates it returns
|
||||
|
||||
@@ -40,9 +40,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var isUpdate = postureChecks.ID != ""
|
||||
var action = activity.PostureCheckCreated
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
|
||||
@@ -50,12 +50,10 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
if isUpdate {
|
||||
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
action = activity.PostureCheckUpdated
|
||||
|
||||
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(ctx, transaction, accountID, postureChecks.ID)
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
|
||||
}
|
||||
|
||||
postureChecks.AccountID = accountID
|
||||
@@ -75,8 +73,8 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return postureChecks, nil
|
||||
@@ -132,27 +130,23 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
|
||||
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
|
||||
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
|
||||
// collectPostureCheckAffectedGroupsAndPeers finds all policies referencing the given posture check
|
||||
// and collects their affected group IDs and direct peer IDs.
|
||||
func collectPostureCheckAffectedGroupsAndPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (groupIDs []string, directPeerIDs []string) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, policy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
gIDs, pIDs := collectPolicyAffectedGroupsAndPeers(policy)
|
||||
groupIDs = append(groupIDs, gIDs...)
|
||||
directPeerIDs = append(directPeerIDs, pIDs...)
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
return groupIDs, directPeerIDs
|
||||
}
|
||||
|
||||
// validatePostureChecks validates the posture checks.
|
||||
|
||||
@@ -147,7 +147,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
}
|
||||
|
||||
var newRoute *route.Route
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
newRoute = &route.Route{
|
||||
@@ -173,15 +173,13 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.SaveRoute(ctx, newRoute); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(newRoute)
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -190,8 +188,8 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
|
||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return newRoute, nil
|
||||
@@ -208,8 +206,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
}
|
||||
|
||||
var oldRoute *route.Route
|
||||
var oldRouteAffectsPeers bool
|
||||
var newRouteAffectsPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
|
||||
@@ -221,21 +218,15 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return err
|
||||
}
|
||||
|
||||
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
routeToSave.AccountID = accountID
|
||||
|
||||
if err = transaction.SaveRoute(ctx, routeToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(routeToSave, oldRoute)
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -244,8 +235,8 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
if oldRouteAffectsPeers || newRouteAffectsPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -261,19 +252,17 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var route *route.Route
|
||||
var updateAccountPeers bool
|
||||
var rt *route.Route
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
|
||||
rt, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
groupIDs, directPeerIDs := collectRouteAffectedGroupsAndPeers(rt)
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
|
||||
|
||||
if err = transaction.DeleteRoute(ctx, accountID, string(routeID)); err != nil {
|
||||
return err
|
||||
@@ -285,10 +274,10 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
return fmt.Errorf("failed to delete route %s: %w", routeID, err)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
|
||||
am.StoreEvent(ctx, userID, string(rt.ID), accountID, activity.RouteRemoved, rt.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -377,23 +366,20 @@ func getPlaceholderIP() netip.Prefix {
|
||||
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
|
||||
}
|
||||
|
||||
// areRouteChangesAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers.
|
||||
func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
|
||||
if route.Peer != "" {
|
||||
return true, nil
|
||||
// collectRouteAffectedGroupsAndPeers returns group IDs and direct peer IDs from the given routes.
|
||||
func collectRouteAffectedGroupsAndPeers(routes ...*route.Route) (groupIDs []string, directPeerIDs []string) {
|
||||
for _, r := range routes {
|
||||
if r == nil {
|
||||
continue
|
||||
}
|
||||
groupIDs = append(groupIDs, r.Groups...)
|
||||
groupIDs = append(groupIDs, r.PeerGroups...)
|
||||
groupIDs = append(groupIDs, r.AccessControlGroups...)
|
||||
if r.Peer != "" {
|
||||
directPeerIDs = append(directPeerIDs, r.Peer)
|
||||
}
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups)
|
||||
return
|
||||
}
|
||||
|
||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||
|
||||
@@ -4662,6 +4662,23 @@ func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, gro
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var peerIDs []string
|
||||
result := s.db.Model(&types.GroupPeer{}).
|
||||
Select("DISTINCT peer_id").
|
||||
Where("account_id = ? AND group_id IN ?", accountID, groupIDs).
|
||||
Pluck("peer_id", &peerIDs)
|
||||
if result.Error != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to get peer IDs by groups: %s", result.Error)
|
||||
}
|
||||
|
||||
return peerIDs, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) {
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
|
||||
@@ -159,6 +159,7 @@ type Store interface {
|
||||
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
|
||||
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
|
||||
GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error)
|
||||
GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error)
|
||||
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
||||
|
||||
@@ -178,6 +178,7 @@ func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr int
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockStore) Close(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1852,6 +1853,21 @@ func (mr *MockStoreMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupIDs int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockStore)(nil).GetPeersByGroupIDs), ctx, accountID, groupIDs)
|
||||
}
|
||||
|
||||
// GetPeerIDsByGroups mocks base method.
|
||||
func (m *MockStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeerIDsByGroups", ctx, accountID, groupIDs)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeerIDsByGroups indicates an expected call of GetPeerIDsByGroups.
|
||||
func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs)
|
||||
}
|
||||
|
||||
// GetPeersByIDs mocks base method.
|
||||
func (m *MockStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Reference in New Issue
Block a user