diff --git a/management/server/account.go b/management/server/account.go index acd024b04..426d94bf4 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -139,7 +139,7 @@ type AccountManager interface { GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API - SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager @@ -2294,17 +2294,7 @@ func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.Aut } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { - accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) - defer accountUnlock() - peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) - defer peerUnlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - - peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account) + peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, err } @@ -2331,18 +2321,11 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st return err } - unlock := am.Store.AcquireReadLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - _, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account) + _, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID) if err != nil { return mapError(ctx, err) } + return nil } @@ -2436,7 +2419,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee } func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) { - existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID) + existingLabels, err := store.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID) if err != nil { return "", fmt.Errorf("failed to get peer dns labels: %w", err) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 0526204a5..08bd15e10 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -92,7 +92,7 @@ type MockAccountManager struct { GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool @@ -681,9 +681,9 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { if am.SyncPeerFunc != nil { - return am.SyncPeerFunc(ctx, sync, account) + return am.SyncPeerFunc(ctx, sync, accountID) } return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } diff --git a/management/server/peer.go b/management/server/peer.go index b344ef042..f49c9609f 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -66,12 +66,16 @@ func (am *DefaultAccountManager) ListPeers(ctx context.Context, accountID, userI // GetUserPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // the current user is not an admin. func (am *DefaultAccountManager) GetUserPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) + if user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -85,11 +89,15 @@ func (am *DefaultAccountManager) GetUserPeers(ctx context.Context, accountID, us regularUser := !user.HasAdminPower() && !user.IsServiceUser - if regularUser && account.Settings.RegularUsersViewBlocked { + if regularUser && settings.RegularUsersViewBlocked { return peers, nil } - for _, peer := range account.Peers { + accountPeers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + for _, peer := range accountPeers { if regularUser && user.Id != peer.UserID { // only display peers that belong to the current user if the current user is not an admin continue @@ -103,6 +111,11 @@ func (am *DefaultAccountManager) GetUserPeers(ctx context.Context, accountID, us return peers, nil } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, fmt.Errorf(errGetAccountFmt, err) + } + // fetch all the peers that have access to the user's peers for _, peer := range peers { aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) @@ -196,20 +209,31 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated. func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - peer := account.GetPeer(update.ID) - if peer == nil { - return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) + if user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) } - update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, update.ID) + if err != nil { + return nil, err + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, update.ID) + if err != nil { + return nil, err + } + + update, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra) if err != nil { return nil, err } @@ -226,7 +250,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if peer.Name != update.Name { peer.Name = update.Name - existingLabels := account.getPeerDNSLabels() + existingLabels, err := am.getPeerDNSLabels(ctx, accountID) + if err != nil { + return nil, err + } newLabel, err := getPeerHostLabel(peer.Name, existingLabels) if err != nil { @@ -252,13 +279,12 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) - if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { + if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } } if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled { - if !peer.AddedWithSSOLogin() { return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") } @@ -271,90 +297,89 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) - if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { + if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } } - account.UpdatePeer(peer) - - err = am.Store.SaveAccount(ctx, account) - if err != nil { + if err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil { return nil, err } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, fmt.Errorf(errGetAccountFmt, err) + } am.updateAccountPeers(ctx, account) return peer, nil } // deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock -func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error { +func (am *DefaultAccountManager) deletePeers(ctx context.Context, accountID string, userID string, peers []*nbpeer.Peer) error { + return am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for _, peer := range peers { + if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil { + return fmt.Errorf("failed to validate peer: %w", err) + } - // the first loop is needed to ensure all peers present under the account before modifying, otherwise - // we might have some inconsistencies - peers := make([]*nbpeer.Peer, 0, len(peerIDs)) - for _, peerID := range peerIDs { + network, err := transaction.GetAccountNetwork(ctx, LockingStrengthShare, accountID) + if err != nil { + return fmt.Errorf("failed to get account network: %w", err) + } - peer := account.GetPeer(peerID) - if peer == nil { - return status.Errorf(status.NotFound, "peer %s not found", peerID) - } - peers = append(peers, peer) - } + if err = transaction.DeletePeer(ctx, LockingStrengthUpdate, accountID, peer.ID); err != nil { + return fmt.Errorf("failed to delete peer: %w", err) + } - // the 2nd loop performs the actual modification - for _, peer := range peers { - - err := am.integratedPeerValidator.PeerDeleted(ctx, account.Id, peer.ID) - if err != nil { - return err - } - - account.DeletePeer(peer.ID) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, - &UpdateMessage{ - Update: &proto.SyncResponse{ - // fill those field for backward compatibility - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - // new field - NetworkMap: &proto.NetworkMap{ - Serial: account.Network.CurrentSerial(), - RemotePeers: []*proto.RemotePeerConfig{}, - RemotePeersIsEmpty: true, - FirewallRules: []*proto.FirewallRule{}, - FirewallRulesIsEmpty: true, + am.peersUpdateManager.SendUpdate(ctx, peer.ID, + &UpdateMessage{ + Update: &proto.SyncResponse{ + // fill those field for backward compatibility + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + // new field + NetworkMap: &proto.NetworkMap{ + Serial: network.CurrentSerial(), + RemotePeers: []*proto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + FirewallRules: []*proto.FirewallRule{}, + FirewallRulesIsEmpty: true, + }, }, - }, - }) - am.peersUpdateManager.CloseChannel(ctx, peer.ID) - am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) - } + }) + am.peersUpdateManager.CloseChannel(ctx, peer.ID) + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) + } + return nil + }) - return nil } // DeletePeer removes peer from the account by its IP func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - err = am.deletePeers(ctx, account, []string{peerID}, userID) + if user.AccountID != accountID { + return status.Errorf(status.PermissionDenied, errUserNotPartOfAccountMsg) + } + + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) if err != nil { return err } - err = am.Store.SaveAccount(ctx, account) - if err != nil { + if err = am.deletePeers(ctx, accountID, userID, []*nbpeer.Peer{peer}); err != nil { return err } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return fmt.Errorf(errGetAccountFmt, err) + } am.updateAccountPeers(ctx, account) return nil @@ -636,14 +661,14 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { - peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) +func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, sync.WireGuardPubKey) if err != nil { return nil, nil, nil, status.NewPeerNotRegisteredError() } if peer.UserID != "" { - user, err := account.FindUser(peer.UserID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) if err != nil { return nil, nil, nil, err } @@ -654,23 +679,21 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } } - if peerLoginExpired(ctx, peer, account.Settings) { + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, nil, nil, err + } + + if peerLoginExpired(ctx, peer, settings) { return nil, nil, nil, status.NewPeerLoginExpiredError() } - updated := peer.UpdateMetaIfNew(sync.Meta) - if updated { - err = am.Store.SavePeer(ctx, account.Id, peer) - if err != nil { - return nil, nil, nil, err - } - - if sync.UpdateAccountPeers { - am.updateAccountPeers(ctx, account) - } + peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, peer.ID) + if err != nil { + return nil, nil, nil, err } - peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupList, settings.Extra) if err != nil { return nil, nil, nil, err } @@ -678,22 +701,40 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac var postureChecks []*posture.Checks if peerNotValid { + network, err := am.Store.GetAccountNetwork(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, nil, nil, err + } + emptyMap := &NetworkMap{ - Network: account.Network.Copy(), + Network: network.Copy(), } return peer, emptyMap, postureChecks, nil } - if isStatusChanged { - am.updateAccountPeers(ctx, account) + updated := peer.UpdateMetaIfNew(sync.Meta) + if updated { + err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) + if err != nil { + return nil, nil, nil, err + } } - validPeersMap, err := am.GetValidatedPeers(ctx, account.Id) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { return nil, nil, nil, err } - postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + if isStatusChanged || (updated && sync.UpdateAccountPeers) { + am.updateAccountPeers(ctx, account) + } + + validPeersMap, err := am.GetValidatedPeers(ctx, accountID) + if err != nil { + return nil, nil, nil, err + } + + postureChecks, err = am.getPeerPostureChecks(ctx, accountID, peer.ID) if err != nil { return nil, nil, nil, err } @@ -805,7 +846,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } if shouldStorePeer { - err = am.Store.SavePeer(ctx, accountID, peer) + err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) if err != nil { return nil, nil, nil, err } @@ -885,7 +926,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us // If peer was expired before and if it reached this point, it is re-authenticated. // UserID is present, meaning that JWT validation passed successfully in the API layer. peer = peer.UpdateLastLogin() - err = am.Store.SavePeer(ctx, peer.AccountID, peer) + err = am.Store.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer) if err != nil { return err } @@ -1190,6 +1231,34 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p return peerGroups, nil } +// getPeerGroupIDs returns the IDs of the groups that the peer is part of. +func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID string, peerID string) ([]string, error) { + groups, err := am.GetPeerGroups(ctx, accountID, peerID) + if err != nil { + return nil, err + } + + groupIDs := make([]string, 0, len(groups)) + for _, group := range groups { + groupIDs = append(groupIDs, group.ID) + } + + return groupIDs, err +} + +func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID string) (lookupMap, error) { + dnsLabels, err := am.Store.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + existingLabels := make(lookupMap) + for _, label := range dnsLabels { + existingLabels[label] = struct{}{} + } + return existingLabels, nil +} + func ConvertSliceToMap(existingLabels []string) map[string]struct{} { labelMap := make(map[string]struct{}, len(existingLabels)) for _, label := range existingLabels { diff --git a/management/server/sql_store.go b/management/server/sql_store.go index fe2e4ee76..5b2d61b59 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -291,12 +291,12 @@ func (s *SqlStore) GetInstallationID() string { return installation.InstallationIDValue } -func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { +func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error { // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. peerCopy := peer.Copy() peerCopy.AccountID = accountID - err := s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { + err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Transaction(func(tx *gorm.DB) error { // check if peer exists before saving var peerID string result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID) @@ -783,7 +783,8 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength return ips, nil } -func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { +// GetAccountPeerDNSLabels retrieves all unique DNS labels for peers associated with a specified account. +func (s *SqlStore) GetAccountPeerDNSLabels(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { var labels []string result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). @@ -804,10 +805,12 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { var accountNetwork AccountNetwork - if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { + if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + First(&accountNetwork, idQueryCondition, accountID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } + log.WithContext(ctx).Errorf("error when getting network from the store: %s", err) return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err) } return accountNetwork.Network, nil @@ -1139,6 +1142,21 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength return peer, nil } +func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&Policy{}, accountAndIDQueryCondition, accountID, peerID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete peer from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "peer not found") + } + + return nil +} + func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 0a3e1aaf9..09be9ab85 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -403,7 +403,7 @@ func TestSqlite_SavePeer(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } ctx := context.Background() - err = store.SavePeer(ctx, account.Id, peer) + err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, peer) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -419,7 +419,7 @@ func TestSqlite_SavePeer(t *testing.T) { updatedPeer.Status.Connected = false updatedPeer.Meta.Hostname = "updatedpeer" - err = store.SavePeer(ctx, account.Id, updatedPeer) + err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, updatedPeer) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -1056,7 +1056,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) - labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) + labels, err := store.GetAccountPeerDNSLabels(context.Background(), LockingStrengthShare, existingAccountID) require.NoError(t, err) assert.Equal(t, []string{}, labels) @@ -1068,7 +1068,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { err = store.AddPeerToAccount(context.Background(), peer1) require.NoError(t, err) - labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) + labels, err = store.GetAccountPeerDNSLabels(context.Background(), LockingStrengthShare, existingAccountID) require.NoError(t, err) assert.Equal(t, []string{"peer1.domain.test"}, labels) @@ -1080,7 +1080,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { err = store.AddPeerToAccount(context.Background(), peer2) require.NoError(t, err) - labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) + labels, err = store.GetAccountPeerDNSLabels(context.Background(), LockingStrengthShare, existingAccountID) require.NoError(t, err) assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels) } diff --git a/management/server/store.go b/management/server/store.go index 1f2ebc81e..e4b948be6 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -91,7 +91,7 @@ type Store interface { SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error - GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) + GetAccountPeerDNSLabels(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error @@ -101,9 +101,10 @@ type Store interface { GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) - SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error + SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error + DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error diff --git a/management/server/user.go b/management/server/user.go index 44b371b49..4c43c63fe 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -493,17 +493,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account } func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) error { - peers, err := account.FindUserPeers(targetUserID) + peers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, account.Id, targetUserID) if err != nil { - return status.Errorf(status.Internal, "failed to find user peers") + return err } - peerIDs := make([]string, 0, len(peers)) - for _, peer := range peers { - peerIDs = append(peerIDs, peer.ID) - } - - return am.deletePeers(ctx, account, peerIDs, initiatorUserID) + return am.deletePeers(ctx, account.Id, initiatorUserID, peers) } // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.