wip: refactoring

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-02 11:56:47 +03:00
parent 78e238646c
commit 0297b5f142
7 changed files with 375 additions and 240 deletions

View File

@@ -1125,6 +1125,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
// TODO: call direct on the store to get expired peers
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed getting account %s expiring peers", accountID)
@@ -1139,7 +1140,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil {
log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", account.Id)
return account.GetNextPeerExpiration()
}
@@ -1296,7 +1297,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Conte
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
}
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
return "", err
}
@@ -1311,28 +1312,28 @@ func isNil(i idp.Manager) bool {
}
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error {
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) {
// user can be nil if it wasn't found (e.g., just created)
user, err := am.lookupUserInCache(ctx, userID, account)
user, err := am.lookupUserInCache(ctx, userID, accountID)
if err != nil {
return err
}
if user != nil && user.AppMetadata.WTAccountID == account.Id {
if user != nil && user.AppMetadata.WTAccountID == accountID {
// it was already set, so we skip the unnecessary update
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
account.Id, userID)
accountID, userID)
return nil
}
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id})
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
if err != nil {
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
}
// refresh cache to reflect the update
_, err = am.refreshCache(ctx, account.Id)
_, err = am.refreshCache(ctx, accountID)
if err != nil {
return err
}
@@ -1391,10 +1392,15 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
}
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *Account) (*idp.UserData, error) {
users := make(map[string]userLoggedInOnce, len(account.Users))
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
users := make(map[string]userLoggedInOnce, len(accountUsers))
// ignore service users and users provisioned by integrations than are never logged in
for _, user := range account.Users {
for _, user := range accountUsers {
if user.IsServiceUser {
continue
}
@@ -1403,8 +1409,9 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
}
users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
}
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(ctx, users, account.Id)
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, accountID)
userData, err := am.lookupCache(ctx, users, accountID)
if err != nil {
return nil, err
}
@@ -1417,13 +1424,13 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
user, err := account.FindUser(userID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, account.Id)
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
return nil, err
}
key := user.IntegrationReference.CacheKey(account.Id, userID)
key := user.IntegrationReference.CacheKey(accountID, userID)
ud, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err)
@@ -1591,13 +1598,14 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
primaryDomain bool,
claims jwtclaims.AuthorizationClaims,
) error {
// TODO: remove account as parameter and pass accountID string
err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain)
if err != nil {
return err
}
// we should register the account ID to this user's metadata in our IDP manager
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc)
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc.Id)
if err != nil {
return err
}
@@ -1635,7 +1643,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
}
}
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account)
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account.Id)
if err != nil {
return nil, err
}
@@ -1653,12 +1661,12 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
return nil
}
account, err := am.Store.GetAccount(ctx, accountID)
_, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
user, err := am.lookupUserInCache(ctx, userID, account)
user, err := am.lookupUserInCache(ctx, userID, accountID)
if err != nil {
return err
}
@@ -1668,17 +1676,17 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
}
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, account.Id)
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID)
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
// Our job is to just reload cache.
go func() {
_, err = am.refreshCache(ctx, account.Id)
_, err = am.refreshCache(ctx, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id)
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID)
return
}
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, account.Id)
am.StoreEvent(ctx, userID, userID, account.Id, activity.UserJoined, nil)
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID)
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
}()
}

View File

@@ -1704,7 +1704,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
require.NoError(t, err, "unable to mark peer connected")
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@@ -1821,7 +1821,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
},
}
// enabling PeerLoginExpirationEnabled should trigger the expiration job
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
})
@@ -1852,13 +1852,13 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false,
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
require.NoError(t, err, "unable to get account by ID")

View File

