calculate affected peers

This commit is contained in:
pascal
2026-04-27 17:49:12 +02:00
parent 154b81645a
commit 285bbc5ffb
16 changed files with 540 additions and 259 deletions

View File

@@ -79,7 +79,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
var eventsToStore []func()
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
@@ -91,11 +91,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
}
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err)
}
@@ -106,6 +101,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
}
groupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{newGroup.ID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -116,8 +114,8 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -134,7 +132,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
var eventsToStore []func()
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
@@ -165,15 +163,13 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
}
if err = transaction.UpdateGroup(ctx, newGroup); err != nil {
return err
}
groupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{newGroup.ID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, groupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -184,8 +180,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -205,7 +201,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
@@ -243,17 +238,14 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, am.Store, accountID, groupIDs)
affectedPeerIDs := am.resolvePeerIDs(ctx, am.Store, accountID, allGroupIDs, directPeerIDs)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return globalErr
@@ -273,7 +265,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
@@ -311,17 +302,14 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, am.Store, accountID, groupIDs)
affectedPeerIDs := am.resolvePeerIDs(ctx, am.Store, accountID, allGroupIDs, directPeerIDs)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return globalErr
@@ -473,27 +461,25 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
return err
}
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -502,7 +488,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupAddResource appends resource to the group
func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
var group *types.Group
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -515,23 +501,21 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.UpdateGroup(ctx, group); err != nil {
return err
}
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -539,14 +523,13 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
// Resolve before removing, so the peer being removed is still included
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
return err
@@ -558,8 +541,8 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -568,7 +551,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
// GroupDeleteResource removes resource from the group
func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
var group *types.Group
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -581,23 +564,21 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.UpdateGroup(ctx, group); err != nil {
return err
}
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, transaction, accountID, []string{groupID})
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroupIDs, directPeerIDs)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
if len(affectedPeerIDs) > 0 {
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
return nil
@@ -840,18 +821,175 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil
}
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
if err != nil {
return false, err
// collectGroupChangeAffectedGroups walks all entities that reference the changed groups
// and collects the full set of affected group IDs and direct peer IDs.
// This ensures that when a group changes, we update not just the peers in that group
// but also peers in other groups that share policies, routes, DNS, or nameserver configs.
func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) (allGroupIDs []string, directPeerIDs []string) {
if len(changedGroupIDs) == 0 {
return nil, nil
}
for _, group := range groups {
if group.HasPeers() || group.HasResources() {
return true, nil
changedSet := make(map[string]struct{}, len(changedGroupIDs))
for _, id := range changedGroupIDs {
changedSet[id] = struct{}{}
}
groupSet := make(map[string]struct{})
// Always include the changed groups themselves
for _, id := range changedGroupIDs {
groupSet[id] = struct{}{}
}
peerSet := make(map[string]struct{})
// Policies: collect all rule groups + direct peer resources from policies that reference any changed group
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get policies for group change resolution: %v", err)
} else {
for _, policy := range policies {
if !policyReferencesGroups(policy, changedSet) {
continue
}
for _, gID := range policy.RuleGroups() {
groupSet[gID] = struct{}{}
}
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
}
}
return false, nil
// Routes: collect all groups + direct peer from routes that reference any changed group
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get routes for group change resolution: %v", err)
} else {
for _, r := range routes {
if !routeReferencesGroups(r, changedSet) {
continue
}
for _, gID := range r.Groups {
groupSet[gID] = struct{}{}
}
for _, gID := range r.PeerGroups {
groupSet[gID] = struct{}{}
}
for _, gID := range r.AccessControlGroups {
groupSet[gID] = struct{}{}
}
if r.Peer != "" {
peerSet[r.Peer] = struct{}{}
}
}
}
// Nameserver groups: collect groups from NS groups that reference any changed group
nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get nameserver groups for group change resolution: %v", err)
} else {
for _, ns := range nsGroups {
for _, gID := range ns.Groups {
if _, ok := changedSet[gID]; ok {
for _, g := range ns.Groups {
groupSet[g] = struct{}{}
}
break
}
}
}
}
// DNS settings: if any changed group is in DisabledManagementGroups, include those groups
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get DNS settings for group change resolution: %v", err)
} else {
for _, gID := range dnsSettings.DisabledManagementGroups {
if _, ok := changedSet[gID]; ok {
groupSet[gID] = struct{}{}
}
}
}
// Network routers: collect peer groups + direct peer from routers that reference any changed group
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get network routers for group change resolution: %v", err)
} else {
for _, router := range routers {
if !routerReferencesGroups(router, changedSet) {
continue
}
for _, gID := range router.PeerGroups {
groupSet[gID] = struct{}{}
}
if router.Peer != "" {
peerSet[router.Peer] = struct{}{}
}
}
}
allGroupIDs = make([]string, 0, len(groupSet))
for gID := range groupSet {
allGroupIDs = append(allGroupIDs, gID)
}
directPeerIDs = make([]string, 0, len(peerSet))
for pID := range peerSet {
directPeerIDs = append(directPeerIDs, pID)
}
return allGroupIDs, directPeerIDs
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
for _, gID := range rule.Sources {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range rule.Destinations {
if _, ok := groupSet[gID]; ok {
return true
}
}
}
return false
}
func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool {
for _, gID := range r.Groups {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range r.PeerGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range r.AccessControlGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
return false
}
func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[string]struct{}) bool {
for _, gID := range router.PeerGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
return false
}