mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-05 08:36:37 +00:00
calculate affected peers
This commit is contained in:
@@ -40,9 +40,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var updateAccountPeers bool
|
||||
var isUpdate = postureChecks.ID != ""
|
||||
var action = activity.PostureCheckCreated
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
|
||||
@@ -50,12 +50,10 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
if isUpdate {
|
||||
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
action = activity.PostureCheckUpdated
|
||||
|
||||
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(ctx, transaction, accountID, postureChecks.ID)
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
|
||||
}
|
||||
|
||||
postureChecks.AccountID = accountID
|
||||
@@ -75,8 +73,8 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return postureChecks, nil
|
||||
@@ -132,27 +130,23 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
|
||||
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
|
||||
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
|
||||
// collectPostureCheckAffectedGroupsAndPeers finds all policies referencing the given posture check
|
||||
// and collects their affected group IDs and direct peer IDs.
|
||||
func collectPostureCheckAffectedGroupsAndPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (groupIDs []string, directPeerIDs []string) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, policy.RuleGroups())
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
gIDs, pIDs := collectPolicyAffectedGroupsAndPeers(policy)
|
||||
groupIDs = append(groupIDs, gIDs...)
|
||||
directPeerIDs = append(directPeerIDs, pIDs...)
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
return groupIDs, directPeerIDs
|
||||
}
|
||||
|
||||
// validatePostureChecks validates the posture checks.
|
||||
|
||||
Reference in New Issue
Block a user