mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-08 09:49:54 +00:00
try improving Sync method
This commit is contained in:
@@ -76,7 +76,7 @@ type AccountManager interface {
|
||||
GetUser(claims jwtclaims.AuthorizationClaims) (*User, error)
|
||||
ListUsers(accountID string) ([]*User, error)
|
||||
GetPeers(accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error
|
||||
MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *Account) error
|
||||
DeletePeer(accountID, peerID, userID string) error
|
||||
UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
|
||||
GetNetworkMap(peerID string) (*NetworkMap, error)
|
||||
@@ -117,8 +117,8 @@ type AccountManager interface {
|
||||
SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error
|
||||
GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error)
|
||||
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||
SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||
LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||
SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API
|
||||
GetAllConnectedPeers() (map[string]struct{}, error)
|
||||
HasConnectedChannel(peerID string) bool
|
||||
GetExternalCacheManager() ExternalCacheManager
|
||||
@@ -130,6 +130,8 @@ type AccountManager interface {
|
||||
UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error
|
||||
GroupValidation(accountId string, groups []string) (bool, error)
|
||||
GetValidatedPeers(account *Account) (map[string]struct{}, error)
|
||||
SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error)
|
||||
CancelPeerRoutines(peer *nbpeer.Peer) error
|
||||
}
|
||||
|
||||
type DefaultAccountManager struct {
|
||||
@@ -1864,6 +1866,62 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) {
|
||||
startTime := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(startTime)
|
||||
log.Debugf("SyncAndMarkPeer took %s", duration)
|
||||
}()
|
||||
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey}, account)
|
||||
if err != nil {
|
||||
return nil, nil, mapError(err)
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(peerPubKey, true, realIP, account)
|
||||
if err != nil {
|
||||
log.Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return peer, netMap, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error {
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(peer.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.MarkPeerConnected(peer.Key, false, nil, account)
|
||||
if err != nil {
|
||||
log.Warnf("failed marking peer as connected %s %v", peer.Key, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers()
|
||||
func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) {
|
||||
return am.peersUpdateManager.GetAllConnectedPeers(), nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -572,6 +573,10 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
return account.Copy(), nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||
return "", errors.New("not implemented")
|
||||
}
|
||||
|
||||
// GetInstallationID returns the installation ID from the store
|
||||
func (s *FileStore) GetInstallationID() string {
|
||||
return s.InstallationID
|
||||
|
||||
@@ -140,9 +140,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
return err
|
||||
}
|
||||
|
||||
peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()})
|
||||
peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), realIP)
|
||||
if err != nil {
|
||||
return mapError(err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = s.sendInitialSync(peerKey, peer, netMap, srv)
|
||||
@@ -155,11 +155,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
s.ephemeralManager.OnPeerConnected(peer)
|
||||
|
||||
err = s.accountManager.MarkPeerConnected(peerKey.String(), true, realIP)
|
||||
if err != nil {
|
||||
log.Warnf("failed marking peer as connected %s %v", peerKey, err)
|
||||
}
|
||||
|
||||
if s.config.TURNConfig.TimeBasedCredentials {
|
||||
s.turnCredentialsManager.SetupRefresh(peer.ID)
|
||||
}
|
||||
@@ -213,7 +208,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) {
|
||||
s.peersUpdateManager.CloseChannel(peer.ID)
|
||||
s.turnCredentialsManager.CancelRefresh(peer.ID)
|
||||
_ = s.accountManager.MarkPeerConnected(peer.Key, false, nil)
|
||||
_ = s.accountManager.CancelPeerRoutines(peer)
|
||||
s.ephemeralManager.OnPeerDisconnected(peer)
|
||||
}
|
||||
|
||||
|
||||
@@ -88,27 +88,13 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P
|
||||
}
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP) error {
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP, account *Account) error {
|
||||
startTime := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(startTime)
|
||||
log.Debugf("MarkPeerConnected took %s", duration)
|
||||
}()
|
||||
|
||||
account, err := am.Store.GetAccountByPeerPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
// ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account)
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peer, err := account.FindPeerByPubKey(peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -524,31 +510,13 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P
|
||||
}
|
||||
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) {
|
||||
func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) {
|
||||
startTime := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(startTime)
|
||||
log.Debugf("SyncPeer took %s", duration)
|
||||
}()
|
||||
|
||||
account, err := am.Store.GetAccountByPeerPubKey(sync.WireGuardPubKey)
|
||||
if err != nil {
|
||||
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
|
||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// we found the peer, and we follow a normal login flow
|
||||
unlock := am.Store.AcquireAccountLock(account.Id)
|
||||
defer unlock()
|
||||
|
||||
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
|
||||
account, err = am.Store.GetAccount(account.Id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered")
|
||||
|
||||
@@ -280,20 +280,9 @@ func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer
|
||||
duration := time.Since(startTime)
|
||||
log.Debugf("SavePeerStatus took %s", duration)
|
||||
}()
|
||||
var peer nbpeer.Peer
|
||||
|
||||
result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.Errorf(status.NotFound, "peer %s not found", peerID)
|
||||
}
|
||||
log.Errorf("error when getting peer from the store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "issue getting peer from store")
|
||||
}
|
||||
|
||||
peer.Status = &peerStatus
|
||||
|
||||
return s.db.Save(peer).Error
|
||||
s.db.Where("account_id = ? and id = ?", accountID, peerID).Update("status", peerStatus)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
|
||||
@@ -303,7 +292,6 @@ func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpee
|
||||
log.Debugf("SavePeerLocation took %s", duration)
|
||||
}()
|
||||
|
||||
log.Info("saving peer location")
|
||||
s.db.Where("account_id = ? and id = ?", accountID, peerWithLocation.ID).Update("location", peerWithLocation.Location)
|
||||
return nil
|
||||
}
|
||||
@@ -563,6 +551,26 @@ func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) {
|
||||
return s.GetAccount(peer.AccountID)
|
||||
}
|
||||
|
||||
func (s *SqliteStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) {
|
||||
startTime := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(startTime)
|
||||
log.Debugf("GetAccountByPubKey took %s", duration)
|
||||
}()
|
||||
|
||||
var accountID string
|
||||
result := s.db.Select("account_id").Where("key = ?", peerKey).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
}
|
||||
log.Errorf("error when getting peer from the store: %s", result.Error)
|
||||
return "", status.Errorf(status.Internal, "issue getting account from store")
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error {
|
||||
startTime := time.Now()
|
||||
|
||||
@@ -19,6 +19,7 @@ type Store interface {
|
||||
DeleteAccount(account *Account) error
|
||||
GetAccountByUser(userID string) (*Account, error)
|
||||
GetAccountByPeerPubKey(peerKey string) (*Account, error)
|
||||
GetAccountIDByPeerPubKey(peerKey string) (string, error)
|
||||
GetAccountByPeerID(peerID string) (*Account, error)
|
||||
GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later
|
||||
GetAccountByPrivateDomain(domain string) (*Account, error)
|
||||
|
||||
Reference in New Issue
Block a user