@@ -959,10 +959,6 @@ func (s *FileStore) GetStoreEngine() StoreEngine {
return FileStoreEngine
}
func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented")
}
func (s *FileStore) SaveGroups(_ context.Context, _ LockingStrength, _ []*nbgroup.Group) error {
return status.Errorf(status.Internal, "SaveGroups is not implemented")
}
@@ -1085,6 +1081,10 @@ func (s *FileStore) GetAccountPeers(_ context.Context, _ LockingStrength, _ stri
return nil, status.Errorf(status.Internal, "GetAccountPeers is not implemented")
}
func (s *FileStore) GetUserPeers(_ context.Context, _ LockingStrength, _, _ string) ([]*nbpeer.Peer, error) {
return nil, status.Errorf(status.Internal, "GetUserPeers is not implemented")
}
func (s *FileStore) GetAccountPeersWithExpiration(_ context.Context, _ LockingStrength, _ string) ([]*nbpeer.Peer, error) {
return nil, status.Errorf(status.Internal, "GetAccountPeersWithExpiration is not implemented")
}
@@ -1127,3 +1127,23 @@ func (s *FileStore) DeleteGroups(_ context.Context, _ LockingStrength, _ []strin
func (s *FileStore) GetAccountUsers(_ context.Context, _ LockingStrength, _ string) ([]*User, error) {
return nil, status.Errorf(status.Internal, "GetAccountUsers is not implemented")
}
func (s *FileStore) SaveUser(_ context.Context, _ LockingStrength, _ *User) error {
return status.Errorf(status.Internal, "SaveUser is not implemented")
}
func (s *FileStore) SaveUsers(_ context.Context, _ LockingStrength, _ []*User) error {
return status.Errorf(status.Internal, "SaveUsers is not implemented")
}
func (s *FileStore) DeleteUser(_ context.Context, _ LockingStrength, _, _ string) error {
return status.Errorf(status.Internal, "DeleteUser is not implemented")
}
func (s *FileStore) DeleteUsers(_ context.Context, _ LockingStrength, _ []string, _ string) error {
return status.Errorf(status.Internal, "DeleteUsers is not implemented")
}
func (s *FileStore) GetAccountOwnerID(_ context.Context, _ LockingStrength, _ string) (string, error) {
return "", status.Errorf(status.Internal, "GetAccountOwnerID is not implemented")
}

View File

@@ -32,7 +32,21 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
// routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
routes, err := am.Store.GetAccountRoutes(context.Background(), LockingStrengthShare, account.Id)
if err != nil {
return err
}
routesWithPrefix := make([]*route.Route, 0)
for _, r := range routes {
dynamic := r.IsDynamic()
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
!dynamic && r.Network.String() == prefix.String() {
routesWithPrefix = append(routesWithPrefix, r)
}
}
//routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
// lets remember all the peers and the peer groups from routesWithPrefix
seenPeers := make(map[string]bool)
@@ -51,8 +65,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true
group := account.GetGroup(groupID)
if group == nil {
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, groupID, account.Id)
if err != nil || group == nil {
return status.Errorf(
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
getRouteDescriptor(prefix, domains), groupID,
@@ -67,10 +81,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
if peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group
peer := account.GetPeer(peerID)
if peer == nil {
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, peerID, account.Id)
if err != nil || peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
if _, ok := seenPeers[peerID]; ok {
return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
@@ -79,7 +94,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
// check that peerGroupIDs are not in any route peerGroups list
for _, groupID := range peerGroupIDs {
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again.
// we validated the group existence before entering this function, no need to check again.
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, groupID, account.Id)
if err != nil || group == nil {
return status.Errorf(status.InvalidArgument, "group with ID %s not found", peerID)
}
if _, ok := seenPeerGroups[groupID]; ok {
return status.Errorf(
@@ -90,10 +109,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
for _, id := range group.Peers {
if _, ok := seenPeers[id]; ok {
peer := account.GetPeer(id)
if peer == nil {
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, peerID, account.Id)
if err != nil || peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s from the group %s already has this route",
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
@@ -151,10 +171,10 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
newRoute.ID = route.ID(xid.New().String())
if len(peerGroupIDs) > 0 {
err = validateGroups(peerGroupIDs, account.Groups)
if err != nil {
return nil, err
}
//err = validateGroups(peerGroupIDs, account.Groups)
//if err != nil {
// return nil, err
//}
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
@@ -170,10 +190,10 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
err = validateGroups(groups, account.Groups)
if err != nil {
return nil, err
}
//err = validateGroups(groups, account.Groups)
//if err != nil {
// return nil, err
//}
newRoute.Peer = peerID
newRoute.PeerGroups = peerGroupIDs
@@ -208,13 +228,19 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
// SaveRoute saves route
func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if routeToSave == nil {
return status.Errorf(status.InvalidArgument, "route provided is nil")
}
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "user not allowed to update route")
}
if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric {
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
}
@@ -223,16 +249,14 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
account, err := am.Store.GetAccount(ctx, accountID)
// Do not allow non-Linux peers
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, routeToSave.Peer, accountID)
if err != nil {
return err
}
// Do not allow non-Linux peers
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
if peer.Meta.GoOS != "linux" {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
if peer.Meta.GoOS != "linux" {
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
@@ -251,60 +275,78 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
}
if len(routeToSave.PeerGroups) > 0 {
err = validateGroups(routeToSave.PeerGroups, account.Groups)
if err != nil {
return err
}
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
groups, err := am.Store.GetAccountGroups(ctx, accountID)
if err != nil {
return err
}
_ = groups
err = validateGroups(routeToSave.Groups, account.Groups)
if err != nil {
return err
}
//if len(routeToSave.PeerGroups) > 0 {
// err = validateGroups(routeToSave.PeerGroups, groups)
// if err != nil {
// return err
// }
//}
account.Routes[routeToSave.ID] = routeToSave
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.updateAccountPeers(ctx, account)
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
//err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
//if err != nil {
// return err
//}
//
//err = validateGroups(routeToSave.Groups, account.Groups)
//if err != nil {
// return err
//}
//
//account.Routes[routeToSave.ID] = routeToSave
//
//account.Network.IncSerial()
//if err = am.Store.SaveAccount(ctx, account); err != nil {
// return err
//}
//
//am.updateAccountPeers(ctx, account)
//
//am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
return nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
routy := account.Routes[routeID]
if routy == nil {
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
if user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "user not allowed to delete route")
}
delete(account.Routes, routeID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
route, err := am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
if err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
err = transaction.DeleteRoute(ctx, LockingStrengthUpdate, string(routeID), accountID)
if err != nil {
return fmt.Errorf("failed to delete route: %w", err)
}
return nil
})
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
return nil

View File

@@ -361,21 +361,34 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
return nil
}
// SaveUsers saves the given list of users to the database.
// It updates existing users if a conflict occurs.
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
usersToSave := make([]User, 0, len(users))
for _, user := range users {
user.AccountID = accountID
for id, pat := range user.PATs {
pat.ID = id
user.PATsG = append(user.PATsG, *pat)
}
usersToSave = append(usersToSave, *user)
// SaveUser saves a user to the store.
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
return saveRecord[User](s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}), lockStrength, user)
}
// SaveUsers saves a list of users to the store.
func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error {
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&users)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save users to store: %v", result.Error)
}
return s.db.Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).
Create(&usersToSave).Error
return nil
}
// DeleteUser deletes a user from the store.
func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, userID, accountID string) error {
return deleteRecordByID[User](s.db.WithContext(ctx).Select(clause.Associations), lockStrength, userID, accountID)
}
// DeleteUsers deletes a list of users from the store.
func (s *SqlStore) DeleteUsers(ctx context.Context, strength LockingStrength, userIDs []string, accountID string) error {
result := s.db.WithContext(ctx).Select(clause.Associations).Clauses(clause.Locking{Strength: string(strength)}).
Where("id IN ? AND account_id = ?", userIDs, accountID).Delete(&User{})
if result.Error != nil {
return status.Errorf(status.Internal, "failed to delete users from store: %v", result.Error)
}
return nil
}
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
@@ -695,6 +708,22 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
return accountID, nil
}
// GetAccountOwnerID returns the owner ID of the account.
func (s *SqlStore) GetAccountOwnerID(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
var ownerID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Select("created_by").Where(idQueryCondition, accountID).First(&ownerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found")
}
return "", status.Errorf(status.Internal, "failed to get account owner from store: %v", result.Error)
}
return ownerID, nil
}
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
var ipJSONStrings []string
@@ -1151,7 +1180,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength
// DeleteGroups deletes groups from the database.
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, groupIDs []string, accountID string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(strength)}).
Where("account_id AND id IN ?", accountID, groupIDs).Delete(&nbgroup.Group{})
Where("id IN ? AND account_id = ?", groupIDs, accountID).Delete(&nbgroup.Group{})
if result.Error != nil {
return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error)
}
@@ -1276,7 +1305,7 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, patID string, userID string) (*PersonalAccessToken, error) {
var pat PersonalAccessToken
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&pat, "user_id = ? and id = ?", userID, patID)
First(&pat, "id = ? AND user_id = ?", patID, userID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "PAT not found")
@@ -1295,7 +1324,7 @@ func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pa
// DeletePAT deletes a personal access token from the database.
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, patID, userID string) error {
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&PersonalAccessToken{}, "user_id = ? and id = ?", userID, patID).Error
Delete(&PersonalAccessToken{}, "id = ? AND user_id = ?", patID, userID).Error
}
// GetAccountPeers retrieves peers for an account.
@@ -1303,6 +1332,11 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre
return getRecords[nbpeer.Peer](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
return getRecords[nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
}
// GetAccountPeersWithExpiration retrieves a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user.
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
db := s.db.WithContext(ctx).Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true)

