From ccab3b427fe11b6343916373c2599dbfaef7f9fd Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 18 Sep 2024 14:24:39 +0300 Subject: [PATCH] refactor getAccountFromToken Signed-off-by: bcmmbaga --- management/server/account.go | 195 ++++++++++++++++++------------ management/server/account_test.go | 24 +++- 2 files changed, 135 insertions(+), 84 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index b470e8036..0108c2758 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1625,13 +1625,21 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai } // redeemInvite checks whether user has been invited and redeems the invite -func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error { +func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID, userID string) error { // only possible with the enabled IdP manager if am.idpManager == nil { log.WithContext(ctx).Warnf("invites only work with enabled IdP manager") return nil } + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + + account, err := am.Store.GetAccount(ctx, accountID) + unlock() + if err != nil { + return err + } + user, err := am.lookupUserInCache(ctx, userID, account) if err != nil { return err @@ -1751,94 +1759,112 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } - newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims) - if err != nil { - return "", "", err - } - unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id) - alreadyUnlocked := false - defer func() { - if !alreadyUnlocked { - unlock() - } - }() - - account, err := am.Store.GetAccount(ctx, newAcc.Id) + accountID, err := am.getAccountWithAuthorizationClaims(ctx, claims) if err != nil { return "", "", err } - user := account.Users[claims.UserId] - if user == nil { + user, err := am.Store.GetUserByUserID(ctx, claims.UserId) + if err != nil { // this is not really possible because we got an account by user ID return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) } if !user.IsServiceUser && claims.Invited { - err = am.redeemInvite(ctx, account, claims.UserId) + err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { return "", "", err } } - if account.Settings.JWTGroupsEnabled { - if account.Settings.JWTGroupsClaimName == "" { - log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") + if err = am.syncJWTGroups(ctx, claims, accountID); err != nil { + return "", "", err + } - return account.Id, user.Id, nil + return accountID, user.Id, nil +} + +// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, +// and propagates changes to peers if group propagation is enabled. +func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims, accountID string) error { + settings, err := am.Store.GetAccountSettings(ctx, accountID) + if err != nil { + return err + } + + if !settings.JWTGroupsEnabled { + return nil + } + + if settings.JWTGroupsClaimName == "" { + log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") + return nil + } + + jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) + + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + return err + } + + user, err := account.FindUser(claims.UserId) + if err != nil { + return nil + } + + oldGroups := make([]string, len(user.AutoGroups)) + copy(oldGroups, user.AutoGroups) + + // Update the account if group membership changes + if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { + addNewGroups := difference(user.AutoGroups, oldGroups) + removeOldGroups := difference(oldGroups, user.AutoGroups) + + if settings.GroupsPropagationEnabled { + account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) + account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) + account.Network.IncSerial() } - jwtGroupsNames := extractJWTGroups(ctx, account.Settings.JWTGroupsClaimName, claims) + if err := am.Store.SaveAccount(ctx, account); err != nil { + log.WithContext(ctx).Errorf("failed to save account: %v", err) + return nil + } - oldGroups := make([]string, len(user.AutoGroups)) - copy(oldGroups, user.AutoGroups) - // if groups were added or modified, save the account - if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { - if account.Settings.GroupsPropagationEnabled { - if user, err := account.FindUser(claims.UserId); err == nil { - addNewGroups := difference(user.AutoGroups, oldGroups) - removeOldGroups := difference(oldGroups, user.AutoGroups) - account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) - account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) - account.Network.IncSerial() - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to save account: %v", err) - } else { - log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) - unlock() - alreadyUnlocked = true - for _, g := range addNewGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) - } - } - for _, g := range removeOldGroups { - if group := account.GetGroup(g); group != nil { - am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{ - "group": group.Name, - "group_id": group.ID, - "is_service_user": user.IsServiceUser, - "user_name": user.ServiceUserName}) - } - } - } - } - } else { - if err := am.Store.SaveAccount(ctx, account); err != nil { - log.WithContext(ctx).Errorf("failed to save account: %v", err) - } + // Propagate changes to peers if group propagation is enabled + if settings.GroupsPropagationEnabled { + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) + am.updateAccountPeers(ctx, account) + } + + for _, g := range addNewGroups { + if group := account.GetGroup(g); group != nil { + am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, + map[string]any{ + "group": group.Name, + "group_id": group.ID, + "is_service_user": user.IsServiceUser, + "user_name": user.ServiceUserName}) + } + } + + for _, g := range removeOldGroups { + if group := account.GetGroup(g); group != nil { + am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser, + map[string]any{ + "group": group.Name, + "group_id": group.ID, + "is_service_user": user.IsServiceUser, + "user_name": user.ServiceUserName}) } } } - return account.Id, user.Id, nil + return nil } // getAccountWithAuthorizationClaims retrievs an account using JWT Claims. @@ -1858,27 +1884,31 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims // Existing user + Existing account + Existing Indexed Domain -> Nothing changes // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) -func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) if claims.UserId == "" { - return nil, fmt.Errorf("user ID is empty") + return "", fmt.Errorf("user ID is empty") } // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) + account, err := am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) + if err != nil { + return "", nil + } + return account.Id, nil } else if claims.AccountId != "" { accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId) if err != nil { - return nil, err + return "", err } if _, ok := accountFromID.Users[claims.UserId]; !ok { - return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) } if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain { - return accountFromID, nil + return accountFromID.Id, nil } } @@ -1893,7 +1923,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C // if NotFound we are good to continue, otherwise return error e, ok := status.FromError(err) if !ok || e.Type() != status.NotFound { - return nil, err + return "", err } } @@ -1903,7 +1933,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C defer unlockAccount() account, err = am.Store.GetAccountByUser(ctx, claims.UserId) if err != nil { - return nil, err + return "", err } // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, // we compare the account's ID with the domain account ID, and if they don't match, we set the account as @@ -1914,22 +1944,27 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims) if err != nil { - return nil, err + return "", err } - return account, nil + return account.Id, nil } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if domainAccount != nil { unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id) defer unlockAccount() domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) if err != nil { - return nil, err + return "", err } } - return am.handleNewUserAccount(ctx, domainAccount, claims) + + account, err = am.handleNewUserAccount(ctx, domainAccount, claims) + if err != nil { + return "", err + } + return account.Id, nil } else { // other error - return nil, err + return "", err } } diff --git a/management/server/account_test.go b/management/server/account_test.go index 03b5fa83e..c6fac07b2 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -645,8 +645,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { testCase.inputClaims.AccountId = initAccount.Id } - account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims) + accountID, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims) require.NoError(t, err, "support function failed") + + account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "") + require.NoError(t, err, "get account by account id") + verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) @@ -685,8 +689,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { } t.Run("JWT groups disabled", func(t *testing.T) { - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "") + require.NoError(t, err, "get account by account id") + require.Len(t, account.Groups, 1, "only ALL group should exists") }) @@ -696,8 +704,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "save account failed") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "") + require.NoError(t, err, "get account by account id") + require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") }) @@ -708,8 +720,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.NoError(t, err, "save account failed") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") - account, _, err := manager.GetAccountFromToken(context.Background(), claims) + accountID, _, err := manager.GetAccountFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") + + account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "") + require.NoError(t, err, "get account by account id") + require.Len(t, account.Groups, 3, "groups should be added to the account") groupsByNames := map[string]*group.Group{}