diff --git a/management/server/account.go b/management/server/account.go index 6b928b8d7..b3228f83e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1794,9 +1794,9 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) defer unlockAccount() - usersMap := make(map[string]*User) - usersMap[claims.UserId] = NewRegularUser(claims.UserId) - err := am.Store.SaveUsers(domainAccountID, usersMap) + newUser := NewRegularUser(claims.UserId) + newUser.AccountID = domainAccountID + err := am.Store.SaveUser(ctx, LockingStrengthUpdate, newUser) if err != nil { return "", err } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index a94f357e3..f10e2f8ff 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -393,20 +393,17 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr } // SaveUsers saves the given list of users to the database. -// It updates existing users if a conflict occurs. -func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { - usersToSave := make([]User, 0, len(users)) - for _, user := range users { - user.AccountID = accountID - for id, pat := range user.PATs { - pat.ID = id - user.PATsG = append(user.PATsG, *pat) - } - usersToSave = append(usersToSave, *user) +func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error { + if len(users) == 0 { + return nil } - return s.db.Session(&gorm.Session{FullSaveAssociations: true}). - Clauses(clause.OnConflict{UpdateAll: true}). - Create(&usersToSave).Error + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&users) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error) + return status.Errorf(status.Internal, "failed to save users to store") + } + return nil } // SaveUser saves the given user to the database. diff --git a/management/server/store.go b/management/server/store.go index f9f01df4f..f7d5e9348 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -67,7 +67,7 @@ type Store interface { GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) - SaveUsers(accountID string, users map[string]*User) error + SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error diff --git a/management/server/user.go b/management/server/user.go index 9f992da7a..7e69fbaf0 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "slices" "strings" "time" @@ -625,10 +624,6 @@ func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initia // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { - if update == nil { - return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") - } - updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists) if err != nil { return nil, err @@ -642,127 +637,111 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i } // SaveOrAddUsers updates existing users or adds new users to the account. -// Note: This function does not acquire the global lock. -// It is the caller's responsibility to ensure proper locking is in place before invoking this method. func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) { if len(updates) == 0 { return nil, nil //nolint:nilnil } - account, err := am.Store.GetAccount(ctx, accountID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } - initiatorUser, err := account.FindUser(initiatorUserID) - if err != nil { - return nil, err + if initiatorUser.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } if !initiatorUser.HasAdminPower() || initiatorUser.IsBlocked() { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations") + return nil, status.NewAdminPermissionError() } - updatedUsers := make([]*UserInfo, 0, len(updates)) - var ( - expiredPeers []*nbpeer.Peer - userIDs []string - eventsToStore []func() - ) + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } - for _, update := range updates { - if update == nil { - return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") - } + var updateAccountPeers bool + var peersToExpire []*nbpeer.Peer + var addUserEvents []func() + var usersToSave = make([]*User, 0, len(updates)) + var updatedUsersInfo = make([]*UserInfo, 0, len(updates)) - userIDs = append(userIDs, update.Id) + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("error getting account groups: %w", err) + } - oldUser := account.Users[update.Id] - if oldUser == nil { - if !addIfNotExists { - return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id) + groupsMap := make(map[string]*nbgroup.Group, len(groups)) + for _, group := range groups { + groupsMap[group.ID] = group + } + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for _, update := range updates { + if update == nil { + return status.Errorf(status.InvalidArgument, "provided user update is nil") } - // when addIfNotExists is set to true, the newUser will use all fields from the update input - oldUser = update - } - if err := validateUserUpdate(account, initiatorUser, oldUser, update); err != nil { - return nil, err - } - - // only auto groups, revoked status, and integration reference can be updated for now - newUser := oldUser.Copy() - newUser.Role = update.Role - newUser.Blocked = update.Blocked - newUser.AutoGroups = update.AutoGroups - // these two fields can't be set via API, only via direct call to the method - newUser.Issued = update.Issued - newUser.IntegrationReference = update.IntegrationReference - - transferredOwnerRole := handleOwnerRoleTransfer(account, initiatorUser, update) - account.Users[newUser.Id] = newUser - - if !oldUser.IsBlocked() && update.IsBlocked() { - // expire peers that belong to the user who's getting blocked - blockedPeers, err := account.FindUserPeers(update.Id) + userHadPeers, updatedUser, userPeersToExpire, userEvents, err := processUserUpdate( + ctx, am, transaction, groupsMap, initiatorUser, update, addIfNotExists, settings, + ) if err != nil { - return nil, err + return fmt.Errorf("failed to process user update: %w", err) } - expiredPeers = append(expiredPeers, blockedPeers...) + usersToSave = append(usersToSave, updatedUser) + addUserEvents = append(addUserEvents, userEvents...) + peersToExpire = append(peersToExpire, userPeersToExpire...) + + if userHadPeers { + updateAccountPeers = true + } + + updatedUserInfo, err := getUserInfo(ctx, am, updatedUser, accountID) + if err != nil { + return fmt.Errorf("failed to get user info: %w", err) + } + updatedUsersInfo = append(updatedUsersInfo, updatedUserInfo) } - if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { - removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) - // need force update all auto groups in any case they will not be duplicated - account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) - account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) - } - - events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole) - eventsToStore = append(eventsToStore, events...) - - updatedUserInfo, err := getUserInfo(ctx, am, newUser, accountID) - if err != nil { - return nil, err - } - updatedUsers = append(updatedUsers, updatedUserInfo) + return transaction.SaveUsers(ctx, LockingStrengthUpdate, usersToSave) + }) + if err != nil { + return nil, err } - if len(expiredPeers) > 0 { - if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil { + for _, addUserEvent := range addUserEvents { + addUserEvent() + } + + if len(peersToExpire) > 0 { + if err := am.expireAndUpdatePeers(ctx, accountID, peersToExpire); err != nil { log.WithContext(ctx).Errorf("failed update expired peers: %s", err) return nil, err } } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return nil, err - } - - if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { + if settings.GroupsPropagationEnabled && updateAccountPeers { + if err = am.Store.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return nil, fmt.Errorf("failed to increment network serial: %w", err) + } am.updateAccountPeers(ctx, accountID) } - for _, storeEvent := range eventsToStore { - storeEvent() - } - - return updatedUsers, nil + return updatedUsersInfo, nil } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. -func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() { +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, groupsMap map[string]*nbgroup.Group, accountID string, initiatorUserID string, oldUser, newUser *User, transferredOwnerRole bool) []func() { var eventsToStore []func() if oldUser.IsBlocked() != newUser.IsBlocked() { if newUser.IsBlocked() { eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserBlocked, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil) }) } else { eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserUnblocked, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil) }) } } @@ -770,11 +749,11 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in switch { case transferredOwnerRole: eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.TransferredOwnerRole, nil) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil) }) case oldUser.Role != newUser.Role: eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role}) }) } @@ -782,23 +761,22 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups) addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) for _, g := range removedGroups { - group := account.GetGroup(g) - if group != nil { + group, ok := groupsMap[g] + if ok { eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) + meta := map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName} + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta) }) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) + log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, accountID) } } for _, g := range addedGroups { - group := account.GetGroup(g) - if group != nil { + group, ok := groupsMap[g] + if ok { eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser, - map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) + meta := map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName} + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, meta) }) } } @@ -807,14 +785,92 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in return eventsToStore } -func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool { +func processUserUpdate(ctx context.Context, am *DefaultAccountManager, transaction Store, groupsMap map[string]*nbgroup.Group, + initiatorUser, update *User, addIfNotExists bool, settings *Settings) (bool, *User, []*nbpeer.Peer, []func(), error) { + + if update == nil { + return false, nil, nil, nil, status.Errorf(status.InvalidArgument, "provided user update is nil") + } + + oldUser, err := getUserOrCreateIfNotExists(ctx, transaction, update, addIfNotExists) + if err != nil { + return false, nil, nil, nil, err + } + + if err := validateUserUpdate(groupsMap, initiatorUser, oldUser, update); err != nil { + return false, nil, nil, nil, err + } + + // only auto groups, revoked status, and integration reference can be updated for now + updatedUser := oldUser.Copy() + updatedUser.AccountID = initiatorUser.AccountID + updatedUser.Role = update.Role + updatedUser.Blocked = update.Blocked + updatedUser.AutoGroups = update.AutoGroups + // these two fields can't be set via API, only via direct call to the method + updatedUser.Issued = update.Issued + updatedUser.IntegrationReference = update.IntegrationReference + + transferredOwnerRole, err := handleOwnerRoleTransfer(ctx, transaction, initiatorUser, update) + if err != nil { + return false, nil, nil, nil, err + } + + userPeers, err := transaction.GetUserPeers(ctx, LockingStrengthUpdate, updatedUser.AccountID, update.Id) + if err != nil { + return false, nil, nil, nil, err + } + + var peersToExpire []*nbpeer.Peer + + if !oldUser.IsBlocked() && update.IsBlocked() { + peersToExpire = userPeers + } + + if update.AutoGroups != nil && settings.GroupsPropagationEnabled { + removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) + updatedGroups, err := am.updateUserPeersInGroups(groupsMap, userPeers, update.AutoGroups, removedGroups) + if err != nil { + return false, nil, nil, nil, fmt.Errorf("error modifying user peers in groups: %w", err) + } + + if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil { + return false, nil, nil, nil, fmt.Errorf("error saving groups: %w", err) + } + } + + updateAccountPeers := len(userPeers) > 0 + userEventsToAdd := am.prepareUserUpdateEvents(ctx, groupsMap, updatedUser.AccountID, initiatorUser.Id, oldUser, updatedUser, transferredOwnerRole) + + return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil +} + +// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist. +func getUserOrCreateIfNotExists(ctx context.Context, transaction Store, update *User, addIfNotExists bool) (*User, error) { + existingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, update.Id) + if err != nil { + if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { + if !addIfNotExists { + return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id) + } + return update, nil // use all fields from update if addIfNotExists is true + } + return nil, err + } + return existingUser, nil +} + +func handleOwnerRoleTransfer(ctx context.Context, transaction Store, initiatorUser, update *User) (bool, error) { if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner { newInitiatorUser := initiatorUser.Copy() newInitiatorUser.Role = UserRoleAdmin - account.Users[initiatorUser.Id] = newInitiatorUser - return true + + if err := transaction.SaveUser(ctx, LockingStrengthUpdate, newInitiatorUser); err != nil { + return false, err + } + return true, nil } - return false + return false, nil } // getUserInfo retrieves the UserInfo for a given User and Account. @@ -837,7 +893,7 @@ func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, acc } // validateUserUpdate validates the update operation for a user. -func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error { +func validateUserUpdate(groupsMap map[string]*nbgroup.Group, initiatorUser, oldUser, update *User) error { if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") } @@ -858,7 +914,7 @@ func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) } for _, newGroupID := range update.AutoGroups { - group, ok := account.Groups[newGroupID] + group, ok := groupsMap[newGroupID] if !ok { return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist", newGroupID, update.Id) @@ -1284,16 +1340,6 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa return nil, false } -// areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account. -func areUsersLinkedToPeers(account *Account, userIDs []string) bool { - for _, peer := range account.Peers { - if slices.Contains(userIDs, peer.UserID) { - return true - } - } - return false -} - func validateUserInvite(invite *UserInfo) error { if invite == nil { return fmt.Errorf("provided user update is nil")