diff --git a/management/server/account.go b/management/server/account.go index d6c14afc3..9bb029b51 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -76,7 +76,7 @@ type AccountManager interface { GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) ListUsers(accountID string) ([]*User, error) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error + MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *Account) error DeletePeer(accountID, peerID, userID string) error UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) GetNetworkMap(peerID string) (*NetworkMap, error) @@ -117,8 +117,8 @@ type AccountManager interface { SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) - LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API - SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API + LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API + SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager @@ -130,6 +130,8 @@ type AccountManager interface { UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error GroupValidation(accountId string, groups []string) (bool, error) GetValidatedPeers(account *Account) (map[string]struct{}, error) + SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) + CancelPeerRoutines(peer *nbpeer.Peer) error } type DefaultAccountManager struct { @@ -1864,6 +1866,62 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla } } +func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) { + startTime := time.Now() + defer func() { + duration := time.Since(startTime) + log.Debugf("SyncAndMarkPeer took %s", duration) + }() + + accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey) + if err != nil { + return nil, nil, err + } + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, nil, err + } + + peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey}, account) + if err != nil { + return nil, nil, mapError(err) + } + + err = am.MarkPeerConnected(peerPubKey, true, realIP, account) + if err != nil { + log.Warnf("failed marking peer as connected %s %v", peerPubKey, err) + } + + return peer, netMap, nil +} + +func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error { + accountID, err := am.Store.GetAccountIDByPeerPubKey(peer.Key) + if err != nil { + return err + } + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return err + } + + err = am.MarkPeerConnected(peer.Key, false, nil, account) + if err != nil { + log.Warnf("failed marking peer as connected %s %v", peer.Key, err) + } + + return nil + +} + // GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers() func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { return am.peersUpdateManager.GetAllConnectedPeers(), nil diff --git a/management/server/file_store.go b/management/server/file_store.go index 2de852bee..ebc96f4be 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -1,6 +1,7 @@ package server import ( + "errors" "os" "path/filepath" "strings" @@ -572,6 +573,10 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { return account.Copy(), nil } +func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) { + return "", errors.New("not implemented") +} + // GetInstallationID returns the installation ID from the store func (s *FileStore) GetInstallationID() string { return s.InstallationID diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index ddf88ef6d..df95e5a27 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -140,9 +140,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi return err } - peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()}) + peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), realIP) if err != nil { - return mapError(err) + return err } err = s.sendInitialSync(peerKey, peer, netMap, srv) @@ -155,11 +155,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi s.ephemeralManager.OnPeerConnected(peer) - err = s.accountManager.MarkPeerConnected(peerKey.String(), true, realIP) - if err != nil { - log.Warnf("failed marking peer as connected %s %v", peerKey, err) - } - if s.config.TURNConfig.TimeBasedCredentials { s.turnCredentialsManager.SetupRefresh(peer.ID) } @@ -213,7 +208,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) { s.peersUpdateManager.CloseChannel(peer.ID) s.turnCredentialsManager.CancelRefresh(peer.ID) - _ = s.accountManager.MarkPeerConnected(peer.Key, false, nil) + _ = s.accountManager.CancelPeerRoutines(peer) s.ephemeralManager.OnPeerDisconnected(peer) } diff --git a/management/server/peer.go b/management/server/peer.go index 784cd7bf2..5140620ea 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -88,27 +88,13 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP) error { +func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP, account *Account) error { startTime := time.Now() defer func() { duration := time.Since(startTime) log.Debugf("MarkPeerConnected took %s", duration) }() - account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) - if err != nil { - return err - } - - unlock := am.Store.AcquireAccountLock(account.Id) - defer unlock() - - // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(account.Id) - if err != nil { - return err - } - peer, err := account.FindPeerByPubKey(peerPubKey) if err != nil { return err @@ -524,31 +510,13 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) { +func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) { startTime := time.Now() defer func() { duration := time.Since(startTime) log.Debugf("SyncPeer took %s", duration) }() - account, err := am.Store.GetAccountByPeerPubKey(sync.WireGuardPubKey) - if err != nil { - if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { - return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered") - } - return nil, nil, err - } - - // we found the peer, and we follow a normal login flow - unlock := am.Store.AcquireAccountLock(account.Id) - defer unlock() - - // fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies - account, err = am.Store.GetAccount(account.Id) - if err != nil { - return nil, nil, err - } - peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) if err != nil { return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered") diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index c60e96e6b..93dbec473 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -280,20 +280,9 @@ func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer duration := time.Since(startTime) log.Debugf("SavePeerStatus took %s", duration) }() - var peer nbpeer.Peer - result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID) - if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.Errorf(status.NotFound, "peer %s not found", peerID) - } - log.Errorf("error when getting peer from the store: %s", result.Error) - return status.Errorf(status.Internal, "issue getting peer from store") - } - - peer.Status = &peerStatus - - return s.db.Save(peer).Error + s.db.Where("account_id = ? and id = ?", accountID, peerID).Update("status", peerStatus) + return nil } func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { @@ -303,7 +292,6 @@ func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpee log.Debugf("SavePeerLocation took %s", duration) }() - log.Info("saving peer location") s.db.Where("account_id = ? and id = ?", accountID, peerWithLocation.ID).Update("location", peerWithLocation.Location) return nil } @@ -563,6 +551,26 @@ func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { return s.GetAccount(peer.AccountID) } +func (s *SqliteStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) { + startTime := time.Now() + defer func() { + duration := time.Since(startTime) + log.Debugf("GetAccountByPubKey took %s", duration) + }() + + var accountID string + result := s.db.Select("account_id").Where("key = ?", peerKey).First(&accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + } + log.Errorf("error when getting peer from the store: %s", result.Error) + return "", status.Errorf(status.Internal, "issue getting account from store") + } + + return accountID, nil +} + // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { startTime := time.Now() diff --git a/management/server/store.go b/management/server/store.go index 77b8d0dad..26a1b8a7e 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -19,6 +19,7 @@ type Store interface { DeleteAccount(account *Account) error GetAccountByUser(userID string) (*Account, error) GetAccountByPeerPubKey(peerKey string) (*Account, error) + GetAccountIDByPeerPubKey(peerKey string) (string, error) GetAccountByPeerID(peerID string) (*Account, error) GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later GetAccountByPrivateDomain(domain string) (*Account, error)