From 9bf0bf484327d454c1dafef42ab277d4d270ebef Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 28 Oct 2024 17:47:54 +0300 Subject: [PATCH] wip: refactor get account in peers Signed-off-by: bcmmbaga --- management/server/account.go | 7 +- management/server/http/groups_handler.go | 8 +- management/server/http/peers_handler.go | 97 +++++++------- management/server/integrated_validator.go | 31 ++++- management/server/mock_server/account_mock.go | 46 ++++--- management/server/peer.go | 122 +++++++++--------- management/server/peer_test.go | 8 +- 7 files changed, 176 insertions(+), 143 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index c08665c2a..acd024b04 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -94,7 +94,8 @@ type AccountManager interface { GetUserByID(ctx context.Context, id string) (*User, error) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) ListUsers(ctx context.Context, accountID string) ([]*User, error) - GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) + GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) + ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) @@ -105,7 +106,6 @@ type AccountManager interface { DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) - UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) @@ -116,6 +116,7 @@ type AccountManager interface { DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error + GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error DeletePolicy(ctx context.Context, accountID, policyID, userID string) error @@ -149,7 +150,7 @@ type AccountManager interface { GetIdpManager() idp.Manager UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(account *Account) (map[string]struct{}, error) + GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index f369d1a00..75aad653f 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -49,7 +49,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -132,7 +132,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -180,7 +180,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -238,7 +238,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + accountPeers, err := h.accountManager.GetUserPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index a5856a0e4..608207927 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -48,8 +48,8 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) return peerToReturn, nil } -func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { - peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) +func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) { + peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID) if err != nil { util.WriteError(ctx, err, w) return @@ -62,11 +62,16 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee } dnsDomain := h.accountManager.GetDNSDomain() - groupsInfo := toGroupsInfo(account.Groups, peer.ID) - - validPeers, err := h.accountManager.GetValidatedPeers(account) + peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) if err != nil { - log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + util.WriteError(ctx, err, w) + return + } + groupsInfo := toGroupsInfo(peerGroups) + + validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) return } @@ -75,7 +80,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -99,16 +104,21 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, } } - peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update) + peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update) if err != nil { util.WriteError(ctx, err, w) return } dnsDomain := h.accountManager.GetDNSDomain() - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) + if err != nil { + util.WriteError(ctx, err, w) + return + } + groupMinimumInfo := toGroupsInfo(peerGroups) - validPeers, err := h.accountManager.GetValidatedPeers(account) + validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) @@ -149,18 +159,11 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { case http.MethodDelete: h.deletePeer(r.Context(), accountID, userID, peerID, w) return - case http.MethodGet, http.MethodPut: - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - if r.Method == http.MethodGet { - h.getPeer(r.Context(), account, peerID, userID, w) - } else { - h.updatePeer(r.Context(), account, userID, peerID, w, r) - } + case http.MethodGet: + h.getPeer(r.Context(), accountID, peerID, userID, w) + return + case http.MethodPut: + h.updatePeer(r.Context(), accountID, userID, peerID, w, r) return default: util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) @@ -176,7 +179,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { return } - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) + peers, err := h.accountManager.ListPeers(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -184,19 +187,25 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - respBody := make([]*api.PeerBatch, 0, len(account.Peers)) - for _, peer := range account.Peers { + respBody := make([]*api.PeerBatch, 0, len(peers)) + for _, peer := range peers { peerToReturn, err := h.checkPeerStatus(peer) if err != nil { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + + peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + groupMinimumInfo := toGroupsInfo(peerGroups) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } - validPeersMap, err := h.accountManager.GetValidatedPeers(account) + validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -259,16 +268,16 @@ func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request } } - dnsDomain := h.accountManager.GetDNSDomain() - - validPeers, err := h.accountManager.GetValidatedPeers(account) + validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain()) + dnsDomain := h.accountManager.GetDNSDomain() + + customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) @@ -303,26 +312,14 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { - var groupsInfo []api.GroupMinimum - groupsChecked := make(map[string]struct{}) +func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum { + groupsInfo := make([]api.GroupMinimum, 0, len(groups)) for _, group := range groups { - _, ok := groupsChecked[group.ID] - if ok { - continue - } - groupsChecked[group.ID] = struct{}{} - for _, pk := range group.Peers { - if pk == peerID { - info := api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - } - groupsInfo = append(groupsInfo, info) - break - } - } + groupsInfo = append(groupsInfo, api.GroupMinimum{ + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + }) } return groupsInfo } diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index ba6a20259..dc3c47596 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -4,6 +4,8 @@ import ( "context" "errors" + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" @@ -78,6 +80,31 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { - return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) +func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + groupsMap := make(map[string]*nbgroup.Group, len(groups)) + for _, group := range groups { + groupsMap[group.ID] = group + } + + peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + peersMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 18f6ff16c..0526204a5 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -31,7 +31,8 @@ type MockAccountManager struct { GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) - GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) + GetUserPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) + ListPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error @@ -47,6 +48,7 @@ type MockAccountManager struct { DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error + GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error) DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error @@ -56,7 +58,6 @@ type MockAccountManager struct { GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) MarkPATUsedFunc func(ctx context.Context, pat string) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error - UpdatePeerSSHKeyFunc func(ctx context.Context, peerID string, sshKey string) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) @@ -123,7 +124,12 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { + account, err := am.GetAccountFunc(ctx, accountID) + if err != nil { + return nil, err + } + approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} @@ -425,14 +431,6 @@ func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ( return nil, status.Errorf(codes.Unimplemented, "method ListUsers is not implemented") } -// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager -func (am *MockAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { - if am.UpdatePeerSSHKeyFunc != nil { - return am.UpdatePeerSSHKeyFunc(ctx, peerID, sshKey) - } - return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is not implemented") -} - // UpdatePeer mocks UpdatePeerFunc function of the account manager func (am *MockAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) { if am.UpdatePeerFunc != nil { @@ -618,12 +616,12 @@ func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, cl return status.Errorf(codes.Unimplemented, "method CheckUserAccessByJWTGroups is not implemented") } -// GetPeers mocks GetPeers of the AccountManager interface -func (am *MockAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { - if am.GetPeersFunc != nil { - return am.GetPeersFunc(ctx, accountID, userID) +// GetUserPeers mocks GetUserPeers of the AccountManager interface +func (am *MockAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { + if am.GetUserPeersFunc != nil { + return am.GetUserPeersFunc(ctx, accountID, userID) } - return nil, status.Errorf(codes.Unimplemented, "method GetPeers is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method GetUserPeers is not implemented") } // GetDNSDomain mocks GetDNSDomain of the AccountManager interface @@ -832,3 +830,19 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) } return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented") } + +// GetPeerGroups mocks GetPeerGroups of the AccountManager interface +func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) { + if am.GetPeerGroupsFunc != nil { + return am.GetPeerGroupsFunc(ctx, accountID, peerID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented") +} + +// ListPeers mocks ListPeers of the AccountManager interface +func (am *MockAccountManager) ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { + if am.ListPeersFunc != nil { + return am.ListPeersFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method ListPeers is not implemented") +} diff --git a/management/server/peer.go b/management/server/peer.go index eb60d02cb..b344ef042 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -4,10 +4,12 @@ import ( "context" "fmt" "net" + "slices" "strings" "sync" "time" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/rs/xid" log "github.com/sirupsen/logrus" @@ -47,9 +49,23 @@ type PeerLogin struct { ConnectionIP net.IP } -// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if +// ListPeers returns a list of peers under the given account. +func (am *DefaultAccountManager) ListPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + } + + return am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID) +} + +// GetUserPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // the current user is not an admin. -func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { +func (am *DefaultAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err @@ -60,7 +76,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, err } - approvedPeersMap, err := am.GetValidatedPeers(account) + approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) if err != nil { return nil, err } @@ -585,7 +601,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s am.updateAccountPeers(ctx, account) - approvedPeersMap, err := am.GetValidatedPeers(account) + approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) if err != nil { return nil, nil, nil, err } @@ -672,7 +688,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac am.updateAccountPeers(ctx, account) } - validPeersMap, err := am.GetValidatedPeers(account) + validPeersMap, err := am.GetValidatedPeers(ctx, account.Id) if err != nil { return nil, nil, nil, err } @@ -847,7 +863,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return peer, emptyMap, nil, nil } - approvedPeersMap, err := am.GetValidatedPeers(account) + approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id) if err != nil { return nil, nil, nil, err } @@ -914,92 +930,53 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings return false } -// UpdatePeerSSHKey updates peer's public SSH key -func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { - if sshKey == "" { - log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID) - return nil - } - - account, err := am.Store.GetAccountByPeerID(ctx, peerID) - if err != nil { - return err - } - - unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) - defer unlock() - - // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(ctx, account.Id) - if err != nil { - return err - } - - peer := account.GetPeer(peerID) - if peer == nil { - return status.Errorf(status.NotFound, "peer with ID %s not found", peerID) - } - - if peer.SSHKey == sshKey { - log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID) - return nil - } - - peer.SSHKey = sshKey - account.UpdatePeer(peer) - - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return err - } - - // trigger network map update - am.updateAccountPeers(ctx, account) - - return nil -} - // GetPeer for a given accountID, peerID and userID error if not found. func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + if user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } - if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { + if !user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked { return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID) } - peer := account.GetPeer(peerID) - if peer == nil { - return nil, status.Errorf(status.NotFound, "peer with %s not found under account %s", peerID, accountID) + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + return nil, err } // if admin or user owns this peer, return peer - if user.HasAdminPower() || user.IsServiceUser || peer.UserID == userID { + if user.IsAdminOrServiceUser() || peer.UserID == userID { return peer, nil } // it is also possible that user doesn't own the peer but some of his peers have access to it, // this is a valid case, show the peer as well. - userPeers, err := account.FindUserPeers(userID) + userPeers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID) if err != nil { return nil, err } - approvedPeersMap, err := am.GetValidatedPeers(account) + approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) if err != nil { return nil, err } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, fmt.Errorf(errGetAccountFmt, err) + } + for _, p := range userPeers { aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap) for _, aclPeer := range aclPeers { @@ -1024,7 +1001,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account peers := account.GetPeers() - approvedPeersMap, err := am.GetValidatedPeers(account) + approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err) return @@ -1196,6 +1173,23 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID return peers, nil } +// GetPeerGroups returns groups that the peer is part of. +func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { + groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + peerGroups := make([]*nbgroup.Group, 0) + for _, group := range groups { + if slices.Contains(group.Peers, peerID) { + peerGroups = append(peerGroups, group) + } + } + + return peerGroups, nil +} + func ConvertSliceToMap(existingLabels []string) map[string]struct{} { labelMap := make(map[string]struct{}, len(existingLabels)) for _, label := range existingLabels { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index c5edb5636..0f1bb1888 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -561,7 +561,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { assert.NotNil(t, peer) } -func TestDefaultAccountManager_GetPeers(t *testing.T) { +func TestDefaultAccountManager_GetUserPeers(t *testing.T) { testCases := []struct { name string role UserRole @@ -697,7 +697,7 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { return } - peers, err := manager.GetPeers(context.Background(), accountID, someUser) + peers, err := manager.GetUserPeers(context.Background(), accountID, someUser) if err != nil { t.Fatal(err) return @@ -822,9 +822,9 @@ func BenchmarkGetPeers(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := manager.GetPeers(context.Background(), accountID, userID) + _, err := manager.GetUserPeers(context.Background(), accountID, userID) if err != nil { - b.Fatalf("GetPeers failed: %v", err) + b.Fatalf("GetUserPeers failed: %v", err) } } })