diff --git a/management/server/account.go b/management/server/account.go index 7a9f104ae..7c80ae6e5 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -75,7 +75,7 @@ type AccountManager interface { SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) - GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) + GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) @@ -1252,25 +1252,30 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return nil } -// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and +// GetAccountIDByUserOrAccountID looks for an account by user or accountID, if no account is provided and // userID doesn't have an account associated with it, one account is created // domain is used to create a new account if no account is found -func (am *DefaultAccountManager) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) { +func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (string, error) { if accountID != "" { - return am.Store.GetAccount(ctx, accountID) + _, _, err := am.Store.GetAccountDomainAndCategory(ctx, accountID) + if err != nil { + return "", err + } + return accountID, nil } else if userID != "" { account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) if err != nil { - return nil, status.Errorf(status.NotFound, "account not found using user id: %s", userID) + return "", status.Errorf(status.NotFound, "account not found using user id: %s", userID) } + err = am.addAccountIDToIDPAppMeta(ctx, userID, account) if err != nil { - return nil, err + return "", err } - return account, nil + return account.Id, nil } - return nil, status.Errorf(status.NotFound, "no valid user or account Id provided") + return "", status.Errorf(status.NotFound, "no valid user or account Id provided") } func isNil(i idp.Manager) bool { @@ -1613,13 +1618,21 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai } // redeemInvite checks whether user has been invited and redeems the invite -func (am *DefaultAccountManager) redeemInvite(ctx context.Context, account *Account, userID string) error { +func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID string, userID string) error { // only possible with the enabled IdP manager if am.idpManager == nil { log.WithContext(ctx).Warnf("invites only work with enabled IdP manager") return nil } + unlock := am.Store.AcquireReadLockByUID(ctx, accountID) + + account, err := am.Store.GetAccount(ctx, accountID) + unlock() + if err != nil { + return err + } + user, err := am.lookupUserInCache(ctx, userID, account) if err != nil { return err @@ -1739,7 +1752,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } - account, err := am.getAccountWithAuthorizationClaims(ctx, claims) + accountID, err := am.getAccountIDWithAuthorizationClaims(ctx, claims) if err != nil { return nil, nil, err } @@ -1751,26 +1764,28 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims } if !user.IsServiceUser && claims.Invited { - err = am.redeemInvite(ctx, account, user.Id) + err = am.redeemInvite(ctx, accountID, user.Id) if err != nil { return nil, nil, err } } - unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) - defer unlock() - - if err = am.syncJWTGroups(ctx, account, user, claims); err != nil { + if err = am.syncJWTGroups(ctx, accountID, user, claims); err != nil { return nil, nil, err } - return account, user, nil + // TODO: return account id, user id and error + return &Account{Id: accountID}, user, nil } // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. -func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, account *Account, user *User, claims jwtclaims.AuthorizationClaims) error { - settings := account.Settings +func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, user *User, claims jwtclaims.AuthorizationClaims) error { + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + if settings == nil || !settings.JWTGroupsEnabled { return nil } @@ -1780,6 +1795,14 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, account *Acc return nil } + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + return err + } + jwtGroupsNames := extractJWTGroups(ctx, settings.JWTGroupsClaimName, claims) oldGroups := make([]string, len(user.AutoGroups)) @@ -1833,7 +1856,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, account *Acc return nil } -// getAccountWithAuthorizationClaims retrievs an account using JWT Claims. +// getAccountIDWithAuthorizationClaims retrieves an account ID using JWT Claims. // if domain is of the PrivateCategory category, it will evaluate // if account is new, existing or if there is another account with the same domain // @@ -1850,27 +1873,34 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, account *Acc // Existing user + Existing account + Existing Indexed Domain -> Nothing changes // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) -func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, error) { +func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { log.WithContext(ctx).Tracef("getting account with authorization claims. User ID: \"%s\", Account ID: \"%s\", Domain: \"%s\", Domain Category: \"%s\"", claims.UserId, claims.AccountId, claims.Domain, claims.DomainCategory) if claims.UserId == "" { - return nil, fmt.Errorf("user ID is empty") + return "", fmt.Errorf("user ID is empty") } // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) + return am.GetAccountIDByUserOrAccountID(ctx, claims.UserId, claims.AccountId, claims.Domain) } else if claims.AccountId != "" { - accountFromID, err := am.Store.GetAccount(ctx, claims.AccountId) + userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) if err != nil { - return nil, err + return "", err } - if _, ok := accountFromID.Users[claims.UserId]; !ok { - return nil, fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) + + if userAccountID != claims.AccountId { + return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) } - if accountFromID.DomainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || accountFromID.Domain != claims.Domain { - return accountFromID, nil + + domain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, claims.AccountId) + if err != nil { + return "", err + } + + if domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain { + return userAccountID, nil } } @@ -1885,7 +1915,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C // if NotFound we are good to continue, otherwise return error e, ok := status.FromError(err) if !ok || e.Type() != status.NotFound { - return nil, err + return "", err } } @@ -1895,7 +1925,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C defer unlockAccount() account, err := am.Store.GetAccountByUser(ctx, claims.UserId) if err != nil { - return nil, err + return "", err } // If there is no primary domain account yet, we set the account as primary for the domain. Otherwise, // we compare the account's ID with the domain account ID, and if they don't match, we set the account as @@ -1903,12 +1933,11 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C // was previously unclassified or classified as public so N users that logged int that time, has they own account // and peers that shouldn't be lost. primaryDomain := domainAccountID == "" || account.Id == domainAccountID - - err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims) - if err != nil { - return nil, err + if err = am.handleExistingUserAccount(ctx, account, primaryDomain, claims); err != nil { + return "", err } - return account, nil + + return account.Id, nil } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { var domainAccount *Account if domainAccountID != "" { @@ -1916,14 +1945,18 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C defer unlockAccount() domainAccount, err = am.Store.GetAccountByPrivateDomain(ctx, claims.Domain) if err != nil { - return nil, err + return "", err } } - return am.handleNewUserAccount(ctx, domainAccount, claims) + account, err := am.handleNewUserAccount(ctx, domainAccount, claims) + if err != nil { + return "", err + } + return account.Id, nil } else { // other error - return nil, err + return "", err } } diff --git a/management/server/file_store.go b/management/server/file_store.go index 434a7710d..1b61b2a68 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -931,7 +931,7 @@ func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID strin return nil } -func (s *FileStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { +func (s *FileStore) GetPostureCheckByChecksDefinition(_ string, _ *posture.ChecksDefinition) (*posture.Checks, error) { return nil, status.Errorf(status.Internal, "GetPostureCheckByChecksDefinition is not implemented") } @@ -950,14 +950,18 @@ func (s *FileStore) GetStoreEngine() StoreEngine { return FileStoreEngine } -func (s *FileStore) SaveUsers(accountID string, users map[string]*User) error { +func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error { return status.Errorf(status.Internal, "SaveUsers is not implemented") } -func (s *FileStore) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error { +func (s *FileStore) SaveGroups(_ string, _ map[string]*nbgroup.Group) error { return status.Errorf(status.Internal, "SaveGroups is not implemented") } -func (s *FileStore) GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) { +func (s *FileStore) GetAccountIDByPrivateDomain(_ context.Context, _ string) (string, error) { return "", status.Errorf(status.Internal, "GetAccountIDByPrivateDomain is not implemented") } + +func (s *FileStore) GetAccountDomainAndCategory(_ context.Context, _ string) (string, string, error) { + return "", "", status.Errorf(status.Internal, "GetAccountDomainAndCategory is not implemented") +} diff --git a/management/server/sql_store.go b/management/server/sql_store.go index a95acfd54..6a667b398 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1033,3 +1033,18 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store { db: tx, } } + +// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. +func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, accountID string) (string, string, error) { + var account Account + result := s.db.WithContext(ctx).Model(&Account{}).Select("domain", "domain_category"). + Where(idQueryCondition, accountID).First(&account) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", "", status.Errorf(status.NotFound, "account not found") + } + return "", "", status.Errorf(status.Internal, "failed to retrieve account fields") + } + + return account.Domain, account.DomainCategory, nil +} diff --git a/management/server/store.go b/management/server/store.go index 3bf15c1b5..54a559605 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -39,6 +39,7 @@ const ( type Store interface { GetAllAccounts(ctx context.Context) []*Account GetAccount(ctx context.Context, accountID string) (*Account, error) + GetAccountDomainAndCategory(ctx context.Context, accountID string) (string, string, error) DeleteAccount(ctx context.Context, account *Account) error GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) diff --git a/management/server/user.go b/management/server/user.go index 7e5574e4b..193333685 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -360,16 +360,11 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u // GetUser looks up a user by provided authorization claims. // It will also create an account if didn't exist for this user before. func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) { - account, _, err := am.GetAccountFromToken(ctx, claims) + account, user, err := am.GetAccountFromToken(ctx, claims) if err != nil { return nil, fmt.Errorf("failed to get account with token claims %v", err) } - user, ok := account.Users[claims.UserId] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") - } - // this code should be outside of the am.GetAccountFromToken(claims) because this method is called also by the gRPC // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. newLogin := user.LastDashboardLoginChanged(claims.LastLogin)