move updateUserPeersInGroups to account manager

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-03 20:13:55 +03:00
parent 18614d0b52
commit 19177feb80
5 changed files with 92 additions and 18 deletions

View File

@@ -1851,6 +1851,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
}()
if err = am.Store.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, claims.UserId)
if err != nil {
@@ -1860,10 +1864,6 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
addNewGroups := difference(updatedAutoGroups, user.AutoGroups)
removeOldGroups := difference(user.AutoGroups, updatedAutoGroups)
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
return fmt.Errorf("error saving user: %w", err)
@@ -1871,12 +1871,13 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
// Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled {
if err = transaction.AddUserPeersToGroups(ctx, accountID, claims.UserId, addNewGroups); err != nil {
return fmt.Errorf("error adding user peers to groups: %w", err)
updatedGroups, err := am.updateUserPeersInGroups(ctx, accountID, claims.UserId, addNewGroups, removeOldGroups)
if err != nil {
return fmt.Errorf("error modifying user peers in groups: %w", err)
}
if err = transaction.RemoveUserPeersFromGroups(ctx, accountID, claims.UserId, removeOldGroups); err != nil {
return fmt.Errorf("error removing user peers from groups: %w", err)
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {

View File

@@ -1051,10 +1051,6 @@ func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.G
return status.Errorf(status.Internal, "SaveGroup is not implemented")
}
func (s *FileStore) AddUserPeersToGroups(_ context.Context, _ string, _ string, _ []string) error {
return status.Errorf(status.Internal, "AddUserPeersToGroups is not implemented")
}
func (s *FileStore) RemoveUserPeersFromGroups(_ context.Context, _ string, _ string, _ []string) error {
return status.Errorf(status.Internal, "RemoveUserPeersFromGroups is not implemented")
func (s *FileStore) GetUserPeers(_ context.Context, _ LockingStrength, _, _ string) ([]*nbpeer.Peer, error) {
return nil, status.Errorf(status.Internal, "GetUserPeers is not implemented")
}

View File

@@ -1089,6 +1089,11 @@ func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID stri
return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate)
}
// GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account")

View File

@@ -81,10 +81,9 @@ type Store interface {
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error
RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error

View File

@@ -8,14 +8,14 @@ import (
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
log "github.com/sirupsen/logrus"
)
const (
@@ -1254,6 +1254,79 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
}
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
func (am *DefaultAccountManager) updateUserPeersInGroups(ctx context.Context, accountID, userID string, groupsToAdd,
groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) {
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
return
}
peers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID)
if err != nil {
return nil, err
}
userPeerIDMap := make(map[string]struct{}, len(peers))
for _, peer := range peers {
userPeerIDMap[peer.ID] = struct{}{}
}
for _, gid := range groupsToAdd {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
if err != nil {
return nil, err
}
addUserPeersToGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
}
for _, gid := range groupsToRemove {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
if err != nil {
return nil, err
}
removeUserPeersFromGroup(userPeerIDMap, group)
groupsToUpdate = append(groupsToUpdate, group)
}
return groupsToUpdate, nil
}
// addUserPeersToGroup adds the user's peers to the group.
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
groupPeers := make(map[string]struct{}, len(group.Peers))
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
}
for pid := range userPeerIDs {
groupPeers[pid] = struct{}{}
}
group.Peers = make([]string, 0, len(groupPeers))
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
}
// removeUserPeersFromGroup removes user's peers from the group.
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
// skip removing peers from group All
if group.Name == "All" {
return
}
updatedPeers := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers {
if _, found := userPeerIDs[pid]; !found {
updatedPeers = append(updatedPeers, pid)
}
}
group.Peers = updatedPeers
}
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
for _, user := range userData {
if user.ID == userID {