mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-30 14:16:38 +00:00
calculate affected peers
This commit is contained in:
@@ -79,7 +79,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
@@ -91,11 +91,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
|
||||
return status.Errorf(status.Internal, "failed to create group: %v", err)
|
||||
}
|
||||
@@ -106,6 +101,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
}
|
||||
|
||||
groupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -116,8 +114,8 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -134,7 +132,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
|
||||
@@ -165,15 +163,13 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
groupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
@@ -184,8 +180,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -205,7 +201,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
@@ -243,17 +238,14 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, am.Store, accountID, groupIDs)
|
||||
affectedPeerIDs := am.resolvePeerIDs(ctx, am.Store, accountID, allGroupIDs, directPeerIDs)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return globalErr
|
||||
@@ -273,7 +265,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var updateAccountPeers bool
|
||||
|
||||
var globalErr error
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
@@ -311,17 +302,14 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, am.Store, accountID, groupIDs)
|
||||
affectedPeerIDs := am.resolvePeerIDs(ctx, am.Store, accountID, allGroupIDs, directPeerIDs)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return globalErr
|
||||
@@ -473,27 +461,25 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
|
||||
// GroupAddPeer appends peer to the group
|
||||
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -502,7 +488,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
// GroupAddResource appends resource to the group
|
||||
func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
@@ -515,23 +501,21 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, group); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -539,14 +523,13 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
|
||||
// GroupDeletePeer removes peer from the group
|
||||
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Resolve before removing, so the peer being removed is still included
|
||||
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
|
||||
|
||||
if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
|
||||
return err
|
||||
@@ -558,8 +541,8 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -568,7 +551,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
// GroupDeleteResource removes resource from the group
|
||||
func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
|
||||
var group *types.Group
|
||||
var updateAccountPeers bool
|
||||
var affectedPeerIDs []string
|
||||
var err error
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
@@ -581,23 +564,21 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
return nil
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.UpdateGroup(ctx, group); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
|
||||
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -840,18 +821,175 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
|
||||
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
// collectGroupChangeAffectedGroups walks all entities that reference the changed groups
|
||||
// and collects the full set of affected group IDs and direct peer IDs.
|
||||
// This ensures that when a group changes, we update not just the peers in that group
|
||||
// but also peers in other groups that share policies, routes, DNS, or nameserver configs.
|
||||
func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) (allGroupIDs []string, directPeerIDs []string) {
|
||||
if len(changedGroupIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
if group.HasPeers() || group.HasResources() {
|
||||
return true, nil
|
||||
changedSet := make(map[string]struct{}, len(changedGroupIDs))
|
||||
for _, id := range changedGroupIDs {
|
||||
changedSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
groupSet := make(map[string]struct{})
|
||||
// Always include the changed groups themselves
|
||||
for _, id := range changedGroupIDs {
|
||||
groupSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
peerSet := make(map[string]struct{})
|
||||
|
||||
// Policies: collect all rule groups + direct peer resources from policies that reference any changed group
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get policies for group change resolution: %v", err)
|
||||
} else {
|
||||
for _, policy := range policies {
|
||||
if !policyReferencesGroups(policy, changedSet) {
|
||||
continue
|
||||
}
|
||||
for _, gID := range policy.RuleGroups() {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
peerSet[rule.SourceResource.ID] = struct{}{}
|
||||
}
|
||||
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||
peerSet[rule.DestinationResource.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
// Routes: collect all groups + direct peer from routes that reference any changed group
|
||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get routes for group change resolution: %v", err)
|
||||
} else {
|
||||
for _, r := range routes {
|
||||
if !routeReferencesGroups(r, changedSet) {
|
||||
continue
|
||||
}
|
||||
for _, gID := range r.Groups {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
for _, gID := range r.PeerGroups {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
for _, gID := range r.AccessControlGroups {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
if r.Peer != "" {
|
||||
peerSet[r.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Nameserver groups: collect groups from NS groups that reference any changed group
|
||||
nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get nameserver groups for group change resolution: %v", err)
|
||||
} else {
|
||||
for _, ns := range nsGroups {
|
||||
for _, gID := range ns.Groups {
|
||||
if _, ok := changedSet[gID]; ok {
|
||||
for _, g := range ns.Groups {
|
||||
groupSet[g] = struct{}{}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DNS settings: if any changed group is in DisabledManagementGroups, include those groups
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get DNS settings for group change resolution: %v", err)
|
||||
} else {
|
||||
for _, gID := range dnsSettings.DisabledManagementGroups {
|
||||
if _, ok := changedSet[gID]; ok {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Network routers: collect peer groups + direct peer from routers that reference any changed group
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get network routers for group change resolution: %v", err)
|
||||
} else {
|
||||
for _, router := range routers {
|
||||
if !routerReferencesGroups(router, changedSet) {
|
||||
continue
|
||||
}
|
||||
for _, gID := range router.PeerGroups {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
if router.Peer != "" {
|
||||
peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
allGroupIDs = make([]string, 0, len(groupSet))
|
||||
for gID := range groupSet {
|
||||
allGroupIDs = append(allGroupIDs, gID)
|
||||
}
|
||||
|
||||
directPeerIDs = make([]string, 0, len(peerSet))
|
||||
for pID := range peerSet {
|
||||
directPeerIDs = append(directPeerIDs, pID)
|
||||
}
|
||||
|
||||
return allGroupIDs, directPeerIDs
|
||||
}
|
||||
|
||||
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
for _, gID := range rule.Sources {
|
||||
if _, ok := groupSet[gID]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, gID := range rule.Destinations {
|
||||
if _, ok := groupSet[gID]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool {
|
||||
for _, gID := range r.Groups {
|
||||
if _, ok := groupSet[gID]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, gID := range r.PeerGroups {
|
||||
if _, ok := groupSet[gID]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
for _, gID := range r.AccessControlGroups {
|
||||
if _, ok := groupSet[gID]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[string]struct{}) bool {
|
||||
for _, gID := range router.PeerGroups {
|
||||
if _, ok := groupSet[gID]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user