refactor getAccountFromToken

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-09-18 14:24:39 +03:00
parent e5d55d3c10
commit ccab3b427f
2 changed files with 135 additions and 84 deletions

View File

@@ -1625,13 +1625,21 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
} }
// redeemInvite checks whether user has been invited and redeems the invite // redeemInvite checks whether user has been invited and redeems the invite
func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error { func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID, userID string) error {
// only possible with the enabled IdP manager // only possible with the enabled IdP manager
if am.idpManager == nil { if am.idpManager == nil {
log.WithContext(ctx).Warnf("invites only work with enabled IdP manager") log.WithContext(ctx).Warnf("invites only work with enabled IdP manager")
return nil return nil
} }
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
account, err := am.Store.GetAccount(ctx, accountID)
unlock()
if err != nil {
return err
}
user, err := am.lookupUserInCache(ctx, userID, account) user, err := am.lookupUserInCache(ctx, userID, account)
if err != nil { if err != nil {
return err return err
@@ -1751,94 +1759,112 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled")
} }
newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims) accountID, err := am.getAccountWithAuthorizationClaims(ctx, claims)
if err != nil {
return "", "", err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id)
alreadyUnlocked := false
defer func() {
if !alreadyUnlocked {
unlock()
}
}()
account, err := am.Store.GetAccount(ctx, newAcc.Id)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
user := account.Users[claims.UserId] user, err := am.Store.GetUserByUserID(ctx, claims.UserId)
if user == nil { if err != nil {
// this is not really possible because we got an account by user ID // this is not really possible because we got an account by user ID
return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId)
} }
if !user.IsServiceUser && claims.Invited { if !user.IsServiceUser && claims.Invited {
err = am.redeemInvite(ctx, account, claims.UserId) err = am.redeemInvite(ctx, accountID, user.Id)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
} }
if account.Settings.JWTGroupsEnabled { if err = am.syncJWTGroups(ctx, claims, accountID); err != nil {
if account.Settings.JWTGroupsClaimName == "" { return "", "", err
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") }
return account.Id, user.Id, nil return accountID, user.Id, nil
}
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims, accountID string) error {
settings, err := am.Store.GetAccountSettings(ctx, accountID)
if err != nil {
return err
}
if !settings.JWTGroupsEnabled {
return nil
}
if settings.JWTGroupsClaimName == "" {
log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set")
return nil
}
jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims)
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
user, err := account.FindUser(claims.UserId)
if err != nil {
return nil
}
oldGroups := make([]string, len(user.AutoGroups))
copy(oldGroups, user.AutoGroups)
// Update the account if group membership changes
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) {
addNewGroups := difference(user.AutoGroups, oldGroups)
removeOldGroups := difference(oldGroups, user.AutoGroups)
if settings.GroupsPropagationEnabled {
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...)
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...)
account.Network.IncSerial()
} }
jwtGroupsNames := extractJWTGroups(ctx, account.Settings.JWTGroupsClaimName, claims) if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to save account: %v", err)
return nil
}
oldGroups := make([]string, len(user.AutoGroups)) // Propagate changes to peers if group propagation is enabled
copy(oldGroups, user.AutoGroups) if settings.GroupsPropagationEnabled {
// if groups were added or modified, save the account log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
if account.SetJWTGroups(claims.UserId, jwtGroupsNames) { am.updateAccountPeers(ctx, account)
if account.Settings.GroupsPropagationEnabled { }
if user, err := account.FindUser(claims.UserId); err == nil {
addNewGroups := difference(user.AutoGroups, oldGroups) for _, g := range addNewGroups {
removeOldGroups := difference(oldGroups, user.AutoGroups) if group := account.GetGroup(g); group != nil {
account.UserGroupsAddToPeers(claims.UserId, addNewGroups...) am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser,
account.UserGroupsRemoveFromPeers(claims.UserId, removeOldGroups...) map[string]any{
account.Network.IncSerial() "group": group.Name,
if err := am.Store.SaveAccount(ctx, account); err != nil { "group_id": group.ID,
log.WithContext(ctx).Errorf("failed to save account: %v", err) "is_service_user": user.IsServiceUser,
} else { "user_name": user.ServiceUserName})
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) }
am.updateAccountPeers(ctx, account) }
unlock()
alreadyUnlocked = true for _, g := range removeOldGroups {
for _, g := range addNewGroups { if group := account.GetGroup(g); group != nil {
if group := account.GetGroup(g); group != nil { am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupAddedToUser, map[string]any{
map[string]any{ "group": group.Name,
"group": group.Name, "group_id": group.ID,
"group_id": group.ID, "is_service_user": user.IsServiceUser,
"is_service_user": user.IsServiceUser, "user_name": user.ServiceUserName})
"user_name": user.ServiceUserName})
}
}
for _, g := range removeOldGroups {
if group := account.GetGroup(g); group != nil {
am.StoreEvent(ctx, user.Id, user.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": user.IsServiceUser,
"user_name": user.ServiceUserName})
}
}
}
}
} else {
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).Errorf("failed to save account: %v", err)
}
} }
} }
} }
return account.Id, user.Id, nil return nil
} }
// getAccountWithAuthorizationClaims retrievs an account using JWT Claims. // getAccountWithAuthorizationClaims retrievs an account using JWT Claims.
@@ -1858,27 +1884,31 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims
// Existing user + Existing account + Existing Indexed Domain -> Nothing changes // Existing user + Existing account + Existing Indexed Domain -> Nothing changes
// //
// Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain)
func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) { func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"",
claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory)
if claims.UserId == "" { if claims.UserId == "" {
return nil, fmt.Errorf("user ID is empty") return "", fmt.Errorf("user ID is empty")
} }
// if Account ID is part of the claims // if Account ID is part of the claims
// it means that we've already classified the domain and user has an account // it means that we've already classified the domain and user has an account
if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) {
return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) account, err := am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain)
if err != nil {
return "", nil
}
return account.Id, nil
} else if claims.AccountId != "" { } else if claims.AccountId != "" {
accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId) accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId)
if err != nil { if err != nil {
return nil, err return "", err
} }
if _, ok := accountFromID.Users[claims.UserId]; !ok { if _, ok := accountFromID.Users[claims.UserId]; !ok {
return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId)
} }
if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain { if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain {
return accountFromID, nil return accountFromID.Id, nil
} }
} }
@@ -1893,7 +1923,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
// if NotFound we are good to continue, otherwise return error // if NotFound we are good to continue, otherwise return error
e, ok := status.FromError(err) e, ok := status.FromError(err)
if !ok || e.Type() != status.NotFound { if !ok || e.Type() != status.NotFound {
return nil, err return "", err
} }
} }
@@ -1903,7 +1933,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
defer unlockAccount() defer unlockAccount()
account, err = am.Store.GetAccountByUser(ctx, claims.UserId) account, err = am.Store.GetAccountByUser(ctx, claims.UserId)
if err != nil { if err != nil {
return nil, err return "", err
} }
// If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise,
// we compare the account's ID with the domain account ID, and if they don't match, we set the account as // we compare the account's ID with the domain account ID, and if they don't match, we set the account as
@@ -1914,22 +1944,27 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C
err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims) err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims)
if err != nil { if err != nil {
return nil, err return "", err
} }
return account, nil return account.Id, nil
} else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
if domainAccount != nil { if domainAccount != nil {
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccount.Id)
defer unlockAccount() defer unlockAccount()
domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain)
if err != nil { if err != nil {
return nil, err return "", err
} }
} }
return am.handleNewUserAccount(ctx, domainAccount, claims)
account, err = am.handleNewUserAccount(ctx, domainAccount, claims)
if err != nil {
return "", err
}
return account.Id, nil
} else { } else {
// other error // other error
return nil, err return "", err
} }
} }

