diff --git a/management/server/account.go b/management/server/account.go index 2902bc952..043b797ab 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2126,12 +2126,12 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.GroupsPropagationEnabled { - removedGroupAffectsPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, removeOldGroups) + removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups) if err != nil { return err } - newGroupsAffectsPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, addNewGroups) + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups) if err != nil { return err } diff --git a/management/server/group.go b/management/server/group.go index da4c0fb94..c49bb2471 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -79,7 +79,7 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI // SaveGroups adds new groups to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { +func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err @@ -89,66 +89,35 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return status.NewUserNotPartOfAccountError() } - var ( - eventsToStore []func() - groupsToSave []*nbgroup.Group - ) - - for _, newGroup := range newGroups { - if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { - return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) - } - - if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) - if err != nil { - s, ok := status.FromError(err) - if !ok || s.ErrorType != status.NotFound { - return err - } - } - - // Avoid duplicate groups only for the API issued groups. - // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. - if existingGroup != nil { - return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) - } - - newGroup.ID = xid.New().String() - } - - for _, peerID := range newGroup.Peers { - if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil { - return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) - } - } - - newGroup.AccountID = accountID - groupsToSave = append(groupsToSave, newGroup) - - events := am.prepareGroupEvents(ctx, userID, accountID, newGroup) - eventsToStore = append(eventsToStore, events...) - } - - newGroupIDs := make([]string, 0, len(newGroups)) - for _, newGroup := range newGroups { - newGroupIDs = append(newGroupIDs, newGroup.ID) - } - - updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, newGroupIDs) - if err != nil { - return err - } + var eventsToStore []func() + var groupsToSave []*nbgroup.Group + var updateAccountPeers bool err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { + if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err + } + + newGroup.AccountID = accountID + groupsToSave = append(groupsToSave, newGroup) + groupIDs = append(groupIDs, newGroup.ID) + + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + if err != nil { + return err + } + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return err } - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave); err != nil { - return fmt.Errorf("failed to save groups: %w", err) - } - return nil + return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave) }) if err != nil { return err @@ -166,13 +135,13 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup *nbgroup.Group) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) removedPeers := make([]string, 0) - oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) + oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) if err == nil && oldGroup != nil { addedPeers = difference(newGroup.Peers, oldGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers) @@ -184,36 +153,34 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID } for _, peerID := range addedPeers { - peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) if err != nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: %v", peerID, err) continue } - peerCopy := peer // copy to avoid closure issues + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, - map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), - "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), - }) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta) }) } for _, peerID := range removedPeers { - peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + peer, err := transaction.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) if err != nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: %v", peerID, err) continue } - peerCopy := peer // copy to avoid closure issues + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, - map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), - "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), - }) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta) }) } @@ -246,28 +213,27 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use return status.NewUserNotPartOfAccountError() } - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err - } - - if group.Name == "All" { - return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") - } - - if err = am.validateDeleteGroup(ctx, group, userID); err != nil { - return err - } + var group *nbgroup.Group err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + if err != nil { + return err + } + + if group.IsGroupAll() { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } + + if err = validateDeleteGroup(ctx, transaction, group, userID); err != nil { + return err + } + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return err } - if err = transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID); err != nil { - return fmt.Errorf("failed to delete group: %w", err) - } - return nil + return transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID) }) if err != nil { return err @@ -279,6 +245,11 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use } // DeleteGroups deletes groups from an account. +// Note: This function does not acquire the global lock. +// It is the caller's responsibility to ensure proper locking is in place before invoking this method. +// +// If an error occurs while deleting a group, the function skips it and continues deleting other groups. +// Errors are collected and returned at the end. func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { @@ -289,36 +260,31 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us return status.NewUserNotPartOfAccountError() } - var ( - allErrors error - groupIDsToDelete []string - deletedGroups []*nbgroup.Group - ) - - for _, groupID := range groupIDs { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - continue - } - - if err := am.validateDeleteGroup(ctx, group, userID); err != nil { - allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) - continue - } - - groupIDsToDelete = append(groupIDsToDelete, groupID) - deletedGroups = append(deletedGroups, group) - } + var allErrors error + var groupIDsToDelete []string + var deletedGroups []*nbgroup.Group err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for _, groupID := range groupIDs { + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + if err != nil { + continue + } + + if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil { + allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) + continue + } + + groupIDsToDelete = append(groupIDsToDelete, groupID) + deletedGroups = append(deletedGroups, group) + } + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return err } - if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil { - return fmt.Errorf("failed to delete group: %w", err) - } - return nil + return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete) }) if err != nil { return err @@ -333,36 +299,30 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err - } - - add := true - for _, itemID := range group.Peers { - if itemID == peerID { - add = false - break - } - } - if add { - group.Peers = append(group.Peers, peerID) - } - - updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID}) - if err != nil { - return err - } + var group *nbgroup.Group + var updateAccountPeers bool + var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + if err != nil { + return err + } + + if updated := group.AddPeer(peerID); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return err } - if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil { - return fmt.Errorf("failed to save group: %w", err) - } - return nil + return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) }) if err != nil { return err @@ -377,38 +337,30 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err - } - - updated := false - for i, itemID := range group.Peers { - if itemID == peerID { - group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) - updated = true - break - } - } - - if !updated { - return nil - } - - updateAccountPeers, err := am.areGroupChangesAffectPeers(ctx, accountID, []string{groupID}) - if err != nil { - return err - } + var group *nbgroup.Group + var updateAccountPeers bool + var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + if err != nil { + return err + } + + if updated := group.RemovePeer(peerID); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return err } - if err = transaction.SaveGroup(ctx, LockingStrengthUpdate, group); err != nil { - return fmt.Errorf("failed to save group: %w", err) - } - return nil + return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) }) if err != nil { return err @@ -421,10 +373,43 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return nil } -func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group *nbgroup.Group, userID string) error { +// validateNewGroup validates the new group for existence and required fields. +func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error { + if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { + return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) + } + + if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { + existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) + if err != nil { + if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { + return err + } + } + + // Prevent duplicate groups for API-issued groups. + // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. + if existingGroup != nil { + return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) + } + + newGroup.ID = xid.New().String() + } + + for _, peerID := range newGroup.Peers { + _, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) + } + } + + return nil +} + +func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user if group.Issued == nbgroup.GroupIssuedIntegration { - executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return status.Errorf(status.NotFound, "user not found") } @@ -433,27 +418,27 @@ func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group } } - if isLinked, linkedRoute := am.isGroupLinkedToRoute(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } - if isLinked, linkedDns := am.isGroupLinkedToDns(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"name server groups", linkedDns.Name} } - if isLinked, linkedPolicy := am.isGroupLinkedToPolicy(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"policy", linkedPolicy.Name} } - if isLinked, linkedSetupKey := am.isGroupLinkedToSetupKey(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"setup key", linkedSetupKey.Name} } - if isLinked, linkedUser := am.isGroupLinkedToUser(ctx, group.AccountID, group.ID); isLinked { + if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"user", linkedUser.Id} } - dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) if err != nil { return err } @@ -462,7 +447,7 @@ func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group return &GroupLinkError{"disabled DNS management groups", group.Name} } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) + settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) if err != nil { return err } @@ -477,8 +462,8 @@ func (am *DefaultAccountManager) validateDeleteGroup(ctx context.Context, group } // isGroupLinkedToRoute checks if a group is linked to any route in the account. -func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accountID string, groupID string) (bool, *route.Route) { - routes, err := am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) { + routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) return false, nil @@ -494,8 +479,8 @@ func (am *DefaultAccountManager) isGroupLinkedToRoute(ctx context.Context, accou } // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, accountID string, groupID string) (bool, *Policy) { - policies, err := am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) return false, nil @@ -512,8 +497,8 @@ func (am *DefaultAccountManager) isGroupLinkedToPolicy(ctx context.Context, acco } // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { - nameServerGroups, err := am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) return false, nil @@ -531,8 +516,8 @@ func (am *DefaultAccountManager) isGroupLinkedToDns(ctx context.Context, account } // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, accountID string, groupID string) (bool, *SetupKey) { - setupKeys, err := am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) { + setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) return false, nil @@ -547,8 +532,8 @@ func (am *DefaultAccountManager) isGroupLinkedToSetupKey(ctx context.Context, ac } // isGroupLinkedToUser checks if a group is linked to any user in the account. -func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accountID string, groupID string) (bool, *User) { - users, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) { + users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) return false, nil @@ -563,12 +548,12 @@ func (am *DefaultAccountManager) isGroupLinkedToUser(ctx context.Context, accoun } // areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. -func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, accountID string, groupIDs []string) (bool, error) { +func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { if len(groupIDs) == 0 { return false, nil } - dnsSettings, err := am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) if err != nil { return false, err } @@ -577,13 +562,13 @@ func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) { return true, nil } - if linked, _ := am.isGroupLinkedToDns(ctx, accountID, groupID); linked { + if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked { return true, nil } - if linked, _ := am.isGroupLinkedToPolicy(ctx, accountID, groupID); linked { + if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked { return true, nil } - if linked, _ := am.isGroupLinkedToRoute(ctx, accountID, groupID); linked { + if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked { return true, nil } } @@ -591,40 +576,6 @@ func (am *DefaultAccountManager) areGroupChangesAffectPeers(ctx context.Context, return false, nil } -// isGroupLinkedToRoute checks if a group is linked to any route in the account. -func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { - for _, r := range routes { - if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { - return true, r - } - } - return false, nil -} - -// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { - for _, policy := range policies { - for _, rule := range policy.Rules { - if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { - return true, policy - } - } - } - return false, nil -} - -// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { - for _, dns := range nameServerGroups { - for _, g := range dns.Groups { - if g == groupID { - return true, dns - } - } - } - return false, nil -} - // anyGroupHasPeers checks if any of the given groups in the account have peers. func anyGroupHasPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { @@ -634,22 +585,3 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool { } return false } - -func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool { - for _, groupID := range groupIDs { - if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) { - return true - } - if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked { - return true - } - if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked { - return true - } - if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked { - return true - } - } - - return false -} diff --git a/management/server/peer.go b/management/server/peer.go index 994cc0287..33f27d8c7 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -331,7 +331,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - updateAccountPeers := isPeerInActiveGroup(account, peerID) + updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID) + if err != nil { + return err + } err = am.deletePeers(ctx, account, []string{peerID}, userID) if err != nil { @@ -594,9 +597,14 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if err != nil { return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err) } - groupsToAdd = append(groupsToAdd, allGroup.ID) - if areGroupChangesAffectPeers(account, groupsToAdd) { + + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd) + if err != nil { + return nil, nil, nil, err + } + + if newGroupsAffectsPeers { am.updateAccountPeers(ctx, accountID) } @@ -1033,12 +1041,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} { // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(account *Account, peerID string) bool { +func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) { peerGroupIDs := make([]string, 0) for _, group := range account.Groups { if slices.Contains(group.Peers, peerID) { peerGroupIDs = append(peerGroupIDs, group.ID) } } - return areGroupChangesAffectPeers(account, peerGroupIDs) + return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs) }