mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
@@ -1125,6 +1125,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
// TODO: call direct on the store to get expired peers
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed getting account %s expiring peers", accountID)
|
||||
@@ -1139,7 +1140,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
|
||||
|
||||
log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
|
||||
|
||||
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
|
||||
if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", account.Id)
|
||||
return account.GetNextPeerExpiration()
|
||||
}
|
||||
@@ -1296,7 +1297,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserOrAccountID(ctx context.Conte
|
||||
return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID)
|
||||
}
|
||||
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account); err != nil {
|
||||
if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -1311,28 +1312,28 @@ func isNil(i idp.Manager) bool {
|
||||
}
|
||||
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, account *Account) error {
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) {
|
||||
|
||||
// user can be nil if it wasn't found (e.g., just created)
|
||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user != nil && user.AppMetadata.WTAccountID == account.Id {
|
||||
if user != nil && user.AppMetadata.WTAccountID == accountID {
|
||||
// it was already set, so we skip the unnecessary update
|
||||
log.WithContext(ctx).Debugf("skipping IDP App Meta update because accountID %s has been already set for user %s",
|
||||
account.Id, userID)
|
||||
accountID, userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: account.Id})
|
||||
err = am.idpManager.UpdateUserAppMetadata(ctx, userID, idp.AppMetadata{WTAccountID: accountID})
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "updating user's app metadata failed with: %v", err)
|
||||
}
|
||||
// refresh cache to reflect the update
|
||||
_, err = am.refreshCache(ctx, account.Id)
|
||||
_, err = am.refreshCache(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1391,10 +1392,15 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
|
||||
}
|
||||
|
||||
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
|
||||
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *Account) (*idp.UserData, error) {
|
||||
users := make(map[string]userLoggedInOnce, len(account.Users))
|
||||
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make(map[string]userLoggedInOnce, len(accountUsers))
|
||||
// ignore service users and users provisioned by integrations than are never logged in
|
||||
for _, user := range account.Users {
|
||||
for _, user := range accountUsers {
|
||||
if user.IsServiceUser {
|
||||
continue
|
||||
}
|
||||
@@ -1403,8 +1409,9 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
|
||||
}
|
||||
users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
|
||||
}
|
||||
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, account.Id)
|
||||
userData, err := am.lookupCache(ctx, users, account.Id)
|
||||
|
||||
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, accountID)
|
||||
userData, err := am.lookupCache(ctx, users, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1417,13 +1424,13 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
|
||||
|
||||
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
|
||||
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
|
||||
user, err := account.FindUser(userID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, account.Id)
|
||||
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := user.IntegrationReference.CacheKey(account.Id, userID)
|
||||
key := user.IntegrationReference.CacheKey(accountID, userID)
|
||||
ud, err := am.externalCacheManager.Get(am.ctx, key)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err)
|
||||
@@ -1591,13 +1598,14 @@ func (am *DefaultAccountManager) handleExistingUserAccount(
|
||||
primaryDomain bool,
|
||||
claims jwtclaims.AuthorizationClaims,
|
||||
) error {
|
||||
// TODO: remove account as parameter and pass accountID string
|
||||
err := am.updateAccountDomainAttributes(ctx, existingAcc, claims, primaryDomain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// we should register the account ID to this user's metadata in our IDP manager
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc)
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, existingAcc.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1635,7 +1643,7 @@ func (am *DefaultAccountManager) handleNewUserAccount(ctx context.Context, domai
|
||||
}
|
||||
}
|
||||
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account)
|
||||
err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, account.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1653,12 +1661,12 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
_, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := am.lookupUserInCache(ctx, userID, account)
|
||||
user, err := am.lookupUserInCache(ctx, userID, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1668,17 +1676,17 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
|
||||
}
|
||||
|
||||
if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
|
||||
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, account.Id)
|
||||
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID)
|
||||
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
|
||||
// Our job is to just reload cache.
|
||||
go func() {
|
||||
_, err = am.refreshCache(ctx, account.Id)
|
||||
_, err = am.refreshCache(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id)
|
||||
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID)
|
||||
return
|
||||
}
|
||||
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, account.Id)
|
||||
am.StoreEvent(ctx, userID, userID, account.Id, activity.UserJoined, nil)
|
||||
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID)
|
||||
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
|
||||
}()
|
||||
}
|
||||
|
||||
|
||||
@@ -1704,7 +1704,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
|
||||
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@@ -1821,7 +1821,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
|
||||
},
|
||||
}
|
||||
// enabling PeerLoginExpirationEnabled should trigger the expiration job
|
||||
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
})
|
||||
@@ -1852,13 +1852,13 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
|
||||
accountID, err := manager.GetAccountIDByUserOrAccountID(context.Background(), userID, "", "")
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
assert.False(t, updated.Settings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour)
|
||||
assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
|
||||
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
|
||||
|
||||
accountID, err = manager.GetAccountIDByUserOrAccountID(context.Background(), "", accountID, "")
|
||||
require.NoError(t, err, "unable to get account by ID")
|
||||
|
||||
@@ -959,10 +959,6 @@ func (s *FileStore) GetStoreEngine() StoreEngine {
|
||||
return FileStoreEngine
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveUsers(_ string, _ map[string]*User) error {
|
||||
return status.Errorf(status.Internal, "SaveUsers is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveGroups(_ context.Context, _ LockingStrength, _ []*nbgroup.Group) error {
|
||||
return status.Errorf(status.Internal, "SaveGroups is not implemented")
|
||||
}
|
||||
@@ -1085,6 +1081,10 @@ func (s *FileStore) GetAccountPeers(_ context.Context, _ LockingStrength, _ stri
|
||||
return nil, status.Errorf(status.Internal, "GetAccountPeers is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetUserPeers(_ context.Context, _ LockingStrength, _, _ string) ([]*nbpeer.Peer, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetUserPeers is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountPeersWithExpiration(_ context.Context, _ LockingStrength, _ string) ([]*nbpeer.Peer, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountPeersWithExpiration is not implemented")
|
||||
}
|
||||
@@ -1127,3 +1127,23 @@ func (s *FileStore) DeleteGroups(_ context.Context, _ LockingStrength, _ []strin
|
||||
func (s *FileStore) GetAccountUsers(_ context.Context, _ LockingStrength, _ string) ([]*User, error) {
|
||||
return nil, status.Errorf(status.Internal, "GetAccountUsers is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveUser(_ context.Context, _ LockingStrength, _ *User) error {
|
||||
return status.Errorf(status.Internal, "SaveUser is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) SaveUsers(_ context.Context, _ LockingStrength, _ []*User) error {
|
||||
return status.Errorf(status.Internal, "SaveUsers is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) DeleteUser(_ context.Context, _ LockingStrength, _, _ string) error {
|
||||
return status.Errorf(status.Internal, "DeleteUser is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) DeleteUsers(_ context.Context, _ LockingStrength, _ []string, _ string) error {
|
||||
return status.Errorf(status.Internal, "DeleteUsers is not implemented")
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountOwnerID(_ context.Context, _ LockingStrength, _ string) (string, error) {
|
||||
return "", status.Errorf(status.Internal, "GetAccountOwnerID is not implemented")
|
||||
}
|
||||
|
||||
@@ -32,7 +32,21 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
|
||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
|
||||
// routes can have both peer and peer_groups
|
||||
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
|
||||
routes, err := am.Store.GetAccountRoutes(context.Background(), LockingStrengthShare, account.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
routesWithPrefix := make([]*route.Route, 0)
|
||||
for _, r := range routes {
|
||||
dynamic := r.IsDynamic()
|
||||
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
|
||||
!dynamic && r.Network.String() == prefix.String() {
|
||||
routesWithPrefix = append(routesWithPrefix, r)
|
||||
}
|
||||
}
|
||||
|
||||
//routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
|
||||
|
||||
// lets remember all the peers and the peer groups from routesWithPrefix
|
||||
seenPeers := make(map[string]bool)
|
||||
@@ -51,8 +65,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
for _, groupID := range prefixRoute.PeerGroups {
|
||||
seenPeerGroups[groupID] = true
|
||||
|
||||
group := account.GetGroup(groupID)
|
||||
if group == nil {
|
||||
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, groupID, account.Id)
|
||||
if err != nil || group == nil {
|
||||
return status.Errorf(
|
||||
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
|
||||
getRouteDescriptor(prefix, domains), groupID,
|
||||
@@ -67,10 +81,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
|
||||
if peerID != "" {
|
||||
// check that peerID exists and is not in any route as single peer or part of the group
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, peerID, account.Id)
|
||||
if err != nil || peer == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
if _, ok := seenPeers[peerID]; ok {
|
||||
return status.Errorf(status.AlreadyExists,
|
||||
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
|
||||
@@ -79,7 +94,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
|
||||
// check that peerGroupIDs are not in any route peerGroups list
|
||||
for _, groupID := range peerGroupIDs {
|
||||
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again.
|
||||
// we validated the group existence before entering this function, no need to check again.
|
||||
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, groupID, account.Id)
|
||||
if err != nil || group == nil {
|
||||
return status.Errorf(status.InvalidArgument, "group with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
if _, ok := seenPeerGroups[groupID]; ok {
|
||||
return status.Errorf(
|
||||
@@ -90,10 +109,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
|
||||
for _, id := range group.Peers {
|
||||
if _, ok := seenPeers[id]; ok {
|
||||
peer := account.GetPeer(id)
|
||||
if peer == nil {
|
||||
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, peerID, account.Id)
|
||||
if err != nil || peer == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
return status.Errorf(status.AlreadyExists,
|
||||
"failed to add route with %s - peer %s from the group %s already has this route",
|
||||
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
|
||||
@@ -151,10 +171,10 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
newRoute.ID = route.ID(xid.New().String())
|
||||
|
||||
if len(peerGroupIDs) > 0 {
|
||||
err = validateGroups(peerGroupIDs, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//err = validateGroups(peerGroupIDs, account.Groups)
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
}
|
||||
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
|
||||
@@ -170,10 +190,10 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
err = validateGroups(groups, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//err = validateGroups(groups, account.Groups)
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
|
||||
newRoute.Peer = peerID
|
||||
newRoute.PeerGroups = peerGroupIDs
|
||||
@@ -208,13 +228,19 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
|
||||
// SaveRoute saves route
|
||||
func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userID string, routeToSave *route.Route) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if routeToSave == nil {
|
||||
return status.Errorf(status.InvalidArgument, "route provided is nil")
|
||||
}
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.Errorf(status.PermissionDenied, "user not allowed to update route")
|
||||
}
|
||||
|
||||
if routeToSave.Metric < route.MinMetric || routeToSave.Metric > route.MaxMetric {
|
||||
return status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||
}
|
||||
@@ -223,16 +249,14 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
// Do not allow non-Linux peers
|
||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, routeToSave.Peer, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Do not allow non-Linux peers
|
||||
if peer := account.GetPeer(routeToSave.Peer); peer != nil {
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
if peer.Meta.GoOS != "linux" {
|
||||
return status.Errorf(status.InvalidArgument, "non-linux peers are not supported as network routes")
|
||||
}
|
||||
|
||||
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
||||
@@ -251,60 +275,78 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
|
||||
}
|
||||
|
||||
if len(routeToSave.PeerGroups) > 0 {
|
||||
err = validateGroups(routeToSave.PeerGroups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
|
||||
groups, err := am.Store.GetAccountGroups(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_ = groups
|
||||
|
||||
err = validateGroups(routeToSave.Groups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//if len(routeToSave.PeerGroups) > 0 {
|
||||
// err = validateGroups(routeToSave.PeerGroups, groups)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
//}
|
||||
|
||||
account.Routes[routeToSave.ID] = routeToSave
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
//err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
|
||||
//if err != nil {
|
||||
// return err
|
||||
//}
|
||||
//
|
||||
//err = validateGroups(routeToSave.Groups, account.Groups)
|
||||
//if err != nil {
|
||||
// return err
|
||||
//}
|
||||
//
|
||||
//account.Routes[routeToSave.ID] = routeToSave
|
||||
//
|
||||
//account.Network.IncSerial()
|
||||
//if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
// return err
|
||||
//}
|
||||
//
|
||||
//am.updateAccountPeers(ctx, account)
|
||||
//
|
||||
//am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
routy := account.Routes[routeID]
|
||||
if routy == nil {
|
||||
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
|
||||
if user.AccountID != accountID {
|
||||
return status.Errorf(status.PermissionDenied, "user not allowed to delete route")
|
||||
}
|
||||
delete(account.Routes, routeID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
route, err := am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.DeleteRoute(ctx, LockingStrengthUpdate, string(routeID), accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete route: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account: %w", err)
|
||||
}
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
|
||||
@@ -361,21 +361,34 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveUsers saves the given list of users to the database.
|
||||
// It updates existing users if a conflict occurs.
|
||||
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
|
||||
usersToSave := make([]User, 0, len(users))
|
||||
for _, user := range users {
|
||||
user.AccountID = accountID
|
||||
for id, pat := range user.PATs {
|
||||
pat.ID = id
|
||||
user.PATsG = append(user.PATsG, *pat)
|
||||
}
|
||||
usersToSave = append(usersToSave, *user)
|
||||
// SaveUser saves a user to the store.
|
||||
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
|
||||
return saveRecord[User](s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}), lockStrength, user)
|
||||
}
|
||||
|
||||
// SaveUsers saves a list of users to the store.
|
||||
func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error {
|
||||
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&users)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save users to store: %v", result.Error)
|
||||
}
|
||||
return s.db.Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.OnConflict{UpdateAll: true}).
|
||||
Create(&usersToSave).Error
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user from the store.
|
||||
func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, userID, accountID string) error {
|
||||
return deleteRecordByID[User](s.db.WithContext(ctx).Select(clause.Associations), lockStrength, userID, accountID)
|
||||
}
|
||||
|
||||
// DeleteUsers deletes a list of users from the store.
|
||||
func (s *SqlStore) DeleteUsers(ctx context.Context, strength LockingStrength, userIDs []string, accountID string) error {
|
||||
result := s.db.WithContext(ctx).Select(clause.Associations).Clauses(clause.Locking{Strength: string(strength)}).
|
||||
Where("id IN ? AND account_id = ?", userIDs, accountID).Delete(&User{})
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to delete users from store: %v", result.Error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteHashedPAT2TokenIDIndex is noop in SqlStore
|
||||
@@ -695,6 +708,22 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
// GetAccountOwnerID returns the owner ID of the account.
|
||||
func (s *SqlStore) GetAccountOwnerID(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
|
||||
var ownerID string
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Select("created_by").Where(idQueryCondition, accountID).First(&ownerID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
return "", status.Errorf(status.Internal, "failed to get account owner from store: %v", result.Error)
|
||||
}
|
||||
|
||||
return ownerID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
|
||||
var ipJSONStrings []string
|
||||
|
||||
@@ -1151,7 +1180,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength
|
||||
// DeleteGroups deletes groups from the database.
|
||||
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, groupIDs []string, accountID string) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(strength)}).
|
||||
Where("account_id AND id IN ?", accountID, groupIDs).Delete(&nbgroup.Group{})
|
||||
Where("id IN ? AND account_id = ?", groupIDs, accountID).Delete(&nbgroup.Group{})
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error)
|
||||
}
|
||||
@@ -1276,7 +1305,7 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
|
||||
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, patID string, userID string) (*PersonalAccessToken, error) {
|
||||
var pat PersonalAccessToken
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&pat, "user_id = ? and id = ?", userID, patID)
|
||||
First(&pat, "id = ? AND user_id = ?", patID, userID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "PAT not found")
|
||||
@@ -1295,7 +1324,7 @@ func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pa
|
||||
// DeletePAT deletes a personal access token from the database.
|
||||
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, patID, userID string) error {
|
||||
return s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&PersonalAccessToken{}, "user_id = ? and id = ?", userID, patID).Error
|
||||
Delete(&PersonalAccessToken{}, "id = ? AND user_id = ?", patID, userID).Error
|
||||
}
|
||||
|
||||
// GetAccountPeers retrieves peers for an account.
|
||||
@@ -1303,6 +1332,11 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre
|
||||
return getRecords[nbpeer.Peer](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetUserPeers retrieves peers for a user.
|
||||
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
return getRecords[nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetAccountPeersWithExpiration retrieves a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user.
|
||||
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
||||
db := s.db.WithContext(ctx).Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true)
|
||||
|
||||
@@ -51,6 +51,7 @@ type Store interface {
|
||||
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
||||
GetAccountOwnerID(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error)
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
|
||||
@@ -63,7 +64,10 @@ type Store interface {
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*User) error
|
||||
DeleteUser(ctx context.Context, lockStrength LockingStrength, userID, accountID string) error
|
||||
DeleteUsers(ctx context.Context, strength LockingStrength, userIDs []string, accountID string) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
DeleteHashedPAT2TokenIDIndex(hashedToken string) error
|
||||
@@ -98,6 +102,7 @@ type Store interface {
|
||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, lockStrength LockingStrength, peerID string, accountID string) (*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -213,19 +214,12 @@ func NewOwnerUser(id string) *User {
|
||||
|
||||
// createServiceUser creates a new service user under the given account.
|
||||
func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if executingUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if !executingUser.HasAdminPower() {
|
||||
if !executingUser.HasAdminPower() || executingUser.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can create service users")
|
||||
}
|
||||
|
||||
@@ -236,10 +230,9 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
|
||||
newUserID := uuid.New().String()
|
||||
newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI)
|
||||
log.WithContext(ctx).Debugf("New User: %v", newUser)
|
||||
account.Users[newUserID] = newUser
|
||||
newUser.AccountID = accountID
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
if err = am.Store.SaveUser(ctx, LockingStrengthUpdate, newUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -267,11 +260,8 @@ func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, user
|
||||
return am.inviteNewUser(ctx, accountID, userID, user)
|
||||
}
|
||||
|
||||
// inviteNewUser Invites a USer to a given account and creates reference in datastore
|
||||
// inviteNewUser Invites a User to a given account and creates reference in datastore
|
||||
func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if am.idpManager == nil {
|
||||
return nil, status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites")
|
||||
}
|
||||
@@ -292,23 +282,24 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
default:
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
|
||||
}
|
||||
|
||||
initiatorUser, err := account.FindUser(userID)
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.NotFound, "initiator user with ID %s doesn't exist", userID)
|
||||
}
|
||||
|
||||
inviterID := userID
|
||||
if initiatorUser.IsServiceUser {
|
||||
inviterID = account.CreatedBy
|
||||
ownerID, err := am.Store.GetAccountOwnerID(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account owner: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inviterID = ownerID
|
||||
}
|
||||
|
||||
// inviterUser is the one who is inviting the new user
|
||||
inviterUser, err := am.lookupUserInCache(ctx, inviterID, account)
|
||||
inviterUser, err := am.lookupUserInCache(ctx, inviterID, accountID)
|
||||
if err != nil || inviterUser == nil {
|
||||
return nil, status.Errorf(status.NotFound, "inviter user with ID %s doesn't exist in IdP", inviterID)
|
||||
}
|
||||
@@ -339,27 +330,29 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
|
||||
|
||||
newUser := &User{
|
||||
Id: idpUser.ID,
|
||||
AccountID: accountID,
|
||||
Role: invitedRole,
|
||||
AutoGroups: invite.AutoGroups,
|
||||
Issued: invite.Issued,
|
||||
IntegrationReference: invite.IntegrationReference,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
account.Users[idpUser.ID] = newUser
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
if err = am.Store.SaveUser(ctx, LockingStrengthUpdate, newUser); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = am.refreshCache(ctx, account.Id)
|
||||
_, err = am.refreshCache(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, newUser.Id, accountID, activity.UserInvited, nil)
|
||||
|
||||
return newUser.ToUserInfo(idpUser, account.Settings)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newUser.ToUserInfo(idpUser, settings)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) {
|
||||
@@ -399,20 +392,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A
|
||||
// ListUsers returns lists of all users under the account.
|
||||
// It doesn't populate user information such as email or name.
|
||||
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]*User, 0, len(account.Users))
|
||||
for _, item := range account.Users {
|
||||
users = append(users, item)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
return am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *Account, initiatorUserID string, targetUser *User) {
|
||||
@@ -503,20 +483,12 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
|
||||
|
||||
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
|
||||
func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if am.idpManager == nil {
|
||||
return status.Errorf(status.PreconditionFailed, "IdP manager must be enabled to send user invites")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "account %s doesn't exist", accountID)
|
||||
}
|
||||
|
||||
// check if the user is already registered with this ID
|
||||
user, err := am.lookupUserInCache(ctx, targetUserID, account)
|
||||
user, err := am.lookupUserInCache(ctx, targetUserID, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -663,9 +635,6 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -686,17 +655,12 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
initiatorUser, err := account.FindUser(initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !initiatorUser.HasAdminPower() || initiatorUser.IsBlocked() {
|
||||
if !initiatorUser.HasAdminPower() || initiatorUser.IsBlocked() || initiatorUser.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations")
|
||||
}
|
||||
|
||||
@@ -704,15 +668,21 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
var (
|
||||
expiredPeers []*nbpeer.Peer
|
||||
eventsToStore []func()
|
||||
usersToSave []*User
|
||||
)
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, update := range updates {
|
||||
if update == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "provided user update is nil")
|
||||
}
|
||||
|
||||
oldUser := account.Users[update.Id]
|
||||
if oldUser == nil {
|
||||
oldUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, update.Id)
|
||||
if err != nil {
|
||||
if !addIfNotExists {
|
||||
return nil, status.Errorf(status.NotFound, "user to update doesn't exist: %s", update.Id)
|
||||
}
|
||||
@@ -720,7 +690,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
oldUser = update
|
||||
}
|
||||
|
||||
if err := validateUserUpdate(account, initiatorUser, oldUser, update); err != nil {
|
||||
if err := am.validateUserUpdate(ctx, accountID, initiatorUser, oldUser, update); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -733,29 +703,40 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
newUser.Issued = update.Issued
|
||||
newUser.IntegrationReference = update.IntegrationReference
|
||||
|
||||
transferredOwnerRole := handleOwnerRoleTransfer(account, initiatorUser, update)
|
||||
account.Users[newUser.Id] = newUser
|
||||
// handle owner role transfer
|
||||
transferredOwnerRole := initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner
|
||||
if transferredOwnerRole {
|
||||
newInitiatorUser := initiatorUser.Copy()
|
||||
newInitiatorUser.Role = UserRoleAdmin
|
||||
|
||||
usersToSave = append(usersToSave, newInitiatorUser)
|
||||
}
|
||||
usersToSave = append(usersToSave, newUser)
|
||||
|
||||
if !oldUser.IsBlocked() && update.IsBlocked() {
|
||||
// expire peers that belong to the user who's getting blocked
|
||||
blockedPeers, err := account.FindUserPeers(update.Id)
|
||||
blockedPeers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, update.Id, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
expiredPeers = append(expiredPeers, blockedPeers...)
|
||||
}
|
||||
|
||||
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
|
||||
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
|
||||
//removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
|
||||
// need force update all auto groups in any case they will not be duplicated
|
||||
account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
|
||||
account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
|
||||
|
||||
//TODO: wraps this in a transaction
|
||||
|
||||
//account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
|
||||
//account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
|
||||
|
||||
}
|
||||
|
||||
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
|
||||
events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, accountID, transferredOwnerRole)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updatedUserInfo, err := getUserInfo(ctx, am, newUser, account)
|
||||
updatedUserInfo, err := getUserInfo(ctx, am, newUser, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -763,40 +744,63 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
if len(expiredPeers) > 0 {
|
||||
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
|
||||
if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed update expired peers: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
if account.Settings.GroupsPropagationEnabled {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
//TODO: update groups with new members
|
||||
|
||||
if err = transaction.SaveUsers(ctx, LockingStrengthUpdate, usersToSave); err != nil {
|
||||
return fmt.Errorf("failed to save users: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if settings.GroupsPropagationEnabled {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting account: %w", err)
|
||||
}
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
return updatedUsers, nil
|
||||
}
|
||||
|
||||
// propagateAutoGroupChangesForUser updates the user's auto-groups.
|
||||
// If group propagation is enabled, it adds or removes groups from
|
||||
// the peers owned by the user based on changes in their group assignments.
|
||||
func (am *DefaultAccountManager) propagateAutoGroupChangesForUser(ctx context.Context, oldUser, updatedUser *User) []*nbgroup.Group {
|
||||
return nil
|
||||
}
|
||||
|
||||
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
|
||||
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() {
|
||||
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, accountID string, transferredOwnerRole bool) []func() {
|
||||
var eventsToStore []func()
|
||||
|
||||
if oldUser.IsBlocked() != newUser.IsBlocked() {
|
||||
if newUser.IsBlocked() {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserBlocked, nil)
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserBlocked, nil)
|
||||
})
|
||||
} else {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserUnblocked, nil)
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserUnblocked, nil)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -804,11 +808,11 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
|
||||
switch {
|
||||
case transferredOwnerRole:
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.TransferredOwnerRole, nil)
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.TransferredOwnerRole, nil)
|
||||
})
|
||||
case oldUser.Role != newUser.Role:
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.UserRoleUpdated, map[string]any{"role": newUser.Role})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -816,23 +820,35 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
|
||||
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
|
||||
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
|
||||
for _, g := range removedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
})
|
||||
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||
} else {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{
|
||||
"group": group.Name,
|
||||
"group_id": group.ID,
|
||||
"is_service_user": newUser.IsServiceUser,
|
||||
"user_name": newUser.ServiceUserName,
|
||||
}
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
for _, g := range addedGroups {
|
||||
group := account.GetGroup(g)
|
||||
if group != nil {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||
} else {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
|
||||
map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
|
||||
meta := map[string]any{
|
||||
"group": group.Name,
|
||||
"group_id": group.ID,
|
||||
"is_service_user": newUser.IsServiceUser,
|
||||
"user_name": newUser.ServiceUserName,
|
||||
}
|
||||
am.StoreEvent(ctx, initiatorUserID, oldUser.Id, accountID, activity.GroupAddedToUser, meta)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -841,32 +857,27 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
|
||||
return eventsToStore
|
||||
}
|
||||
|
||||
func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool {
|
||||
if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner {
|
||||
newInitiatorUser := initiatorUser.Copy()
|
||||
newInitiatorUser.Role = UserRoleAdmin
|
||||
account.Users[initiatorUser.Id] = newInitiatorUser
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getUserInfo retrieves the UserInfo for a given User and Account.
|
||||
// If the AccountManager has a non-nil idpManager and the User is not a service user,
|
||||
// it will attempt to look up the UserData from the cache.
|
||||
func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, account *Account) (*UserInfo, error) {
|
||||
func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, accountID string) (*UserInfo, error) {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !isNil(am.idpManager) && !user.IsServiceUser {
|
||||
userData, err := am.lookupUserInCache(ctx, user.Id, account)
|
||||
userData, err := am.lookupUserInCache(ctx, user.Id, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return user.ToUserInfo(userData, account.Settings)
|
||||
return user.ToUserInfo(userData, settings)
|
||||
}
|
||||
return user.ToUserInfo(nil, account.Settings)
|
||||
return user.ToUserInfo(nil, settings)
|
||||
}
|
||||
|
||||
// validateUserUpdate validates the update operation for a user.
|
||||
func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error {
|
||||
func (am *DefaultAccountManager) validateUserUpdate(ctx context.Context, accountID string, initiatorUser, oldUser, update *User) error {
|
||||
if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked {
|
||||
return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves")
|
||||
}
|
||||
@@ -887,11 +898,12 @@ func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User)
|
||||
}
|
||||
|
||||
for _, newGroupID := range update.AutoGroups {
|
||||
group, ok := account.Groups[newGroupID]
|
||||
if !ok {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, newGroupID, accountID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||
newGroupID, update.Id)
|
||||
}
|
||||
|
||||
if group.Name == "All" {
|
||||
return status.Errorf(status.InvalidArgument, "can't add All group to the user")
|
||||
}
|
||||
@@ -942,21 +954,26 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u
|
||||
// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return
|
||||
// based on provided user role.
|
||||
func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) {
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get users")
|
||||
}
|
||||
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queriedUsers := make([]*idp.UserData, 0)
|
||||
if !isNil(am.idpManager) {
|
||||
users := make(map[string]userLoggedInOnce, len(account.Users))
|
||||
users := make(map[string]userLoggedInOnce, len(accountUsers))
|
||||
|
||||
usersFromIntegration := make([]*idp.UserData, 0)
|
||||
for _, user := range account.Users {
|
||||
for _, user := range accountUsers {
|
||||
if user.Issued == UserIssuedIntegration {
|
||||
key := user.IntegrationReference.CacheKey(accountID, user.Id)
|
||||
info, err := am.externalCacheManager.Get(am.ctx, key)
|
||||
@@ -981,16 +998,21 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
queriedUsers = append(queriedUsers, usersFromIntegration...)
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userInfos := make([]*UserInfo, 0)
|
||||
|
||||
// in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo
|
||||
if len(queriedUsers) == 0 {
|
||||
for _, accountUser := range account.Users {
|
||||
for _, accountUser := range accountUsers {
|
||||
if !(user.HasAdminPower() || user.IsServiceUser || user.Id == accountUser.Id) {
|
||||
// if user is not an admin then show only current user and do not show other users
|
||||
continue
|
||||
}
|
||||
info, err := accountUser.ToUserInfo(nil, account.Settings)
|
||||
info, err := accountUser.ToUserInfo(nil, settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -999,7 +1021,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
return userInfos, nil
|
||||
}
|
||||
|
||||
for _, localUser := range account.Users {
|
||||
for _, localUser := range accountUsers {
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) && user.Id != localUser.Id {
|
||||
// if user is not an admin then show only current user and do not show other users
|
||||
continue
|
||||
@@ -1007,7 +1029,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
|
||||
var info *UserInfo
|
||||
if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains {
|
||||
info, err = localUser.ToUserInfo(queriedUser, account.Settings)
|
||||
info, err = localUser.ToUserInfo(queriedUser, settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1020,7 +1042,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
dashboardViewPermissions := "full"
|
||||
if !localUser.HasAdminPower() {
|
||||
dashboardViewPermissions = "limited"
|
||||
if account.Settings.RegularUsersViewBlocked {
|
||||
if settings.RegularUsersViewBlocked {
|
||||
dashboardViewPermissions = "blocked"
|
||||
}
|
||||
}
|
||||
@@ -1044,7 +1066,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
// expireAndUpdatePeers expires all peers of the given user and updates them in the account
|
||||
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error {
|
||||
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error {
|
||||
var peerIDs []string
|
||||
for _, peer := range peers {
|
||||
if peer.Status.LoginExpired {
|
||||
@@ -1052,13 +1074,13 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
}
|
||||
peerIDs = append(peerIDs, peer.ID)
|
||||
peer.MarkLoginExpired(true)
|
||||
account.UpdatePeer(peer)
|
||||
if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil {
|
||||
|
||||
if err := am.Store.SavePeerStatus(accountID, peer.ID, *peer.Status); err != nil {
|
||||
return err
|
||||
}
|
||||
am.StoreEvent(
|
||||
ctx,
|
||||
peer.UserID, peer.ID, account.Id,
|
||||
peer.UserID, peer.ID, accountID,
|
||||
activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()),
|
||||
)
|
||||
}
|
||||
@@ -1066,6 +1088,10 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
|
||||
if len(peerIDs) != 0 {
|
||||
// this will trigger peer disconnect from the management service
|
||||
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting account: %w", err)
|
||||
}
|
||||
am.updateAccountPeers(ctx, account)
|
||||
}
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user