View File

@@ -645,8 +645,12 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) {
testCase.inputClaims.AccountId = initAccount.Id testCase.inputClaims.AccountId = initAccount.Id
} }
account, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims) accountID, _, err := manager.GetAccountFromToken(context.Background(), testCase.inputClaims)
require.NoError(t, err, "support function failed") require.NoError(t, err, "support function failed")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "get account by account id")
verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers) verifyNewAccountHasDefaultFields(t, account, testCase.expectedCreatedBy, testCase.inputClaims.Domain, testCase.expectedUsers)
verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy) verifyCanAddPeerToAccount(t, manager, account, testCase.expectedCreatedBy)
@@ -685,8 +689,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
} }
t.Run("JWT groups disabled", func(t *testing.T) { t.Run("JWT groups disabled", func(t *testing.T) {
account, _, err := manager.GetAccountFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "get account by account id")
require.Len(t, account.Groups, 1, "only ALL group should exists") require.Len(t, account.Groups, 1, "only ALL group should exists")
}) })
@@ -696,8 +704,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed") require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "get account by account id")
require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT")
}) })
@@ -708,8 +720,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) {
require.NoError(t, err, "save account failed") require.NoError(t, err, "save account failed")
require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist")
account, _, err := manager.GetAccountFromToken(context.Background(), claims) accountID, _, err := manager.GetAccountFromToken(context.Background(), claims)
require.NoError(t, err, "get account by token failed") require.NoError(t, err, "get account by token failed")
account, err := manager.GetAccountByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "get account by account id")
require.Len(t, account.Groups, 3, "groups should be added to the account") require.Len(t, account.Groups, 3, "groups should be added to the account")
groupsByNames := map[string]*group.Group{} groupsByNames := map[string]*group.Group{}