diff --git a/management/server/account.go b/management/server/account.go index 726ef0173..59c9c7fb0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -85,7 +85,7 @@ type AccountManager interface { GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error - GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) + GetAccountInfoFromPAT(ctx context.Context, token string) (*User, *PersonalAccessToken, string, string, error) DeleteAccount(ctx context.Context, accountID, userID string) error MarkPATUsed(ctx context.Context, tokenID string) error GetUserByID(ctx context.Context, id string) (*User, error) @@ -1363,13 +1363,13 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u continue } - deleteUserErr := am.deleteRegularUser(ctx, account, userID, otherUser.Id) + _, deleteUserErr := am.deleteRegularUser(ctx, accountID, userID, otherUser.Id) if deleteUserErr != nil { return deleteUserErr } } - err = am.deleteRegularUser(ctx, account, userID, userID) + _, err = am.deleteRegularUser(ctx, accountID, userID, userID) if err != nil { log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err) return err @@ -1426,20 +1426,8 @@ func isNil(i idp.Manager) bool { // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { - accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) - if err != nil { - return err - } - cachedAccount := &Account{ - Id: accountID, - Users: make(map[string]*User), - } - for _, user := range accountUsers { - cachedAccount.Users[user.Id] = user - } - // user can be nil if it wasn't found (e.g., just created) - user, err := am.lookupUserInCache(ctx, userID, cachedAccount) + user, err := am.lookupUserInCache(ctx, userID, accountID) if err != nil { return err } @@ -1515,10 +1503,15 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e } // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil -func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *Account) (*idp.UserData, error) { - users := make(map[string]userLoggedInOnce, len(account.Users)) +func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) { + accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + users := make(map[string]userLoggedInOnce, len(accountUsers)) // ignore service users and users provisioned by integrations than are never logged in - for _, user := range account.Users { + for _, user := range accountUsers { if user.IsServiceUser { continue } @@ -1527,8 +1520,8 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s } users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) } - log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, account.Id) - userData, err := am.lookupCache(ctx, users, account.Id) + log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, accountID) + userData, err := am.lookupCache(ctx, users, accountID) if err != nil { return nil, err } @@ -1541,13 +1534,13 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s // add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP, // or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta - user, err := account.FindUser(userID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, account.Id) + log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID) return nil, err } - key := user.IntegrationReference.CacheKey(account.Id, userID) + key := user.IntegrationReference.CacheKey(accountID, userID) ud, err := am.externalCacheManager.Get(am.ctx, key) if err != nil { log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err) @@ -1787,9 +1780,9 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) defer unlockAccount() - usersMap := make(map[string]*User) - usersMap[claims.UserId] = NewRegularUser(claims.UserId) - err := am.Store.SaveUsers(domainAccountID, usersMap) + newUser := NewRegularUser(claims.UserId) + newUser.AccountID = domainAccountID + err := am.Store.SaveUser(ctx, LockingStrengthUpdate, newUser) if err != nil { return "", err } @@ -1812,12 +1805,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str return nil } - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - user, err := am.lookupUserInCache(ctx, userID, account) + user, err := am.lookupUserInCache(ctx, userID, accountID) if err != nil { return err } @@ -1827,17 +1815,17 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str } if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite { - log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, account.Id) + log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID) // User has already logged in, meaning that IdP should have set wt_pending_invite to false. // Our job is to just reload cache. go func() { - _, err = am.refreshCache(ctx, account.Id) + _, err = am.refreshCache(ctx, accountID) if err != nil { - log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id) + log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID) return } - log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, account.Id) - am.StoreEvent(ctx, userID, userID, account.Id, activity.UserJoined, nil) + log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID) + am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil) }() } @@ -1846,33 +1834,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str // MarkPATUsed marks a personal access token as used func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error { - - user, err := am.Store.GetUserByTokenID(ctx, tokenID) - if err != nil { - return err - } - - account, err := am.Store.GetAccountByUser(ctx, user.Id) - if err != nil { - return err - } - - unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) - defer unlock() - - account, err = am.Store.GetAccountByUser(ctx, user.Id) - if err != nil { - return err - } - - pat, ok := account.Users[user.Id].PATs[tokenID] - if !ok { - return fmt.Errorf("token not found") - } - - pat.LastUsed = time.Now().UTC() - - return am.Store.SaveAccount(ctx, account) + return am.Store.MarkPATUsed(ctx, LockingStrengthUpdate, tokenID) } // GetAccount returns an account associated with this account ID. @@ -1880,52 +1842,65 @@ func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID strin return am.Store.GetAccount(ctx, accountID) } -// GetAccountFromPAT returns Account and User associated with a personal access token -func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) { +// GetAccountInfoFromPAT retrieves user, personal access token, domain, and category details from a personal access token. +func (am *DefaultAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (user *User, pat *PersonalAccessToken, domain string, category string, err error) { + user, pat, err = am.extractPATFromToken(ctx, token) + if err != nil { + return nil, nil, "", "", err + } + + domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, user.AccountID) + if err != nil { + return nil, nil, "", "", err + } + + return user, pat, domain, category, nil +} + +// extractPATFromToken validates the token structure and retrieves associated User and PAT. +func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*User, *PersonalAccessToken, error) { if len(token) != PATLength { - return nil, nil, nil, fmt.Errorf("token has wrong length") + return nil, nil, fmt.Errorf("token has incorrect length") } prefix := token[:len(PATPrefix)] if prefix != PATPrefix { - return nil, nil, nil, fmt.Errorf("token has wrong prefix") + return nil, nil, fmt.Errorf("token has incorrect prefix") } + secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength] encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength] verificationChecksum, err := base62.Decode(encodedChecksum) if err != nil { - return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) + return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err) } secretChecksum := crc32.ChecksumIEEE([]byte(secret)) if secretChecksum != verificationChecksum { - return nil, nil, nil, fmt.Errorf("token checksum does not match") + return nil, nil, fmt.Errorf("token checksum does not match") } hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken) + + var user *User + var pat *PersonalAccessToken + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + pat, err = transaction.GetPATByHashedToken(ctx, LockingStrengthShare, encodedHashedToken) + if err != nil { + return err + } + + user, err = transaction.GetUserByPATID(ctx, LockingStrengthShare, pat.ID) + return err + }) if err != nil { - return nil, nil, nil, err + return nil, nil, err } - user, err := am.Store.GetUserByTokenID(ctx, tokenID) - if err != nil { - return nil, nil, nil, err - } - - account, err := am.Store.GetAccountByUser(ctx, user.Id) - if err != nil { - return nil, nil, nil, err - } - - pat := user.PATs[tokenID] - if pat == nil { - return nil, nil, nil, fmt.Errorf("personal access token not found") - } - - return account, user, pat, nil + return user, pat, nil } // GetAccountByID returns an account associated with this account ID. diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index e1a84b4f9..3e465e32e 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -54,7 +54,7 @@ type MockAccountManager struct { DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) - GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + GetAccountInfoFromPATFunc func(ctx context.Context, token string) (*server.User, *server.PersonalAccessToken, string, string, error) MarkPATUsedFunc func(ctx context.Context, pat string) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) @@ -234,12 +234,12 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -// GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface -func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { - if am.GetAccountFromPATFunc != nil { - return am.GetAccountFromPATFunc(ctx, pat) +// GetAccountInfoFromPAT mock implementation of GetAccountInfoFromPAT from server.AccountManager interface +func (am *MockAccountManager) GetAccountInfoFromPAT(ctx context.Context, pat string) (*server.User, *server.PersonalAccessToken, string, string, error) { + if am.GetAccountInfoFromPATFunc != nil { + return am.GetAccountInfoFromPATFunc(ctx, pat) } - return nil, nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromPAT is not implemented") + return nil, nil, "", "", status.Errorf(codes.Unimplemented, "method GetAccountInfoFromPAT is not implemented") } // DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface