diff --git a/management/server/account.go b/management/server/account.go index 6738235a0..af112c25b 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1404,7 +1404,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth addNewGroups = util.Difference(updatedAutoGroups, user.AutoGroups) for _, group := range addNewGroups { - err = transaction.AddUserToGroup(ctx, userAuth.AccountId, group, user.Id) + err = transaction.AddUserToGroup(ctx, userAuth.AccountId, user.Id, group) if err != nil { return fmt.Errorf("error adding user to group: %w", err) } @@ -1416,6 +1416,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth return fmt.Errorf("error removing user from group: %w", err) } } + user.AutoGroups = updatedAutoGroups if err = transaction.SaveUser(ctx, user); err != nil { return fmt.Errorf("error saving user: %w", err) diff --git a/management/server/account_test.go b/management/server/account_test.go index 59d6e4928..9d032fe13 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1723,6 +1723,13 @@ func TestAccount_Copy(t *testing.T) { Id: "user1", Role: types.UserRoleAdmin, AutoGroups: []string{"group1"}, + Groups: []*types.GroupUser{ + { + AccountID: "account1", + UserID: "user1", + GroupID: "group1", + }, + }, PATs: map[string]*types.PersonalAccessToken{ "pat1": { ID: "pat1", diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index d9323e107..c070fa5bb 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -177,7 +177,8 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro generateAccountSQLTypes(account) // Encrypt sensitive user data before saving - for i := range account.UsersG { + for i, user := range account.UsersG { + user.StoreAutoGroups() if err := account.UsersG[i].EncryptSensitiveData(s.fieldEncrypt); err != nil { return fmt.Errorf("encrypt user: %w", err) } @@ -185,7 +186,6 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro for _, group := range account.GroupsG { group.StoreGroupPeers() - group.StoreGroupUsers() } err := s.transaction(func(tx *gorm.DB) error { @@ -213,6 +213,9 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro } return nil }) + if err != nil { + return err + } took := time.Since(start) if s.metrics != nil { @@ -220,7 +223,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro } log.WithContext(ctx).Debugf("took %d ms to persist an account to the store", took.Milliseconds()) - return err + return nil } // generateAccountSQLTypes generates the GORM compatible types for the account @@ -244,8 +247,8 @@ func generateAccountSQLTypes(account *types.Account) { pat.ID = id user.PATsG = append(user.PATsG, *pat) } - user.LoadAutoGroups() - account.UsersG = append(account.UsersG, *user) + account.UsersG = append(account.UsersG, user) + account.Users = nil } for id, group := range account.Groups { @@ -750,7 +753,6 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr for _, g := range groups { g.LoadGroupPeers() - g.LoadGroupUsers() } return groups, nil @@ -780,7 +782,6 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt for _, g := range groups { g.LoadGroupPeers() - g.LoadGroupUsers() } return groups, nil @@ -881,9 +882,10 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types Preload("SetupKeysG"). Preload("PeersG"). Preload("UsersG"). - Preload("UsersG.GroupUser"). + Preload("UsersG.Groups"). Preload("GroupsG"). Preload("GroupsG.GroupPeers"). + Preload("GroupsG.GroupUsers"). Preload("RoutesG"). Preload("NameServerGroupsG"). Preload("PostureChecks"). @@ -931,7 +933,7 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil { return nil, fmt.Errorf("decrypt user: %w", err) } - account.Users[user.Id] = &user + account.Users[user.Id] = user user.PATsG = nil } account.UsersG = nil @@ -1221,7 +1223,7 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types. account.Users = make(map[string]*types.User, len(account.UsersG)) for i := range account.UsersG { - user := &account.UsersG[i] + user := account.UsersG[i] if err := user.DecryptSensitiveData(s.fieldEncrypt); err != nil { return nil, fmt.Errorf("decrypt user: %w", err) } @@ -1630,19 +1632,19 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee return peers, nil } -func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) { +func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]*types.User, error) { const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type, email, name FROM users WHERE account_id = $1` rows, err := s.pool.Query(ctx, query, accountID) if err != nil { return nil, err } - users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { + users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.User, error) { var u types.User var lastLogin, createdAt sql.NullTime var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType, &u.Email, &u.Name) if err != nil { - return u, err + return &u, err } if lastLogin.Valid { @@ -1664,7 +1666,7 @@ func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User u.PendingApproval = pendingApproval.Bool } - return u, nil + return &u, nil }) if err != nil { return nil, err @@ -2828,7 +2830,6 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng for _, group := range groups { group.LoadGroupPeers() - group.LoadGroupUsers() } return groups, nil @@ -3230,7 +3231,6 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt } group.LoadGroupPeers() - group.LoadGroupUsers() return group, nil } @@ -3262,7 +3262,6 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren } group.LoadGroupPeers() - group.LoadGroupUsers() return &group, nil } @@ -3284,7 +3283,6 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren groupsMap := make(map[string]*types.Group) for _, group := range groups { group.LoadGroupPeers() - group.LoadGroupUsers() groupsMap[group.ID] = group } diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index 350a1da83..2284218ad 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -178,7 +178,7 @@ func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*ty if user.AutoGroups == nil { user.AutoGroups = []string{} } - account.Users[user.Id] = &user + account.Users[user.Id] = user user.PATsG = nil } account.UsersG = nil @@ -900,7 +900,7 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty account.Users = make(map[string]*types.User, len(account.UsersG)) for i := range account.UsersG { - user := &account.UsersG[i] + user := account.UsersG[i] user.PATs = make(map[string]*types.PersonalAccessToken) if userPats, ok := patsByUserID[user.Id]; ok { for j := range userPats { diff --git a/management/server/types/account.go b/management/server/types/account.go index 06170a132..06f6c5a02 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -87,7 +87,7 @@ type Account struct { Peers map[string]*nbpeer.Peer `gorm:"-"` PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` Users map[string]*User `gorm:"-"` - UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` + UsersG []*User `json:"-" gorm:"foreignKey:AccountID;references:id"` Groups map[string]*Group `gorm:"-"` GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` diff --git a/management/server/types/group.go b/management/server/types/group.go index 5aacf5ae9..76f7416a1 100644 --- a/management/server/types/group.go +++ b/management/server/types/group.go @@ -28,7 +28,6 @@ type Group struct { // Peers list of the group Peers []string `gorm:"-"` // Peers and GroupPeers list will be ignored when writing to the DB. Use AddPeerToGroup and RemovePeerFromGroup methods to modify group membership GroupPeers []GroupPeer `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"` - Users []string `gorm:"-"` GroupUsers []GroupUser `gorm:"foreignKey:GroupID;references:id;constraint:OnDelete:CASCADE;"` // Resources contains a list of resources in that group @@ -69,26 +68,6 @@ func (g *Group) StoreGroupPeers() { g.Peers = []string{} } -func (g *Group) LoadGroupUsers() { - g.Users = make([]string, len(g.GroupUsers)) - for i, user := range g.GroupUsers { - g.Users[i] = user.UserID - } - g.GroupUsers = []GroupUser{} -} - -func (g *Group) StoreGroupUsers() { - g.GroupUsers = make([]GroupUser, len(g.Users)) - for i, user := range g.Users { - g.GroupUsers[i] = GroupUser{ - AccountID: g.AccountID, - GroupID: g.ID, - UserID: user, - } - } - g.Users = []string{} -} - // EventMeta returns activity event meta related to the group func (g *Group) EventMeta() map[string]any { return map[string]any{"name": g.Name} @@ -99,21 +78,41 @@ func (g *Group) EventMetaResource(resource *types.NetworkResource) map[string]an } func (g *Group) Copy() *Group { + var peers []string + if g.Peers != nil { + peers = make([]string, len(g.Peers)) + copy(peers, g.Peers) + } + + var groupPeers []GroupPeer + if g.GroupPeers != nil { + groupPeers = make([]GroupPeer, len(g.GroupPeers)) + copy(groupPeers, g.GroupPeers) + } + + var groupUsers []GroupUser + if g.GroupUsers != nil { + groupUsers = make([]GroupUser, len(g.GroupUsers)) + copy(groupUsers, g.GroupUsers) + } + + var resources []Resource + if g.Resources != nil { + resources = make([]Resource, len(g.Resources)) + copy(resources, g.Resources) + } + group := &Group{ ID: g.ID, AccountID: g.AccountID, Name: g.Name, Issued: g.Issued, - Peers: make([]string, len(g.Peers)), - GroupPeers: make([]GroupPeer, len(g.GroupPeers)), - GroupUsers: make([]GroupUser, len(g.GroupUsers)), - Resources: make([]Resource, len(g.Resources)), + Peers: peers, + GroupPeers: groupPeers, + GroupUsers: groupUsers, + Resources: resources, IntegrationReference: g.IntegrationReference, } - copy(group.Peers, g.Peers) - copy(group.GroupPeers, g.GroupPeers) - copy(group.GroupUsers, g.GroupUsers) - copy(group.Resources, g.Resources) return group } diff --git a/management/server/types/user.go b/management/server/types/user.go index 8c1aaf1c5..761ed09e1 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -113,6 +113,7 @@ func (u *User) LoadAutoGroups() { for _, group := range u.Groups { u.AutoGroups = append(u.AutoGroups, group.GroupID) } + u.Groups = []*GroupUser{} } func (u *User) StoreAutoGroups() { @@ -124,6 +125,7 @@ func (u *User) StoreAutoGroups() { UserID: u.Id, }) } + u.AutoGroups = []string{} } // IsBlocked returns true if the user is blocked, false otherwise @@ -218,10 +220,16 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) { // Copy the user func (u *User) Copy() *User { - groupUsers := make([]*GroupUser, len(u.Groups)) - copy(groupUsers, u.Groups) - autoGroups := make([]string, len(u.AutoGroups)) - copy(autoGroups, u.AutoGroups) + var groupUsers []*GroupUser + if u.Groups != nil { + groupUsers = make([]*GroupUser, len(u.Groups)) + copy(groupUsers, u.Groups) + } + var autoGroups []string + if u.AutoGroups != nil { + autoGroups = make([]string, len(u.AutoGroups)) + copy(autoGroups, u.AutoGroups) + } pats := make(map[string]*PersonalAccessToken, len(u.PATs)) for k, v := range u.PATs { diff --git a/management/server/user.go b/management/server/user.go index 49aa35698..c5699ab07 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -745,6 +745,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact updatedUser.Role = update.Role updatedUser.Blocked = update.Blocked updatedUser.AutoGroups = update.AutoGroups + updatedUser.StoreAutoGroups() // these two fields can't be set via API, only via direct call to the method updatedUser.Issued = update.Issued updatedUser.IntegrationReference = update.IntegrationReference @@ -772,11 +773,6 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact return false, nil, nil, nil, err } - updatedUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, updatedUser.Id) - if err != nil { - return false, nil, nil, nil, err - } - userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, isNewUser, removedGroupsIDs, addedGroupsIDs, transaction) return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil diff --git a/management/server/user_test.go b/management/server/user_test.go index 6d356a8b1..bb7d84efe 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -345,6 +345,9 @@ func TestUser_Copy(t *testing.T) { IsServiceUser: true, ServiceUserName: "servicename", AutoGroups: []string{"group1", "group2"}, + Groups: []*types.GroupUser{ + {AccountID: "accountId", GroupID: "groupId", UserID: "userId"}, + }, PATs: map[string]*types.PersonalAccessToken{ "pat1": { ID: "pat1", @@ -1338,6 +1341,8 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { t.Fatal(err) } + account, err = manager.Store.GetAccount(context.Background(), account.Id) + updated, err := manager.SaveUser(context.Background(), account.Id, tc.initiatorID, tc.update) if tc.expectedErr { require.Errorf(t, err, "expecting SaveUser to throw an error") @@ -1653,6 +1658,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { LastLogin: time.Time{}, Issued: "api", IntegrationReference: integration_reference.IntegrationReference{}, + AutoGroups: []string{}, }, Permissions: mergeRolePermissions(roles.User), }, @@ -1672,6 +1678,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { LastLogin: time.Time{}, Issued: "api", IntegrationReference: integration_reference.IntegrationReference{}, + AutoGroups: []string{}, }, Permissions: mergeRolePermissions(roles.Admin), }, @@ -1691,6 +1698,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { LastLogin: time.Time{}, Issued: "api", IntegrationReference: integration_reference.IntegrationReference{}, + AutoGroups: []string{}, }, Permissions: mergeRolePermissions(roles.User), Restricted: true, @@ -1712,6 +1720,7 @@ func TestDefaultAccountManager_GetCurrentUserInfo(t *testing.T) { LastLogin: time.Time{}, Issued: "api", IntegrationReference: integration_reference.IntegrationReference{}, + AutoGroups: []string{}, }, Permissions: mergeRolePermissions(roles.User), Restricted: false,