mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
move updateUserPeersInGroups to account manager
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@@ -1851,6 +1851,10 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
}()
|
||||
|
||||
if err = am.Store.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, claims.UserId)
|
||||
if err != nil {
|
||||
@@ -1860,10 +1864,6 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
addNewGroups := difference(updatedAutoGroups, user.AutoGroups)
|
||||
removeOldGroups := difference(user.AutoGroups, updatedAutoGroups)
|
||||
|
||||
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil {
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
user.AutoGroups = updatedAutoGroups
|
||||
if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil {
|
||||
return fmt.Errorf("error saving user: %w", err)
|
||||
@@ -1871,12 +1871,13 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
|
||||
// Propagate changes to peers if group propagation is enabled
|
||||
if settings.GroupsPropagationEnabled {
|
||||
if err = transaction.AddUserPeersToGroups(ctx, accountID, claims.UserId, addNewGroups); err != nil {
|
||||
return fmt.Errorf("error adding user peers to groups: %w", err)
|
||||
updatedGroups, err := am.updateUserPeersInGroups(ctx, accountID, claims.UserId, addNewGroups, removeOldGroups)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error modifying user peers in groups: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.RemoveUserPeersFromGroups(ctx, accountID, claims.UserId, removeOldGroups); err != nil {
|
||||
return fmt.Errorf("error removing user peers from groups: %w", err)
|
||||
if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil {
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
|
||||
@@ -1051,10 +1051,6 @@ func (s *FileStore) SaveGroup(_ context.Context, _ LockingStrength, _ *nbgroup.G
|
||||
return status.Errorf(status.Internal, "SaveGroup is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) AddUserPeersToGroups(_ context.Context, _ string, _ string, _ []string) error {
|
||||
return status.Errorf(status.Internal, "AddUserPeersToGroups is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) RemoveUserPeersFromGroups(_ context.Context, _ string, _ string, _ []string) error {
|
||||
return status.Errorf(status.Internal, "RemoveUserPeersFromGroups is not implemented")
|
||||
func (s *FileStore) GetUserPeers(_ context.Context, _ LockingStrength, _, _ string) ([]*nbpeer.Peer, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetUserPeers is not implemented")
|
||||
}
|
||||
|
||||
@@ -1089,6 +1089,11 @@ func (s *SqlStore) RemoveUserPeersFromGroups(ctx context.Context, accountID stri
|
||||
return s.SaveGroups(ctx, LockingStrengthUpdate, groupsToUpdate)
|
||||
}
|
||||
|
||||
// GetUserPeers retrieves peers for a user.
|
||||
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue adding peer to account")
|
||||
|
||||
@@ -81,10 +81,9 @@ type Store interface {
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||
AddUserPeersToGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error
|
||||
RemoveUserPeersFromGroups(ctx context.Context, accountID string, userID string, groupIDs []string) error
|
||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||
|
||||
@@ -8,14 +8,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integration_reference"
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -1254,6 +1254,79 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun
|
||||
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
|
||||
}
|
||||
|
||||
// updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them.
|
||||
func (am *DefaultAccountManager) updateUserPeersInGroups(ctx context.Context, accountID, userID string, groupsToAdd,
|
||||
groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) {
|
||||
|
||||
if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
peers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userPeerIDMap := make(map[string]struct{}, len(peers))
|
||||
for _, peer := range peers {
|
||||
userPeerIDMap[peer.ID] = struct{}{}
|
||||
}
|
||||
|
||||
for _, gid := range groupsToAdd {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addUserPeersToGroup(userPeerIDMap, group)
|
||||
groupsToUpdate = append(groupsToUpdate, group)
|
||||
}
|
||||
|
||||
for _, gid := range groupsToRemove {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, gid, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
removeUserPeersFromGroup(userPeerIDMap, group)
|
||||
groupsToUpdate = append(groupsToUpdate, group)
|
||||
}
|
||||
|
||||
return groupsToUpdate, nil
|
||||
}
|
||||
|
||||
// addUserPeersToGroup adds the user's peers to the group.
|
||||
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
|
||||
groupPeers := make(map[string]struct{}, len(group.Peers))
|
||||
for _, pid := range group.Peers {
|
||||
groupPeers[pid] = struct{}{}
|
||||
}
|
||||
|
||||
for pid := range userPeerIDs {
|
||||
groupPeers[pid] = struct{}{}
|
||||
}
|
||||
|
||||
group.Peers = make([]string, 0, len(groupPeers))
|
||||
for pid := range groupPeers {
|
||||
group.Peers = append(group.Peers, pid)
|
||||
}
|
||||
}
|
||||
|
||||
// removeUserPeersFromGroup removes user's peers from the group.
|
||||
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) {
|
||||
// skip removing peers from group All
|
||||
if group.Name == "All" {
|
||||
return
|
||||
}
|
||||
|
||||
updatedPeers := make([]string, 0, len(group.Peers))
|
||||
for _, pid := range group.Peers {
|
||||
if _, found := userPeerIDs[pid]; !found {
|
||||
updatedPeers = append(updatedPeers, pid)
|
||||
}
|
||||
}
|
||||
|
||||
group.Peers = updatedPeers
|
||||
}
|
||||
|
||||
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
|
||||
for _, user := range userData {
|
||||
if user.ID == userID {
|
||||
|
||||
Reference in New Issue
Block a user