diff --git a/management/server/account.go b/management/server/account.go index 29415b038..6738235a0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1403,9 +1403,20 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } addNewGroups = util.Difference(updatedAutoGroups, user.AutoGroups) + for _, group := range addNewGroups { + err = transaction.AddUserToGroup(ctx, userAuth.AccountId, group, user.Id) + if err != nil { + return fmt.Errorf("error adding user to group: %w", err) + } + } removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups) + for _, group := range removeOldGroups { + err = transaction.RemoveUserFromGroup(ctx, user.Id, group) + if err != nil { + return fmt.Errorf("error removing user from group: %w", err) + } + } - user.AutoGroups = updatedAutoGroups if err = transaction.SaveUser(ctx, user); err != nil { return fmt.Errorf("error saving user: %w", err) } diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index 78f4afbd5..7a9155eba 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -487,3 +487,103 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri log.WithContext(ctx).Infof("Migration of JSON field %s from table %s into separate table completed", columnName, tableName) return nil } + +// CleanupOrphanedIDs removes non-existent IDs from the JSON array column. +// T is the type of the model that contains the list. +// This migration cleans up the lists field by removing IDs that no longer exist in the target table. +func CleanupOrphanedIDs[T, S any](ctx context.Context, db *gorm.DB, columnName string) error { + var sourceModel T + var fkModel S + + if !db.Migrator().HasTable(&sourceModel) { + log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", sourceModel) + return nil + } + + if !db.Migrator().HasTable(&fkModel) { + log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", fkModel) + return nil + } + + stmt := &gorm.Statement{DB: db} + err := stmt.Parse(&sourceModel) + if err != nil { + return fmt.Errorf("parse model: %w", err) + } + tableName := stmt.Schema.Table + + if !db.Migrator().HasColumn(&sourceModel, columnName) { + log.WithContext(ctx).Debugf("Column %s does not exist in table %s, no migration needed", columnName, tableName) + return nil + } + + if err := db.Transaction(func(tx *gorm.DB) error { + var rows []map[string]any + if err := tx.Table(tableName).Select("id", columnName).Find(&rows).Error; err != nil { + return fmt.Errorf("find rows: %w", err) + } + + // Get all valid IDs from the fk table + var validIDs []string + if err := tx.Model(fkModel).Select("id").Pluck("id", &validIDs).Error; err != nil { + return fmt.Errorf("fetch valid group IDs: %w", err) + } + + validIDMap := make(map[string]bool, len(validIDs)) + for _, id := range validIDs { + validIDMap[id] = true + } + + updatedCount := 0 + for _, row := range rows { + jsonValue, ok := row[columnName].(string) + if !ok || jsonValue == "" || jsonValue == "null" { + continue + } + + var list []string + if err := json.Unmarshal([]byte(jsonValue), &list); err != nil { + log.WithContext(ctx).Warnf("Failed to unmarshal %s for id %v: %v", columnName, row["id"], err) + continue + } + + if len(list) == 0 { + continue + } + + // Filter out non-existent IDs + cleanedList := make([]string, 0, len(list)) + for _, groupID := range list { + if validIDMap[groupID] { + cleanedList = append(cleanedList, groupID) + } + } + + // Only update if there were orphaned ids removed + if len(cleanedList) != len(list) { + cleanedJSON, err := json.Marshal(cleanedList) + if err != nil { + return fmt.Errorf("marshal cleaned %s: %w", columnName, err) + } + + if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(columnName, cleanedJSON).Error; err != nil { + return fmt.Errorf("update row with id %v: %w", row["id"], err) + } + updatedCount++ + } + } + + if updatedCount > 0 { + log.WithContext(ctx).Infof("Cleaned up orphaned %s in %d rows from table %s", columnName, updatedCount, tableName) + } else { + log.WithContext(ctx).Debugf("No orphaned %s found in table %s", columnName, tableName) + } + + return nil + }); err != nil { + return err + } + + log.WithContext(ctx).Infof("Cleanup of orphaned auto_groups from table %s completed", tableName) + return nil +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index f407a35e6..d9323e107 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -119,7 +119,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met return nil, fmt.Errorf("migratePreAuto: %w", err) } err = db.AutoMigrate( - &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, + &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, &types.GroupUser{}, &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, @@ -185,6 +185,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro for _, group := range account.GroupsG { group.StoreGroupPeers() + group.StoreGroupUsers() } err := s.transaction(func(tx *gorm.DB) error { @@ -243,6 +244,7 @@ func generateAccountSQLTypes(account *types.Account) { pat.ID = id user.PATsG = append(user.PATsG, *pat) } + user.LoadAutoGroups() account.UsersG = append(account.UsersG, *user) } @@ -453,6 +455,7 @@ func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error { userCopy := user.Copy() userCopy.Email = user.Email userCopy.Name = user.Name + userCopy.StoreAutoGroups() if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { return fmt.Errorf("encrypt user: %w", err) } @@ -472,6 +475,7 @@ func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error { userCopy := user.Copy() userCopy.Email = user.Email userCopy.Name = user.Name + userCopy.StoreAutoGroups() if err := userCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { return fmt.Errorf("encrypt user: %w", err) @@ -617,6 +621,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren var user types.User result := tx. + Preload("Groups"). Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id"). Where("personal_access_tokens.id = ?", patID).Take(&user) if result.Error != nil { @@ -631,6 +636,8 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren return nil, fmt.Errorf("decrypt user: %w", err) } + user.LoadAutoGroups() + return &user, nil } @@ -641,7 +648,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre } var user types.User - result := tx.Take(&user, idQueryCondition, userID) + result := tx.Preload("Groups").Take(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) @@ -653,6 +660,8 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return nil, fmt.Errorf("decrypt user: %w", err) } + user.LoadAutoGroups() + return &user, nil } @@ -680,7 +689,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre } var users []*types.User - result := tx.Find(&users, accountIDCondition, accountID) + result := tx.Preload("Groups").Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -693,6 +702,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil { return nil, fmt.Errorf("decrypt user: %w", err) } + user.LoadAutoGroups() } return users, nil @@ -705,7 +715,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre } var user types.User - result := tx.Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner) + result := tx.Preload("Groups").Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed") @@ -717,6 +727,8 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre return nil, fmt.Errorf("decrypt user: %w", err) } + user.LoadAutoGroups() + return &user, nil } @@ -738,6 +750,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr for _, g := range groups { g.LoadGroupPeers() + g.LoadGroupUsers() } return groups, nil @@ -767,6 +780,7 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt for _, g := range groups { g.LoadGroupPeers() + g.LoadGroupUsers() } return groups, nil @@ -867,6 +881,8 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types Preload("SetupKeysG"). Preload("PeersG"). Preload("UsersG"). + Preload("UsersG.GroupUser"). + Preload("GroupsG"). Preload("GroupsG.GroupPeers"). Preload("RoutesG"). Preload("NameServerGroupsG"). @@ -908,9 +924,10 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types pat.UserID = "" user.PATs[pat.ID] = &pat } - if user.AutoGroups == nil { - user.AutoGroups = []string{} + if user.Groups == nil { + user.Groups = []*types.GroupUser{} } + user.LoadAutoGroups() if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil { return nil, fmt.Errorf("decrypt user: %w", err) } @@ -1116,8 +1133,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types. groupIDs = append(groupIDs, g.ID) } - wg.Add(3) - errChan = make(chan error, 3) + wg.Add(4) + errChan = make(chan error, 4) var pats []types.PersonalAccessToken go func() { @@ -1149,6 +1166,16 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types. } }() + var groupUsers []types.GroupUser + go func() { + defer wg.Done() + var err error + groupUsers, err = s.getGroupUsers(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + wg.Wait() close(errChan) for e := range errChan { @@ -1174,6 +1201,12 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types. peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) } + groupsByUserID := make(map[string][]*types.GroupUser) + for i := range groupUsers { + gu := &groupUsers[i] + groupsByUserID[gu.UserID] = append(groupsByUserID[gu.UserID], gu) + } + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) for i := range account.SetupKeysG { key := &account.SetupKeysG[i] @@ -1199,6 +1232,8 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types. user.PATs[pat.ID] = pat } } + user.Groups = groupsByUserID[user.Id] + user.LoadAutoGroups() account.Users[user.Id] = user } @@ -1596,43 +1631,40 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee } func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) { - const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1` + const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1` rows, err := s.pool.Query(ctx, query, accountID) if err != nil { return nil, err } users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { var u types.User - var autoGroups []byte var lastLogin, createdAt sql.NullTime var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool - err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name) - if err == nil { - if lastLogin.Valid { - u.LastLogin = &lastLogin.Time - } - if createdAt.Valid { - u.CreatedAt = createdAt.Time - } - if isServiceUser.Valid { - u.IsServiceUser = isServiceUser.Bool - } - if nonDeletable.Valid { - u.NonDeletable = nonDeletable.Bool - } - if blocked.Valid { - u.Blocked = blocked.Bool - } - if pendingApproval.Valid { - u.PendingApproval = pendingApproval.Bool - } - if autoGroups != nil { - _ = json.Unmarshal(autoGroups, &u.AutoGroups) - } else { - u.AutoGroups = []string{} - } + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name) + if err != nil { + return u, err } - return u, err + + if lastLogin.Valid { + u.LastLogin = &lastLogin.Time + } + if createdAt.Valid { + u.CreatedAt = createdAt.Time + } + if isServiceUser.Valid { + u.IsServiceUser = isServiceUser.Bool + } + if nonDeletable.Valid { + u.NonDeletable = nonDeletable.Bool + } + if blocked.Valid { + u.Blocked = blocked.Bool + } + if pendingApproval.Valid { + u.PendingApproval = pendingApproval.Bool + } + + return u, nil }) if err != nil { return nil, err @@ -2038,6 +2070,22 @@ func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]type return groupPeers, nil } +func (s *SqlStore) getGroupUsers(ctx context.Context, userIDs []string) ([]types.GroupUser, error) { + if len(userIDs) == 0 { + return nil, nil + } + const query = `SELECT account_id, group_id, user_id FROM group_users WHERE user_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, userIDs) + if err != nil { + return nil, err + } + groupUsers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupUser]) + if err != nil { + return nil, err + } + return groupUsers, nil +} + func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) { var user types.User result := s.db.Select("account_id").Take(&user, idQueryCondition, userID) @@ -2659,6 +2707,41 @@ func (s *SqlStore) RemovePeerFromGroup(ctx context.Context, peerID string, group return nil } +func (s *SqlStore) AddUserToGroup(ctx context.Context, accountID, userID, groupID string) error { + user := &types.GroupUser{ + AccountID: accountID, + GroupID: groupID, + UserID: userID, + } + + err := s.db.Clauses(clause.OnConflict{ + Columns: []clause.Column{{Name: "group_id"}, {Name: "user_id"}}, + DoNothing: true, + }).Create(user).Error + + if err != nil { + log.WithContext(ctx).Errorf("failed to add user %s to group %s for account %s: %v", userID, groupID, accountID, err) + return status.Errorf(status.Internal, "failed to add user to group") + } + + return nil +} + +func (s *SqlStore) RemoveUserFromGroup(ctx context.Context, userID, groupID string) error { + result := s.db.Delete(&types.GroupUser{}, "group_id = ? AND user_id = ?", groupID, userID) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to remove user %s from group %s: %v", userID, groupID, result.Error) + return status.Errorf(status.Internal, "failed to remove user from group") + } + + if result.RowsAffected == 0 { + log.WithContext(ctx).Warnf("user %s was not in group %s", userID, groupID) + } + + return nil +} + // RemovePeerFromAllGroups removes a peer from all groups func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) error { err := s.db. @@ -2745,6 +2828,7 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng for _, group := range groups { group.LoadGroupPeers() + group.LoadGroupUsers() } return groups, nil @@ -3146,6 +3230,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt } group.LoadGroupPeers() + group.LoadGroupUsers() return group, nil } @@ -3177,6 +3262,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren } group.LoadGroupPeers() + group.LoadGroupUsers() return &group, nil } @@ -3198,6 +3284,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren groupsMap := make(map[string]*types.Group) for _, group := range groups { group.LoadGroupPeers() + group.LoadGroupUsers() groupsMap[group.ID] = group } diff --git a/management/server/store/store.go b/management/server/store/store.go index 013a66d73..372f2ebdc 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -89,6 +89,8 @@ type Store interface { GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error + AddUserToGroup(ctx context.Context, accountID, userID, groupID string) error + RemoveUserFromGroup(ctx context.Context, userID, groupID string) error GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) @@ -350,6 +352,9 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.MigrateNewField[types.User](ctx, db, "email", "") }, + func(db *gorm.DB) error { + return migration.CleanupOrphanedIDs[types.User, types.Group](ctx, db, "auto_groups") + }, } } // migratePostAuto migrates the SQLite database to the latest schema func migratePostAuto(ctx context.Context, db *gorm.DB) error { @@ -381,6 +386,15 @@ func getMigrationsPostAuto(ctx context.Context) []migrationFunc { } }) }, + func(db *gorm.DB) error { + return migration.MigrateJsonToTable[types.User](ctx, db, "auto_groups", func(accountID, id, value string) any { + return &types.GroupUser{ + AccountID: accountID, + GroupID: value, + UserID: id, + } + }) + }, } } diff --git a/management/server/types/group.go b/management/server/types/group.go index 00fdf7a69..5aacf5ae9 100644 --- a/management/server/types/group.go +++ b/management/server/types/group.go @@ -28,6 +28,8 @@ type Group struct { // Peers list of the group Peers []string `gorm:"-"` // Peers and GroupPeers list will be ignored when writing to the DB. Use AddPeerToGroup and RemovePeerFromGroup methods to modify group membership GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"` + Users []string `gorm:"-"` + GroupUsers []GroupUser `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"` // Resources contains a list of resources in that group Resources []Resource `gorm:"serializer:json"` @@ -41,6 +43,12 @@ type GroupPeer struct { PeerID string `gorm:"primaryKey"` } +type GroupUser struct { + AccountID string `gorm:"index"` + GroupID string `gorm:"primaryKey"` + UserID string `gorm:"primaryKey"` +} + func (g *Group) LoadGroupPeers() { g.Peers = make([]string, len(g.GroupPeers)) for i, peer := range g.GroupPeers { @@ -61,6 +69,26 @@ func (g *Group) StoreGroupPeers() { g.Peers = []string{} } +func (g *Group) LoadGroupUsers() { + g.Users = make([]string, len(g.GroupUsers)) + for i, user := range g.GroupUsers { + g.Users[i] = user.UserID + } + g.GroupUsers = []GroupUser{} +} + +func (g *Group) StoreGroupUsers() { + g.GroupUsers = make([]GroupUser, len(g.Users)) + for i, user := range g.Users { + g.GroupUsers[i] = GroupUser{ + AccountID: g.AccountID, + GroupID: g.ID, + UserID: user, + } + } + g.Users = []string{} +} + // EventMeta returns activity event meta related to the group func (g *Group) EventMeta() map[string]any { return map[string]any{"name": g.Name} @@ -78,11 +106,13 @@ func (g *Group) Copy() *Group { Issued: g.Issued, Peers: make([]string, len(g.Peers)), GroupPeers: make([]GroupPeer, len(g.GroupPeers)), + GroupUsers: make([]GroupUser, len(g.GroupUsers)), Resources: make([]Resource, len(g.Resources)), IntegrationReference: g.IntegrationReference, } copy(group.Peers, g.Peers) copy(group.GroupPeers, g.GroupPeers) + copy(group.GroupUsers, g.GroupUsers) copy(group.Resources, g.Resources) return group } diff --git a/management/server/types/user.go b/management/server/types/user.go index dc601e15b..8c1aaf1c5 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -85,9 +85,11 @@ type User struct { // ServiceUserName is only set if IsServiceUser is true ServiceUserName string // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user - AutoGroups []string `gorm:"serializer:json"` - PATs map[string]*PersonalAccessToken `gorm:"-"` - PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"` + AutoGroups []string `gorm:"-"` + // GroupUsers replaces old AutoGroups + Groups []*GroupUser `gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"` + PATs map[string]*PersonalAccessToken `gorm:"-"` + PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id;constraint:OnDelete:CASCADE;"` // Blocked indicates whether the user is blocked. Blocked users can't use the system. Blocked bool // PendingApproval indicates whether the user requires approval before being activated @@ -106,6 +108,24 @@ type User struct { Email string `gorm:"default:''"` } +func (u *User) LoadAutoGroups() { + u.AutoGroups = make([]string, 0, len(u.Groups)) + for _, group := range u.Groups { + u.AutoGroups = append(u.AutoGroups, group.GroupID) + } +} + +func (u *User) StoreAutoGroups() { + u.Groups = make([]*GroupUser, 0, len(u.Groups)) + for _, groupID := range u.AutoGroups { + u.Groups = append(u.Groups, &GroupUser{ + AccountID: u.AccountID, + GroupID: groupID, + UserID: u.Id, + }) + } +} + // IsBlocked returns true if the user is blocked, false otherwise func (u *User) IsBlocked() bool { return u.Blocked @@ -198,8 +218,11 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { // Copy the user func (u *User) Copy() *User { + groupUsers := make([]*GroupUser, len(u.Groups)) + copy(groupUsers, u.Groups) autoGroups := make([]string, len(u.AutoGroups)) copy(autoGroups, u.AutoGroups) + pats := make(map[string]*PersonalAccessToken, len(u.PATs)) for k, v := range u.PATs { pats[k] = v.Copy() @@ -221,6 +244,7 @@ func (u *User) Copy() *User { IntegrationReference: u.IntegrationReference, Email: u.Email, Name: u.Name, + Groups: groupUsers, } } diff --git a/management/server/user.go b/management/server/user.go index 4f9007b61..49aa35698 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -45,7 +45,21 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI newUser.AccountID = accountID log.WithContext(ctx).Debugf("New User: %v", newUser) - if err = am.Store.SaveUser(ctx, newUser); err != nil { + if err = am.Store.ExecuteInTransaction(ctx, func(tx store.Store) error { + err = tx.SaveUser(ctx, newUser) + if err != nil { + return err + } + + for _, groupID := range autoGroups { + err = tx.AddUserToGroup(ctx, accountID, newUserID, groupID) + if err != nil { + return fmt.Errorf("failed to add user to group %s: %w", groupID, err) + } + } + + return nil + }); err != nil { return nil, err } @@ -119,7 +133,6 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u Id: idpUser.ID, AccountID: accountID, Role: types.StrRoleToUserRole(invite.Role), - AutoGroups: invite.AutoGroups, Issued: invite.Issued, IntegrationReference: invite.IntegrationReference, CreatedAt: time.Now().UTC(), @@ -127,6 +140,23 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u Name: invite.Name, } + err = am.Store.ExecuteInTransaction(ctx, func(tx store.Store) error { + err = tx.SaveUser(ctx, newUser) + if err != nil { + return err + } + for _, group := range invite.AutoGroups { + err = tx.AddUserToGroup(ctx, accountID, userID, group) + if err != nil { + return fmt.Errorf("failed to add user to group %s: %w", group, err) + } + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to save user: %w", err) + } + if err = am.Store.SaveUser(ctx, newUser); err != nil { return nil, err } @@ -737,28 +767,63 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact peersToExpire = userPeers } - var removedGroups, addedGroups []string - if update.AutoGroups != nil && settings.GroupsPropagationEnabled { - removedGroups = util.Difference(oldUser.AutoGroups, update.AutoGroups) - addedGroups = util.Difference(update.AutoGroups, oldUser.AutoGroups) + updateAccountPeers, removedGroupsIDs, addedGroupsIDs, err := am.processUserGroupsUpdate(ctx, transaction, oldUser, updatedUser, userPeers, settings) + if err != nil { + return false, nil, nil, nil, err + } + + updatedUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, updatedUser.Id) + if err != nil { + return false, nil, nil, nil, err + } + + userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroupsIDs, addedGroupsIDs, transaction) + + return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil +} + +func (am *DefaultAccountManager) processUserGroupsUpdate(ctx context.Context, transaction store.Store, oldUser *types.User, updatedUser *types.User, userPeers []*nbpeer.Peer, settings *types.Settings) (bool, []string, []string, error) { + removedGroups := util.Difference(oldUser.AutoGroups, updatedUser.AutoGroups) + addedGroups := util.Difference(updatedUser.AutoGroups, oldUser.AutoGroups) + + updateAccountPeers := len(userPeers) > 0 + + removedGroupsIDs := make([]string, 0, len(removedGroups)) + for _, id := range removedGroups { + err := transaction.RemoveUserFromGroup(ctx, updatedUser.Id, id) + if err != nil { + return false, nil, nil, fmt.Errorf("failed to remove user %s from group %s: %w", updatedUser.Id, id, err) + } + updateAccountPeers = true + removedGroupsIDs = append(removedGroupsIDs, id) + } + + addedGroupsIDs := make([]string, 0, len(addedGroups)) + for _, id := range addedGroups { + err := transaction.AddUserToGroup(ctx, updatedUser.AccountID, updatedUser.Id, id) + if err != nil { + return false, nil, nil, fmt.Errorf("failed to add user %s to group %s: %w", updatedUser.Id, id, err) + } + updateAccountPeers = true + addedGroupsIDs = append(addedGroupsIDs, id) + } + + if updatedUser.Groups != nil && settings.GroupsPropagationEnabled { for _, peer := range userPeers { - for _, groupID := range removedGroups { - if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil { - return false, nil, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, groupID, err) + for _, id := range removedGroups { + if err := transaction.RemovePeerFromGroup(ctx, peer.ID, id); err != nil { + return false, nil, nil, fmt.Errorf("failed to remove peer %s from group %s: %w", peer.ID, id, err) } } - for _, groupID := range addedGroups { - if err := transaction.AddPeerToGroup(ctx, accountID, peer.ID, groupID); err != nil { - return false, nil, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, groupID, err) + for _, id := range addedGroups { + if err := transaction.AddPeerToGroup(ctx, updatedUser.AccountID, peer.ID, id); err != nil { + return false, nil, nil, fmt.Errorf("failed to add peer %s to group %s: %w", peer.ID, id, err) } } } } - updateAccountPeers := len(userPeers) > 0 - userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroups, addedGroups, transaction) - - return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil + return updateAccountPeers, removedGroupsIDs, addedGroupsIDs, nil } // getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.