From 049b5fb7ede553da0d812590d083b6c77e5ca4a2 Mon Sep 17 00:00:00 2001 From: pascal-fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 19 Aug 2024 12:50:11 +0200 Subject: [PATCH] Split DB calls in peer login (#2439) --- management/server/account.go | 22 ++++++ management/server/file_store.go | 29 ++++++++ management/server/peer.go | 116 +++++++++++++++++--------------- management/server/sql_store.go | 28 ++++++++ management/server/store.go | 2 + 5 files changed, 144 insertions(+), 53 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 972272746..4c150fd7e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2072,6 +2072,28 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey) } +func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) { + user, err := am.Store.GetUserByUserID(ctx, peer.UserID) + if err != nil { + return false, err + } + + err = checkIfPeerOwnerIsBlocked(peer, user) + if err != nil { + return false, err + } + + if peerLoginExpired(ctx, peer, settings) { + err = am.handleExpiredPeer(ctx, user, peer) + if err != nil { + return false, err + } + return true, nil + } + + return false, nil +} + // addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { diff --git a/management/server/file_store.go b/management/server/file_store.go index 6e3536bcd..1927568ef 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -469,6 +469,35 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, return account.Users[userID].Copy(), nil } +func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) { + accountID, ok := s.UserID2AccountID[userID] + if !ok { + return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") + } + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + return account.Users[userID].Copy(), nil +} + +func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups)) + + for _, group := range account.Groups { + groupsSlice = append(groupsSlice, group) + } + + return groupsSlice, nil +} + // GetAllAccounts returns all accounts func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { s.mux.Lock() diff --git a/management/server/peer.go b/management/server/peer.go index 7afe6ee0d..93234d9de 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -549,16 +549,25 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac return nil, nil, nil, status.NewPeerNotRegisteredError() } - err = checkIfPeerOwnerIsBlocked(peer, account) - if err != nil { - return nil, nil, nil, err + if peer.UserID != "" { + log.Infof("Peer has no userID") + + user, err := account.FindUser(peer.UserID) + if err != nil { + return nil, nil, nil, err + } + + err = checkIfPeerOwnerIsBlocked(peer, user) + if err != nil { + return nil, nil, nil, err + } } if peerLoginExpired(ctx, peer, account.Settings) { return nil, nil, nil, status.NewPeerLoginExpiredError() } - peer, updated := updatePeerMeta(peer, sync.Meta, account) + updated := peer.UpdateMetaIfNew(sync.Meta) if updated { err = am.Store.SavePeer(ctx, account.Id, peer) if err != nil { @@ -624,31 +633,28 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) // it means that the client has already checked if it needs login and had been through the SSO flow // so, we can skip this check and directly proceed with the login if login.UserID == "" { + log.Info("Peer needs login") err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login) if err != nil { return nil, nil, nil, err } } - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + unlockAccount := am.Store.AcquireReadLockByUID(ctx, accountID) + defer unlockAccount() + unlockPeer := am.Store.AcquireWriteLockByUID(ctx, login.WireGuardPubKey) defer func() { - if unlock != nil { - unlock() + if unlockPeer != nil { + unlockPeer() } }() - // fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies - account, err := am.Store.GetAccount(ctx, accountID) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) if err != nil { return nil, nil, nil, err } - peer, err := account.FindPeerByPubKey(login.WireGuardPubKey) - if err != nil { - return nil, nil, nil, status.NewPeerNotRegisteredError() - } - - err = checkIfPeerOwnerIsBlocked(peer, account) + settings, err := am.Store.GetAccountSettings(ctx, accountID) if err != nil { return nil, nil, nil, err } @@ -656,21 +662,39 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) // this flag prevents unnecessary calls to the persistent store. shouldStorePeer := false updateRemotePeers := false - if peerLoginExpired(ctx, peer, account.Settings) { - err = am.handleExpiredPeer(ctx, login, account, peer) + + if login.UserID != "" { + changed, err := am.handleUserPeer(ctx, peer, settings) if err != nil { return nil, nil, nil, err } - updateRemotePeers = true - shouldStorePeer = true + if changed { + shouldStorePeer = true + updateRemotePeers = true + } } - isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) + groups, err := am.Store.GetAccountGroups(ctx, accountID) if err != nil { return nil, nil, nil, err } - peer, updated := updatePeerMeta(peer, login.Meta, account) + var grps []string + for _, group := range groups { + for _, id := range group.Peers { + if id == peer.ID { + grps = append(grps, group.ID) + break + } + } + } + + isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra) + if err != nil { + return nil, nil, nil, err + } + + updated := peer.UpdateMetaIfNew(login.Meta) if updated { shouldStorePeer = true } @@ -687,8 +711,13 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } } - unlock() - unlock = nil + unlockPeer() + unlockPeer = nil + + account, err := am.Store.GetAccount(ctx, accountID) + if err != nil { + return nil, nil, nil, err + } if updateRemotePeers || isStatusChanged { am.updateAccountPeers(ctx, account) @@ -746,36 +775,30 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil } -func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error { - err := checkAuth(ctx, login.UserID, peer) +func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error { + err := checkAuth(ctx, user.Id, peer) if err != nil { return err } // If peer was expired before and if it reached this point, it is re-authenticated. // UserID is present, meaning that JWT validation passed successfully in the API layer. - updatePeerLastLogin(peer, account) - - // sync user last login with peer last login - user, err := account.FindUser(login.UserID) - if err != nil { - return status.Errorf(status.Internal, "couldn't find user") - } - - err = am.Store.SaveUserLastLogin(account.Id, user.Id, peer.LastLogin) + peer = peer.UpdateLastLogin() + err = am.Store.SavePeer(ctx, peer.AccountID, peer) if err != nil { return err } - am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) + err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin) + if err != nil { + return err + } + + am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain())) return nil } -func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error { +func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error { if peer.AddedWithSSOLogin() { - user, err := account.FindUser(peer.UserID) - if err != nil { - return status.Errorf(status.PermissionDenied, "user doesn't exist") - } if user.IsBlocked() { return status.Errorf(status.PermissionDenied, "user is blocked") } @@ -805,11 +828,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings return false } -func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) { - peer.UpdateLastLogin() - account.UpdatePeer(peer) -} - // UpdatePeerSSHKey updates peer's public SSH key func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { if sshKey == "" { @@ -908,14 +926,6 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID) } -func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Account) (*nbpeer.Peer, bool) { - if peer.UpdateMetaIfNew(meta) { - account.UpdatePeer(peer) - return peer, true - } - return peer, false -} - // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { diff --git a/management/server/sql_store.go b/management/server/sql_store.go index c44ab7f09..912e31410 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -468,6 +468,34 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return &user, nil } +func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) { + var user User + result := s.db.First(&user, idQueryCondition, userID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "user not found: index lookup failed") + } + log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting user from store") + } + + return &user, nil +} + +func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { + var groups []*nbgroup.Group + result := s.db.Find(&groups, idQueryCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") + } + log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting groups from store") + } + + return groups, nil +} + func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { var accounts []Account result := s.db.Find(&accounts) diff --git a/management/server/store.go b/management/server/store.go index 864871c8e..a2b489391 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -41,6 +41,8 @@ type Store interface { GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) + GetUserByUserID(ctx context.Context, userID string) (*User, error) + GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) SaveAccount(ctx context.Context, account *Account) error SaveUsers(accountID string, users map[string]*User) error