diff --git a/management/server/account.go b/management/server/account.go index 19ff96cfb..9a1ca9866 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -45,6 +45,7 @@ import ( const ( CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days + peerSchedulerRetryInterval = 3 * time.Second emptyUserID = "empty user ID in claims" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) @@ -469,7 +470,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc expiredPeers, err := am.getExpiredPeers(ctx, accountID) if err != nil { - return 0, false + return peerSchedulerRetryInterval, true } var peerIDs []string @@ -481,7 +482,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil { log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", accountID) - return 0, false + return peerSchedulerRetryInterval, true } return am.getNextPeerExpiration(ctx, accountID) @@ -504,7 +505,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context inactivePeers, err := am.getInactivePeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID) - return 0, false + return peerSchedulerRetryInterval, true } var peerIDs []string @@ -516,7 +517,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers); err != nil { log.Errorf("failed updating account peers while expiring peers for account %s", accountID) - return 0, false + return peerSchedulerRetryInterval, true } return am.getNextInactivePeerExpiration(ctx, accountID) diff --git a/management/server/groups/manager.go b/management/server/groups/manager.go index f5abb212e..02b669e41 100644 --- a/management/server/groups/manager.go +++ b/management/server/groups/manager.go @@ -13,7 +13,8 @@ import ( ) type Manager interface { - GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) + GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error) @@ -37,7 +38,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, accou } } -func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { +func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Groups, permissions.Read) if err != nil { return nil, err @@ -51,6 +52,15 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string return nil, fmt.Errorf("error getting account groups: %w", err) } + return groups, nil +} + +func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { + groups, err := m.GetAllGroups(ctx, accountID, userID) + if err != nil { + return nil, err + } + groupsMap := make(map[string]*types.Group) for _, group := range groups { groupsMap[group.ID] = group @@ -130,7 +140,7 @@ func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transa return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID) } -func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum { +func ToGroupsInfo(groups []*types.Group, id string) []api.GroupMinimum { groupsInfo := []api.GroupMinimum{} groupsChecked := make(map[string]struct{}) for _, group := range groups { @@ -167,7 +177,11 @@ func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum return groupsInfo } -func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { +func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + return []*types.Group{}, nil +} + +func (m *mockManager) GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { return map[string]*types.Group{}, nil } diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index 6b36a8fce..316b93611 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -82,7 +82,7 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { return } - groups, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -267,7 +267,7 @@ func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, ne return nil, nil, 0, fmt.Errorf("failed to get routers in network: %w", err) } - groups, err := h.groupsManager.GetAllGroups(ctx, accountID, userID) + groups, err := h.groupsManager.GetAllGroupsMap(ctx, accountID, userID) if err != nil { return nil, nil, 0, fmt.Errorf("failed to get groups: %w", err) } diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 09d63ea6f..76a0149c6 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -72,13 +72,8 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, } dnsDomain := h.accountManager.GetDNSDomain() - groupsMap := map[string]*types.Group{} - grps, _ := h.accountManager.GetAllGroups(ctx, accountID, userID) - for _, group := range grps { - groupsMap[group.ID] = group - } - - groupsInfo := groups.ToGroupsInfo(groupsMap, peerID) + grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) + groupsInfo := groups.ToGroupsInfo(grps, peerID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -128,12 +123,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri return } - groupsMap := map[string]*types.Group{} - for _, group := range peerGroups { - groupsMap[group.ID] = group - } - - groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID) + groupMinimumInfo := groups.ToGroupsInfo(peerGroups, peer.ID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -204,11 +194,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - groupsMap := map[string]*types.Group{} grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) - for _, group := range grps { - groupsMap[group.ID] = group - } respBody := make([]*api.PeerBatch, 0, len(peers)) for _, peer := range peers { @@ -217,7 +203,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID) + groupMinimumInfo := groups.ToGroupsInfo(grps, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } diff --git a/management/server/peer.go b/management/server/peer.go index 78bb1dd8c..9921c0c9b 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -101,7 +101,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, err } - approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, account.Groups, account.Peers, account.Settings.Extra) if err != nil { return nil, err } @@ -335,6 +335,15 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + if err != nil { + return err + } + + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID) if err != nil { return err @@ -1057,12 +1066,12 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, err } - approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { return nil, err } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, account.Groups, account.Peers, account.Settings.Extra) if err != nil { return nil, err } @@ -1139,6 +1148,11 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account // UpdateAccountPeer updates a single peer that belongs to an account. // Should be called when changes need to be synced to a specific peer only. func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) { + if !am.peersUpdateManager.HasChannel(peerId) { + log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId) + return + } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err) @@ -1151,11 +1165,6 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - if !am.peersUpdateManager.HasChannel(peerId) { - log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId) - return - } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) if err != nil { log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err) @@ -1185,7 +1194,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) - return 0, false + return peerSchedulerRetryInterval, true } if len(peersWithExpiry) == 0 { @@ -1195,7 +1204,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get account settings: %v", err) - return 0, false + return peerSchedulerRetryInterval, true } var nextExpiry *time.Duration @@ -1229,7 +1238,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err) - return 0, false + return peerSchedulerRetryInterval, true } if len(peersWithInactivity) == 0 { @@ -1239,7 +1248,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get account settings: %v", err) - return 0, false + return peerSchedulerRetryInterval, true } var nextExpiry *time.Duration diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 82934dd11..46e896b55 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -332,7 +332,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, a result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to save peer to store: %v", result.Error) } return nil @@ -358,7 +358,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID Where(idQueryCondition, accountID). Updates(&accountCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to update account domain attributes to store: %v", result.Error) } if result.RowsAffected == 0 { @@ -381,7 +381,7 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStren Where(accountAndIDQueryCondition, accountID, peerID). Updates(&peerCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to save peer status to store: %v", result.Error) } if result.RowsAffected == 0 { @@ -403,7 +403,7 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr Updates(peerCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to save peer locations to store: %v", result.Error) } if result.RowsAffected == 0 {