[management] Refactor peers to use store methods (#2893)

This commit is contained in:
Bethuel Mmbaga
2025-01-20 20:41:46 +03:00
committed by GitHub
parent c619bf5b0c
commit 1ad2cb5582
30 changed files with 1614 additions and 857 deletions

View File

@@ -45,6 +45,7 @@ import (
const (
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
peerSchedulerRetryInterval = 3 * time.Second
emptyUserID = "empty user ID in claims"
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
)
@@ -85,7 +86,7 @@ type AccountManager interface {
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *types.Account) error
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
@@ -105,6 +106,7 @@ type AccountManager interface {
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error)
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error)
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
@@ -126,8 +128,8 @@ type AccountManager interface {
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error)
LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager
@@ -138,7 +140,7 @@ type AccountManager interface {
GetIdpManager() idp.Manager
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(account *types.Account) (map[string]struct{}, error)
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
@@ -379,14 +381,14 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
event = activity.AccountPeerLoginExpirationDisabled
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerLoginExpiration(ctx, account)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(ctx, account)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
}
updateAccountPeers := false
@@ -400,7 +402,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
account.Network.Serial++
}
err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID)
err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, err
}
@@ -437,13 +439,13 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con
return nil
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *types.Account, oldSettings, newSettings *types.Settings, userID, accountID string) error {
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
} else {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
@@ -452,7 +454,7 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.
event = activity.AccountPeerInactivityExpirationDisabled
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
@@ -466,33 +468,31 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
expiredPeers, err := am.getExpiredPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed getting account %s expiring peers", accountID)
return account.GetNextPeerExpiration()
return peerSchedulerRetryInterval, true
}
expiredPeers := account.GetExpiredPeers()
var peerIDs []string
for _, peer := range expiredPeers {
peerIDs = append(peerIDs, peer.ID)
}
log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID)
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", account.Id)
return account.GetNextPeerExpiration()
if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil {
log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", accountID)
return peerSchedulerRetryInterval, true
}
return account.GetNextPeerExpiration()
return am.getNextPeerExpiration(ctx, accountID)
}
}
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *types.Account) {
am.peerLoginExpiry.Cancel(ctx, []string{account.Id})
if nextRun, ok := account.GetNextPeerExpiration(); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id))
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) {
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
}
}
@@ -502,34 +502,33 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
inactivePeers, err := am.getInactivePeers(ctx, accountID)
if err != nil {
log.Errorf("failed getting account %s expiring peers", accountID)
return account.GetNextInactivePeerExpiration()
log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID)
return peerSchedulerRetryInterval, true
}
expiredPeers := account.GetInactivePeers()
var peerIDs []string
for _, peer := range expiredPeers {
for _, peer := range inactivePeers {
peerIDs = append(peerIDs, peer.ID)
}
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID)
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
return account.GetNextInactivePeerExpiration()
if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers); err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", accountID)
return peerSchedulerRetryInterval, true
}
return account.GetNextInactivePeerExpiration()
return am.getNextInactivePeerExpiration(ctx, accountID)
}
}
// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *types.Account) {
am.peerInactivityExpiry.Cancel(ctx, []string{account.Id})
if nextRun, ok := account.GetNextInactivePeerExpiration(); ok {
go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id))
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, accountID string) {
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
if nextRun, ok := am.getNextInactivePeerExpiration(ctx, accountID); ok {
go am.peerInactivityExpiry.Schedule(ctx, nextRun, accountID, am.peerInactivityExpirationJob(ctx, accountID))
}
}
@@ -665,7 +664,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
return "", status.Errorf(status.NotFound, "no valid userID provided")
}
accountID, err := am.Store.GetAccountIDByUserID(userID)
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
@@ -1450,7 +1449,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", err
}
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId)
if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
@@ -1497,7 +1496,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
}
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId)
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId)
if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err
@@ -1559,17 +1558,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, nil, nil, status.NewGetAccountError(err)
}
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
}
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account)
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
@@ -1583,12 +1577,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return status.NewGetAccountError(err)
}
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}
@@ -1609,12 +1598,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer unlockPeer()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
_, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account)
_, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
if err != nil {
return mapError(ctx, err)
}
@@ -1683,8 +1667,8 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey)
}
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *types.Settings) (bool, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID)
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID)
if err != nil {
return false, err
}
@@ -1695,7 +1679,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
}
if peerLoginExpired(ctx, peer, settings) {
err = am.handleExpiredPeer(ctx, user, peer)
err = am.handleExpiredPeer(ctx, transaction, user, peer)
if err != nil {
return false, err
}