View File

@@ -51,6 +51,7 @@ type Store interface {
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
GetAccountOwnerID(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error
@@ -63,7 +64,10 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error
DeleteUser(ctx context.Context, lockStrength LockingStrength, userID, accountID string) error
DeleteUsers(ctx context.Context, strength LockingStrength, userIDs []string, accountID string) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
@@ -98,6 +102,7 @@ type Store interface {
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, peerID string, accountID string) (*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error

View File

@@ -8,6 +8,7 @@ import (
"time"
"github.com/google/uuid"
nbgroup "github.com/netbirdio/netbird/management/server/group"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
@@ -213,19 +214,12 @@ func NewOwnerUser(id string) *User {
// createServiceUser creates a new service user under the given account.
func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
if err != nil {
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
return nil, err
}
executingUser := account.Users[initiatorUserID]
if executingUser == nil {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !executingUser.HasAdminPower() {
if !executingUser.HasAdminPower() || executingUser.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can create service users")
}
@@ -236,10 +230,9 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
newUserID := uuid.New().String()
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI)
log.WithContext(ctx).Debugf("New User: %v", newUser)
account.Users[newUserID] = newUser
newUser.AccountID = accountID
err = am.Store.SaveAccount(ctx, account)
if err != nil {
if err = am.Store.SaveUser(ctx, LockingStrengthUpdate, newUser); err != nil {
return nil, err
}
@@ -267,11 +260,8 @@ func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, user
return am.inviteNewUser(ctx, accountID, userID, user)
}
// inviteNewUser Invites a USer to a given account and creates reference in datastore
// inviteNewUser Invites a User to a given account and creates reference in datastore
func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if am.idpManager == nil {
return nil, status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites")
}
@@ -292,23 +282,24 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
default:
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
}
initiatorUser, err := account.FindUser(userID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, status.Errorf(status.NotFound, "initiator user with ID %s doesn't exist", userID)
}
inviterID := userID
if initiatorUser.IsServiceUser {
inviterID = account.CreatedBy
ownerID, err := am.Store.GetAccountOwnerID(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account owner: %v", err)
return nil, err
}
inviterID = ownerID
}
// inviterUser is the one who is inviting the new user
inviterUser, err := am.lookupUserInCache(ctx, inviterID, account)
inviterUser, err := am.lookupUserInCache(ctx, inviterID, accountID)
if err != nil || inviterUser == nil {
return nil, status.Errorf(status.NotFound, "inviter user with ID %s doesn't exist in IdP", inviterID)
}
@@ -339,27 +330,29 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
newUser := &User{
Id: idpUser.ID,
AccountID: accountID,
Role: invitedRole,
AutoGroups: invite.AutoGroups,
Issued: invite.Issued,
IntegrationReference: invite.IntegrationReference,
CreatedAt: time.Now().UTC(),
}
account.Users[idpUser.ID] = newUser
err = am.Store.SaveAccount(ctx, account)
if err != nil {
if err = am.Store.SaveUser(ctx, LockingStrengthUpdate, newUser); err != nil {
return nil, err
}
_, err = am.refreshCache(ctx, account.Id)
_, err = am.refreshCache(ctx, accountID)
if err != nil {
return nil, err
}
am.StoreEvent(ctx, userID, newUser.Id, accountID, activity.UserInvited, nil)
return newUser.ToUserInfo(idpUser, account.Settings)
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
return newUser.ToUserInfo(idpUser, settings)
}
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
@@ -399,20 +392,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
// ListUsers returns lists of all users under the account.
// It doesn't populate user information such as email or name.
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
users := make([]*User, 0, len(account.Users))
for _, item := range account.Users {
users = append(users, item)
}
return users, nil
return am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
}
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *Account, initiatorUserID string, targetUser *User) {
@@ -503,20 +483,12 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if am.idpManager == nil {
return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites")
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
}
// check if the user is already registered with this ID
user, err := am.lookupUserInCache(ctx, targetUserID, account)
user, err := am.lookupUserInCache(ctx, targetUserID, accountID)
if err != nil {
return err
}
@@ -663,9 +635,6 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists)
if err != nil {
return nil, err
@@ -686,17 +655,12 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
return nil, nil //nolint:nilnil
}
account, err := am.Store.GetAccount(ctx, accountID)
initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
if err != nil {
return nil, err
}
initiatorUser, err := account.FindUser(initiatorUserID)
if err != nil {
return nil, err
}
if !initiatorUser.HasAdminPower() || initiatorUser.IsBlocked() {
if !initiatorUser.HasAdminPower() || initiatorUser.IsBlocked() || initiatorUser.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations")
}
@@ -704,15 +668,21 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var (
expiredPeers []*nbpeer.Peer
eventsToStore []func()
usersToSave []*User
)
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
for _, update := range updates {
if update == nil {
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
}
oldUser := account.Users[update.Id]
if oldUser == nil {
oldUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, update.Id)
if err != nil {
if !addIfNotExists {
return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id)
}
@@ -720,7 +690,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
oldUser = update
}
if err := validateUserUpdate(account, initiatorUser, oldUser, update); err != nil {
if err := am.validateUserUpdate(ctx, accountID, initiatorUser, oldUser, update); err != nil {
return nil, err
}
@@ -733,29 +703,40 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
newUser.Issued = update.Issued
newUser.IntegrationReference = update.IntegrationReference
transferredOwnerRole := handleOwnerRoleTransfer(account, initiatorUser, update)
account.Users[newUser.Id] = newUser
// handle owner role transfer
transferredOwnerRole := initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner
if transferredOwnerRole {
newInitiatorUser := initiatorUser.Copy()
newInitiatorUser.Role = UserRoleAdmin
usersToSave = append(usersToSave, newInitiatorUser)
}
usersToSave = append(usersToSave, newUser)
if !oldUser.IsBlocked() && update.IsBlocked() {
// expire peers that belong to the user who's getting blocked
blockedPeers, err := account.FindUserPeers(update.Id)
blockedPeers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, update.Id, accountID)
if err != nil {
return nil, err
}
expiredPeers = append(expiredPeers, blockedPeers...)
}
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
//removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
// need force update all auto groups in any case they will not be duplicated
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
//TODO: wraps this in a transaction
//account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
//account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
}
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, accountID, transferredOwnerRole)
eventsToStore = append(eventsToStore, events...)
updatedUserInfo, err := getUserInfo(ctx, am, newUser, account)
updatedUserInfo, err := getUserInfo(ctx, am, newUser, accountID)
if err != nil {
return nil, err
}
@@ -763,40 +744,63 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
if len(expiredPeers) > 0 {
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil {
log.WithContext(ctx).Errorf("failed update expired peers: %s", err)
return nil, err
}
}
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
if account.Settings.GroupsPropagationEnabled {
am.updateAccountPeers(ctx, account)
//TODO: update groups with new members
if err = transaction.SaveUsers(ctx, LockingStrengthUpdate, usersToSave); err != nil {
return fmt.Errorf("failed to save users: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if settings.GroupsPropagationEnabled {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
}
return updatedUsers, nil
}
// propagateAutoGroupChangesForUser updates the user's auto-groups.
// If group propagation is enabled, it adds or removes groups from
// the peers owned by the user based on changes in their group assignments.
func (am *DefaultAccountManager) propagateAutoGroupChangesForUser(ctx context.Context, oldUser, updatedUser *User) []*nbgroup.Group {
return nil
}
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() {
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, accountID string, transferredOwnerRole bool) []func() {
var eventsToStore []func()
if oldUser.IsBlocked() != newUser.IsBlocked() {
if newUser.IsBlocked() {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserBlocked, nil)
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil)
})
} else {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserUnblocked, nil)
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil)
})
}
}
@@ -804,11 +808,11 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
switch {
case transferredOwnerRole:
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.TransferredOwnerRole, nil)
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil)
})
case oldUser.Role != newUser.Role:
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
})
}
@@ -816,23 +820,35 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
for _, g := range removedGroups {
group := account.GetGroup(g)
if group != nil {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
})
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": newUser.IsServiceUser,
"user_name": newUser.ServiceUserName,
}
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta)
})
}
}
for _, g := range addedGroups {
group := account.GetGroup(g)
if group != nil {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
if err != nil {
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
meta := map[string]any{
"group": group.Name,
"group_id": group.ID,
"is_service_user": newUser.IsServiceUser,
"user_name": newUser.ServiceUserName,
}
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, meta)
})
}
}
@@ -841,32 +857,27 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
return eventsToStore
}
func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool {
if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner {
newInitiatorUser := initiatorUser.Copy()
newInitiatorUser.Role = UserRoleAdmin
account.Users[initiatorUser.Id] = newInitiatorUser
return true
}
return false
}
// getUserInfo retrieves the UserInfo for a given User and Account.
// If the AccountManager has a non-nil idpManager and the User is not a service user,
// it will attempt to look up the UserData from the cache.
func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, account *Account) (*UserInfo, error) {
func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, accountID string) (*UserInfo, error) {
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if !isNil(am.idpManager) && !user.IsServiceUser {
userData, err := am.lookupUserInCache(ctx, user.Id, account)
userData, err := am.lookupUserInCache(ctx, user.Id, accountID)
if err != nil {
return nil, err
}
return user.ToUserInfo(userData, account.Settings)
return user.ToUserInfo(userData, settings)
}
return user.ToUserInfo(nil, account.Settings)
return user.ToUserInfo(nil, settings)
}
// validateUserUpdate validates the update operation for a user.
func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error {
func (am *DefaultAccountManager) validateUserUpdate(ctx context.Context, accountID string, initiatorUser, oldUser, update *User) error {
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
}
@@ -887,11 +898,12 @@ func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User)
}
for _, newGroupID := range update.AutoGroups {
group, ok := account.Groups[newGroupID]
if !ok {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, newGroupID, accountID)
if err != nil {
return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
newGroupID, update.Id)
}
if group.Name == "All" {
return status.Errorf(status.InvalidArgument, "can't add All group to the user")
}
@@ -942,21 +954,26 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
// based on provided user role.
func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) {
account, err := am.Store.GetAccount(ctx, accountID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "no permission to get users")
}
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
queriedUsers := make([]*idp.UserData, 0)
if !isNil(am.idpManager) {
users := make(map[string]userLoggedInOnce, len(account.Users))
users := make(map[string]userLoggedInOnce, len(accountUsers))
usersFromIntegration := make([]*idp.UserData, 0)
for _, user := range account.Users {
for _, user := range accountUsers {
if user.Issued == UserIssuedIntegration {
key := user.IntegrationReference.CacheKey(accountID, user.Id)
info, err := am.externalCacheManager.Get(am.ctx, key)
@@ -981,16 +998,21 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
queriedUsers = append(queriedUsers, usersFromIntegration...)
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
userInfos := make([]*UserInfo, 0)
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
if len(queriedUsers) == 0 {
for _, accountUser := range account.Users {
for _, accountUser := range accountUsers {
if !(user.HasAdminPower() || user.IsServiceUser || user.Id == accountUser.Id) {
// if user is not an admin then show only current user and do not show other users
continue
}
info, err := accountUser.ToUserInfo(nil, account.Settings)
info, err := accountUser.ToUserInfo(nil, settings)
if err != nil {
return nil, err
}
@@ -999,7 +1021,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
return userInfos, nil
}
for _, localUser := range account.Users {
for _, localUser := range accountUsers {
if !(user.HasAdminPower() || user.IsServiceUser) && user.Id != localUser.Id {
// if user is not an admin then show only current user and do not show other users
continue
@@ -1007,7 +1029,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
var info *UserInfo
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
info, err = localUser.ToUserInfo(queriedUser, account.Settings)
info, err = localUser.ToUserInfo(queriedUser, settings)
if err != nil {
return nil, err
}
@@ -1020,7 +1042,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
dashboardViewPermissions := "full"
if !localUser.HasAdminPower() {
dashboardViewPermissions = "limited"
if account.Settings.RegularUsersViewBlocked {
if settings.RegularUsersViewBlocked {
dashboardViewPermissions = "blocked"
}
}
@@ -1044,7 +1066,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
}
// expireAndUpdatePeers expires all peers of the given user and updates them in the account
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error {
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error {
var peerIDs []string
for _, peer := range peers {
if peer.Status.LoginExpired {
@@ -1052,13 +1074,13 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
}
peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil {
if err := am.Store.SavePeerStatus(accountID, peer.ID, *peer.Status); err != nil {
return err
}
am.StoreEvent(
ctx,
peer.UserID, peer.ID, account.Id,
peer.UserID, peer.ID, accountID,
activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()),
)
}
@@ -1066,6 +1088,10 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
}
return nil