run peer ops in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-11-18 15:06:25 +03:00
parent f6f7260897
commit a61e9da3e9
6 changed files with 233 additions and 163 deletions

View File

@@ -11,6 +11,7 @@ import (
"sync"
"time"
"github.com/netbirdio/netbird/management/server/geolocation"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
@@ -117,17 +118,25 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peerPubKey)
if err != nil {
return err
}
var peer *nbpeer.Peer
var settings *Settings
var expired bool
var err error
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, peerPubKey)
if err != nil {
return err
}
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, accountID)
settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID)
return err
})
if err != nil {
return err
}
@@ -151,7 +160,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return nil
}
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
func updatePeerStatusAndLocation(ctx context.Context, geo *geolocation.Geolocation, transaction Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
oldStatus := peer.Status.Copy()
newStatus := oldStatus
newStatus.LastSeen = time.Now().UTC()
@@ -162,8 +171,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
}
peer.Status = newStatus
if am.geo != nil && realIP != nil {
location, err := am.geo.Lookup(realIP)
if geo != nil && realIP != nil {
location, err := geo.Lookup(realIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
} else {
@@ -171,14 +180,14 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID
err = am.Store.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer)
err = transaction.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer)
if err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
}
}
}
err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus)
err := transaction.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus)
if err != nil {
return false, err
}
@@ -200,23 +209,49 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return nil, status.NewUserNotPartOfAccountError()
}
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, update.ID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, update.ID)
if err != nil {
return nil, err
}
var peer *nbpeer.Peer
var settings *Settings
var peerGroupList []string
var requiresPeerUpdates bool
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra)
var newLabel string
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, update.ID)
if err != nil {
return err
}
settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
peerGroupList, err = getPeerGroupIDs(ctx, am.Store, accountID, update.ID)
if err != nil {
return err
}
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra)
if err != nil {
return err
}
if peer.Name != update.Name {
existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID)
if err != nil {
return err
}
newLabel, err = getPeerHostLabel(update.Name, existingLabels)
if err != nil {
return err
}
peer.DNSLabel = newLabel
}
return transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
})
if err != nil {
return nil, err
}
@@ -231,18 +266,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
if peer.Name != update.Name {
peer.Name = update.Name
peerLabelChanged = true
existingLabels, err := am.getPeerDNSLabels(ctx, accountID)
if err != nil {
return nil, err
}
newLabel, err := getPeerHostLabel(peer.Name, existingLabels)
if err != nil {
return nil, err
}
peer.DNSLabel = newLabel
}
if peer.LoginExpirationEnabled != update.LoginExpirationEnabled {
@@ -261,10 +284,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
inactivityExpirationChanged = true
}
if err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil {
return nil, err
}
if sshChanged {
event := activity.PeerSSHEnabled
if !peer.SSHEnabled {
@@ -313,13 +332,18 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, peerID)
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, LockingStrengthShare, peerID)
if err != nil {
return err
}
if peerAccountID != accountID {
return status.NewPeerNotPartOfAccountError()
}
var peer *nbpeer.Peer
var addPeerRemovedEvents []func()
var updateAccountPeers bool
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, peerID)
@@ -327,16 +351,21 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID)
if err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
return err
})
for _, addPeerRemovedEvent := range addPeerRemovedEvents {
addPeerRemovedEvent()
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
@@ -433,6 +462,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
var newPeer *nbpeer.Peer
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
var setupKeyID string
@@ -480,7 +510,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return fmt.Errorf("failed to get free DNS label: %w", err)
}
freeIP, err := am.getFreeIP(ctx, transaction, accountID)
freeIP, err := getFreeIP(ctx, transaction, accountID)
if err != nil {
return fmt.Errorf("failed to get free IP: %w", err)
}
@@ -564,6 +594,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
}
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, newPeer.ID)
if err != nil {
return err
}
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
return nil
})
@@ -581,11 +616,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
unlock()
unlock = nil
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, newPeer.ID)
if err != nil {
return nil, nil, nil, err
}
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
@@ -593,13 +623,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
}
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
func getFreeIP(ctx context.Context, transaction Store, accountID string) (net.IP, error) {
takenIps, err := transaction.GetTakenIPs(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
}
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
network, err := transaction.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return nil, fmt.Errorf("failed getting network: %w", err)
}
@@ -614,48 +644,59 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, sync.WireGuardPubKey)
if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError()
}
var peer *nbpeer.Peer
var peerNotValid bool
var isStatusChanged bool
var updated bool
var err error
if peer.UserID != "" {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, sync.WireGuardPubKey)
if err != nil {
return nil, nil, nil, err
return status.NewPeerNotRegisteredError()
}
err = checkIfPeerOwnerIsBlocked(peer, user)
if err != nil {
return nil, nil, nil, err
if peer.UserID != "" {
user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
if err != nil {
return err
}
if err = checkIfPeerOwnerIsBlocked(peer, user); err != nil {
return err
}
}
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
if peerLoginExpired(ctx, peer, settings) {
return nil, nil, nil, status.NewPeerLoginExpiredError()
}
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, peer.ID)
if err != nil {
return nil, nil, nil, err
}
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupList, settings.Extra)
if err != nil {
return nil, nil, nil, err
}
updated := peer.UpdateMetaIfNew(sync.Meta)
if updated {
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
return err
}
if peerLoginExpired(ctx, peer, settings) {
return status.NewPeerLoginExpiredError()
}
peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
peerNotValid, isStatusChanged, err = am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil {
return err
}
updated = peer.UpdateMetaIfNew(sync.Meta)
if updated {
err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, nil, nil, err
}
if isStatusChanged || (updated && sync.UpdateAccountPeers) {
@@ -707,73 +748,73 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
}()
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, err
}
var peer *nbpeer.Peer
var updateRemotePeers bool
var isRequiresApproval bool
var isStatusChanged bool
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
// this flag prevents unnecessary calls to the persistent store.
shouldStorePeer := false
updateRemotePeers := false
if login.UserID != "" {
if peer.UserID != login.UserID {
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
return nil, nil, nil, status.Errorf(status.Unauthenticated, "invalid user")
}
changed, err := am.handleUserPeer(ctx, peer, settings)
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil {
return nil, nil, nil, err
return err
}
if changed {
shouldStorePeer = true
updateRemotePeers = true
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
}
// this flag prevents unnecessary calls to the persistent store.
shouldStorePeer := false
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
if login.UserID != "" {
if peer.UserID != login.UserID {
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
return status.Errorf(status.Unauthenticated, "invalid user")
}
var grps []string
for _, group := range groups {
for _, id := range group.Peers {
if id == peer.ID {
grps = append(grps, group.ID)
break
changed, err := am.handleUserPeer(ctx, transaction, peer, settings)
if err != nil {
return err
}
if changed {
shouldStorePeer = true
updateRemotePeers = true
}
}
}
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra)
peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
if err != nil {
return err
}
isRequiresApproval, isStatusChanged, err = am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil {
return err
}
updated := peer.UpdateMetaIfNew(login.Meta)
if updated {
shouldStorePeer = true
}
if peer.SSHKey != login.SSHKey {
peer.SSHKey = login.SSHKey
shouldStorePeer = true
}
if shouldStorePeer {
if err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil {
return err
}
}
return nil
})
if err != nil {
return nil, nil, nil, err
}
updated := peer.UpdateMetaIfNew(login.Meta)
if updated {
shouldStorePeer = true
}
if peer.SSHKey != login.SSHKey {
peer.SSHKey = login.SSHKey
shouldStorePeer = true
}
if shouldStorePeer {
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
if err != nil {
return nil, nil, nil, err
}
}
unlockPeer()
unlockPeer = nil
@@ -845,7 +886,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
}
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error {
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction Store, user *User, peer *nbpeer.Peer) error {
err := checkAuth(ctx, user.Id, peer)
if err != nil {
return err
@@ -853,12 +894,12 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
// If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer.
peer = peer.UpdateLastLogin()
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer)
err = transaction.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer)
if err != nil {
return err
}
err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
if err != nil {
return err
}
@@ -1149,7 +1190,12 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID
// GetPeerGroups returns groups that the peer is part of.
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
return getPeerGroups(ctx, am.Store, accountID, peerID)
}
// getPeerGroups returns the IDs of the groups that the peer is part of.
func getPeerGroups(ctx context.Context, transaction Store, accountID, peerID string) ([]*nbgroup.Group, error) {
groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
@@ -1165,8 +1211,8 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p
}
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID string, peerID string) ([]string, error) {
groups, err := am.GetPeerGroups(ctx, accountID, peerID)
func getPeerGroupIDs(ctx context.Context, transaction Store, accountID string, peerID string) ([]string, error) {
groups, err := getPeerGroups(ctx, transaction, accountID, peerID)
if err != nil {
return nil, err
}
@@ -1179,8 +1225,8 @@ func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID
return groupIDs, err
}
func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID string) (lookupMap, error) {
dnsLabels, err := am.Store.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID)
func getPeerDNSLabels(ctx context.Context, transaction Store, accountID string) (lookupMap, error) {
dnsLabels, err := transaction.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
@@ -1194,12 +1240,12 @@ func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, accountID, peerID string) (bool, error) {
peerGroupIDs, err := am.getPeerGroupIDs(ctx, accountID, peerID)
func isPeerInActiveGroup(ctx context.Context, transaction Store, accountID, peerID string) (bool, error) {
peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID)
if err != nil {
return false, err
}
return areGroupChangesAffectPeers(ctx, am.Store, accountID, peerGroupIDs) // TODO: use transaction
return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction
}
// deletePeers deletes all specified peers and sends updates to the remote peers.