diff --git a/management/server/account.go b/management/server/account.go index ea081f0e0..1e211b919 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1,6 +1,7 @@ package server import ( + "fmt" "reflect" "strings" "sync" @@ -375,6 +376,10 @@ func (am *DefaultAccountManager) handleNewUserAccount( if domainAcc != nil { account = domainAcc account.Users[claims.UserId] = NewRegularUser(claims.UserId) + err = am.Store.SaveAccount(account) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed saving updated account") + } } else { account = NewAccount(claims.UserId, lowerDomain) account.Users[claims.UserId] = NewAdminUser(claims.UserId) @@ -417,10 +422,13 @@ func (am *DefaultAccountManager) GetAccountWithAuthorizationClaims( if claims.DomainCategory != PrivateCategory { return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain) } else if claims.AccountId != "" { - accountFromID, err := am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain) + accountFromID, err := am.GetAccountById(claims.AccountId) if err != nil { return nil, err } + if _, ok := accountFromID.Users[claims.UserId]; !ok { + return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + } if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory { return accountFromID, nil }