diff --git a/management/server/account.go b/management/server/account.go index 183867e3b..107a48438 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1851,6 +1851,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } }() + if err = am.Store.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { + return fmt.Errorf("error saving groups: %w", err) + } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, claims.UserId) if err != nil { @@ -1860,10 +1864,6 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st addNewGroups := difference(updatedAutoGroups, user.AutoGroups) removeOldGroups := difference(user.AutoGroups, updatedAutoGroups) - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { - return fmt.Errorf("error saving groups: %w", err) - } - user.AutoGroups = updatedAutoGroups if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { return fmt.Errorf("error saving user: %w", err) @@ -1871,12 +1871,13 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - if err = transaction.AddUserPeersToGroups(ctx, accountID, claims.UserId, addNewGroups); err != nil { - return fmt.Errorf("error adding user peers to groups: %w", err) + updatedGroups, err := am.updateUserPeersInGroups(ctx, accountID, claims.UserId, addNewGroups, removeOldGroups) + if err != nil { + return fmt.Errorf("error modifying user peers in groups: %w", err) } - if err = transaction.RemoveUserPeersFromGroups(ctx, accountID, claims.UserId, removeOldGroups); err != nil { - return fmt.Errorf("error removing user peers from groups: %w", err) + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil { + return fmt.Errorf("error saving groups: %w", err) } if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { diff --git a/management/server/file_store.go b/management/server/file_store.go index b3375ee11..19cb0be32 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -1051,10 +1051,6 @@ func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.G return status.Errorf(status.Internal, "SaveGroup is not implemented") } -func (s *FileStore) AddUserPeersToGroups(_ context.Context, _ string, _ string, _ []string) error { - return status.Errorf(status.Internal, "AddUserPeersToGroups is not implemented") -} - -func (s *FileStore) RemoveUserPeersFromGroups(_ context.Context, _ string, _ string, _ []string) error { - return status.Errorf(status.Internal, "RemoveUserPeersFromGroups is not implemented") +func (s *FileStore) GetUserPeers(_ context.Context, _ LockingStrength, _, _ string) ([]*nbpeer.Peer, error) { + return nil, status.Errorf(status.Internal, "GetUserPeers is not implemented") } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 13f5c5e9e..baf76e54a 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1089,6 +1089,11 @@ func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID stri return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate) } +// GetUserPeers retrieves peers for a user. +func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { + return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID) +} + func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account") diff --git a/management/server/store.go b/management/server/store.go index 9bc6eafce..7434d672e 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -81,10 +81,9 @@ type Store interface { GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error - AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error - RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error diff --git a/management/server/user.go b/management/server/user.go index 6d01561c6..8c3ad846d 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -8,14 +8,14 @@ import ( "time" "github.com/google/uuid" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server/activity" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + log "github.com/sirupsen/logrus" ) const ( @@ -1254,6 +1254,79 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil } +// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. +func (am *DefaultAccountManager) updateUserPeersInGroups(ctx context.Context, accountID, userID string, groupsToAdd, + groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) { + + if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { + return + } + + peers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID) + if err != nil { + return nil, err + } + + userPeerIDMap := make(map[string]struct{}, len(peers)) + for _, peer := range peers { + userPeerIDMap[peer.ID] = struct{}{} + } + + for _, gid := range groupsToAdd { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) + if err != nil { + return nil, err + } + addUserPeersToGroup(userPeerIDMap, group) + groupsToUpdate = append(groupsToUpdate, group) + } + + for _, gid := range groupsToRemove { + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID) + if err != nil { + return nil, err + } + removeUserPeersFromGroup(userPeerIDMap, group) + groupsToUpdate = append(groupsToUpdate, group) + } + + return groupsToUpdate, nil +} + +// addUserPeersToGroup adds the user's peers to the group. +func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { + groupPeers := make(map[string]struct{}, len(group.Peers)) + for _, pid := range group.Peers { + groupPeers[pid] = struct{}{} + } + + for pid := range userPeerIDs { + groupPeers[pid] = struct{}{} + } + + group.Peers = make([]string, 0, len(groupPeers)) + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } +} + +// removeUserPeersFromGroup removes user's peers from the group. +func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { + // skip removing peers from group All + if group.Name == "All" { + return + } + + updatedPeers := make([]string, 0, len(group.Peers)) + for _, pid := range group.Peers { + if _, found := userPeerIDs[pid]; !found { + updatedPeers = append(updatedPeers, pid) + } + } + + group.Peers = updatedPeers +} + func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { for _, user := range userData { if user.ID == userID {