[management] Rework DB locks (#4291)

This commit is contained in:
Pascal Fischer
2025-08-06 18:55:14 +02:00
committed by GitHub
parent dfd8bbc015
commit 5860e5343f
32 changed files with 667 additions and 673 deletions

View File

@@ -40,12 +40,12 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -346,12 +346,12 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
} }
if updateAccountPeers || groupsUpdated { if updateAccountPeers || groupsUpdated {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
} }
return transaction.SaveAccountSettings(ctx, store.LockingStrengthUpdate, accountID, newSettings) return transaction.SaveAccountSettings(ctx, accountID, newSettings)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -405,7 +405,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
} }
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil { if err != nil {
return err return err
} }
@@ -746,7 +746,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
// AccountExists checks if an account exists. // AccountExists checks if an account exists.
func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
return am.Store.AccountExists(ctx, store.LockingStrengthShare, accountID) return am.Store.AccountExists(ctx, store.LockingStrengthNone, accountID)
} }
// GetAccountIDByUserID retrieves the account ID based on the userID provided. // GetAccountIDByUserID retrieves the account ID based on the userID provided.
@@ -758,7 +758,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
return "", status.Errorf(status.NotFound, "no valid userID provided") return "", status.Errorf(status.NotFound, "no valid userID provided")
} }
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
@@ -813,7 +813,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID) accountIDString := fmt.Sprintf("%v", accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountIDString) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -867,7 +867,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) { func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -897,7 +897,7 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP, // add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta // or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID) log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
return nil, err return nil, err
@@ -1048,7 +1048,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlockAccount() defer unlockAccount()
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, accountID) accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return err return err
@@ -1058,7 +1058,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
return nil return nil
} }
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting user: %v", err) log.WithContext(ctx).Errorf("error getting user: %v", err)
return err return err
@@ -1145,7 +1145,7 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
newUser := types.NewRegularUser(userAuth.UserId) newUser := types.NewRegularUser(userAuth.UserId)
newUser.AccountID = domainAccountID newUser.AccountID = domainAccountID
err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser) err := am.Store.SaveUser(ctx, newUser)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -1223,7 +1223,7 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
} }
// GetAccountOnboarding retrieves the onboarding information for a specific account. // GetAccountOnboarding retrieves the onboarding information for a specific account.
@@ -1308,7 +1308,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return "", "", err return "", "", err
} }
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
// this is not really possible because we got an account by user ID // this is not really possible because we got an account by user ID
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId) return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
@@ -1340,7 +1340,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil return nil
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return err return err
} }
@@ -1366,12 +1366,12 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
var hasChanges bool var hasChanges bool
var user *types.User var user *types.User
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
return fmt.Errorf("error getting user: %w", err) return fmt.Errorf("error getting user: %w", err)
} }
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId) groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@@ -1387,7 +1387,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil return nil
} }
if err = transaction.CreateGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil { if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err) return fmt.Errorf("error saving groups: %w", err)
} }
@@ -1395,13 +1395,13 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups) removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups)
user.AutoGroups = updatedAutoGroups user.AutoGroups = updatedAutoGroups
if err = transaction.SaveUser(ctx, store.LockingStrengthUpdate, user); err != nil { if err = transaction.SaveUser(ctx, user); err != nil {
return fmt.Errorf("error saving user: %w", err) return fmt.Errorf("error saving user: %w", err)
} }
// Propagate changes to peers if group propagation is enabled // Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId) peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
return fmt.Errorf("error getting user peers: %w", err) return fmt.Errorf("error getting user peers: %w", err)
} }
@@ -1419,7 +1419,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil { if err = transaction.IncrementNetworkSerial(ctx, userAuth.AccountId); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err) return fmt.Errorf("error incrementing network serial: %w", err)
} }
} }
@@ -1437,7 +1437,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
for _, g := range addNewGroups { for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else { } else {
@@ -1450,7 +1450,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
for _, g := range removeOldGroups { for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else { } else {
@@ -1511,7 +1511,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
} }
if userAuth.IsChild { if userAuth.IsChild {
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId) exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil || !exists { if err != nil || !exists {
return "", err return "", err
} }
@@ -1535,7 +1535,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", err return "", err
} }
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err return "", err
@@ -1556,7 +1556,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return am.addNewPrivateAccount(ctx, domainAccountID, userAuth) return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
} }
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
@@ -1571,7 +1571,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
cancel := am.Store.AcquireGlobalLock(ctx) cancel := am.Store.AcquireGlobalLock(ctx)
// check again if the domain has a primary account because of simultaneous requests // check again if the domain has a primary account because of simultaneous requests
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
cancel() cancel()
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
@@ -1582,7 +1582,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
} }
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err return "", err
@@ -1592,7 +1592,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId) return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId)
} }
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId) accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, userAuth.AccountId)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return "", err return "", err
@@ -1603,7 +1603,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
} }
// We checked if the domain has a primary account already // We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, userAuth.Domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", err return "", err
@@ -1751,7 +1751,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
} }
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) { func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -1780,7 +1780,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
if !allowed { if !allowed {
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
} }
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
@@ -1870,7 +1870,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
cancel := am.Store.AcquireGlobalLock(ctx) cancel := am.Store.AcquireGlobalLock(ctx)
defer cancel() defer cancel()
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
return nil, false, err return nil, false, err
} }
@@ -1890,7 +1890,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
for range 2 { for range 2 {
accountId := xid.New().String() accountId := xid.New().String()
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId) exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, accountId)
if err != nil || exists { if err != nil || exists {
continue continue
} }
@@ -1965,7 +1965,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return nil return nil
} }
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain) existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain)
// error is not a not found error // error is not a not found error
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
@@ -2002,17 +2002,17 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
// propagateUserGroupMemberships propagates all account users' group memberships to their peers. // propagateUserGroupMemberships propagates all account users' group memberships to their peers.
// Returns true if any groups were modified, true if those updates affect peers and an error. // Returns true if any groups were modified, true if those updates affect peers and an error.
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, false, err return false, false, err
} }
accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthShare, accountID) accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, false, fmt.Errorf("error getting account group peers: %w", err) return false, false, fmt.Errorf("error getting account group peers: %w", err)
} }
accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, false, fmt.Errorf("error getting account groups: %w", err) return false, false, fmt.Errorf("error getting account groups: %w", err)
} }
@@ -2025,7 +2025,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
updatedGroups := []string{} updatedGroups := []string{}
for _, user := range users { for _, user := range users {
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id) userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id)
if err != nil { if err != nil {
return false, false, err return false, false, err
} }
@@ -2074,7 +2074,7 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
account.Network.Net = newIPNet account.Network.Net = newIPNet
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil { if err != nil {
return err return err
} }
@@ -2099,7 +2099,7 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t
} }
for _, peer := range peers { for _, peer := range peers {
if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil { if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err) return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err)
} }
} }
@@ -2154,7 +2154,7 @@ func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context,
return fmt.Errorf("get account: %w", err) return fmt.Errorf("get account: %w", err)
} }
existingPeer, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) existingPeer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return fmt.Errorf("get peer: %w", err) return fmt.Errorf("get peer: %w", err)
} }
@@ -2185,7 +2185,7 @@ func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context,
func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transaction store.Store, accountID, userID string, peer *nbpeer.Peer, newIP netip.Addr) error { func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transaction store.Store, accountID, userID string, peer *nbpeer.Peer, newIP netip.Addr) error {
log.WithContext(ctx).Infof("updating peer %s IP from %s to %s", peer.ID, peer.IP, newIP) log.WithContext(ctx).Infof("updating peer %s IP from %s to %s", peer.ID, peer.IP, newIP)
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return fmt.Errorf("get account settings: %w", err) return fmt.Errorf("get account settings: %w", err)
} }
@@ -2195,7 +2195,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti
oldIP := peer.IP.String() oldIP := peer.IP.String()
peer.IP = newIP.AsSlice() peer.IP = newIP.AsSlice()
err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer) err = transaction.SavePeer(ctx, accountID, peer)
if err != nil { if err != nil {
return fmt.Errorf("save peer: %w", err) return fmt.Errorf("save peer: %w", err)
} }

View File

@@ -783,7 +783,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
return return
} }
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID) exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthNone, accountID)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, exists, "expected to get existing account after creation using userid") assert.True(t, exists, "expected to get existing account after creation using userid")
@@ -900,11 +900,11 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount)) t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
} }
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1") pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, "service-user-1")
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, pats, 0) assert.Len(t, pats, 0)
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId) pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, userId)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, pats, 0) assert.Len(t, pats, 0)
} }
@@ -1786,7 +1786,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
assert.NotNil(t, settings) assert.NotNil(t, settings)
@@ -1971,7 +1971,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updatedSettings.PeerLoginExpirationEnabled) assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour) assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
assert.False(t, settings.PeerLoginExpirationEnabled) assert.False(t, settings.PeerLoginExpirationEnabled)
@@ -2655,7 +2655,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced") assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
}) })
@@ -2669,7 +2669,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims) err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Empty(t, user.AutoGroups, "auto groups must be empty") assert.Empty(t, user.AutoGroups, "auto groups must be empty")
}) })
@@ -2683,18 +2683,18 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims) err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0) assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
}) })
t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { t.Run("jwt match existing api group in user auto groups", func(t *testing.T) {
account.Users["user1"].AutoGroups = []string{"group1"} account.Users["user1"].AutoGroups = []string{"group1"}
assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"])) assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"]))
claims := nbcontext.UserAuth{ claims := nbcontext.UserAuth{
UserId: "user1", UserId: "user1",
@@ -2704,11 +2704,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1) assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2722,7 +2722,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change") assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
@@ -2736,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change") assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
@@ -2750,11 +2750,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID") groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, "accountID")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added") assert.Len(t, groups, 3, "new group3 should be added")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added") assert.Len(t, user.AutoGroups, 1, "new group should be added")
}) })
@@ -2768,7 +2768,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present") assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
@@ -2783,7 +2783,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed") assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
}) })
@@ -3348,11 +3348,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"} peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1) err = manager.Store.AddPeerToAccount(ctx, peer1)
require.NoError(t, err) require.NoError(t, err)
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"} peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"}
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2) err = manager.Store.AddPeerToAccount(ctx, peer2)
require.NoError(t, err) require.NoError(t, err)
t.Run("should skip propagation when the user has no groups", func(t *testing.T) { t.Run("should skip propagation when the user has no groups", func(t *testing.T) {
@@ -3364,20 +3364,20 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) { t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) {
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id} group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group1)) require.NoError(t, manager.Store.CreateGroup(ctx, group1))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group1.ID) user.AutoGroups = append(user.AutoGroups, group1.ID)
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) require.NoError(t, manager.Store.SaveUser(ctx, user))
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err) require.NoError(t, err)
assert.True(t, groupsUpdated) assert.True(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers) assert.False(t, groupChangesAffectPeers)
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthShare, account.Id, group1.ID) group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, group1.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1") assert.Contains(t, group.Peers, "peer1")
@@ -3386,13 +3386,13 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
t.Run("should update membership and account peers for used groups", func(t *testing.T) { t.Run("should update membership and account peers for used groups", func(t *testing.T) {
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id} group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group2)) require.NoError(t, manager.Store.CreateGroup(ctx, group2))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group2.ID) user.AutoGroups = append(user.AutoGroups, group2.ID)
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) require.NoError(t, manager.Store.SaveUser(ctx, user))
_, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{ _, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{
Name: "Group1 Policy", Name: "Group1 Policy",
@@ -3415,7 +3415,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.True(t, groupsUpdated) assert.True(t, groupsUpdated)
assert.True(t, groupChangesAffectPeers) assert.True(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"}) groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err) require.NoError(t, err)
for _, group := range groups { for _, group := range groups {
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)
@@ -3432,18 +3432,18 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
}) })
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) { t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = []string{"group1"} user.AutoGroups = []string{"group1"}
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) require.NoError(t, manager.Store.SaveUser(ctx, user))
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err) require.NoError(t, err)
assert.False(t, groupsUpdated) assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers) assert.False(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"}) groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err) require.NoError(t, err)
for _, group := range groups { for _, group := range groups {
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)

View File

@@ -73,7 +73,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
return userAuth, nil return userAuth, nil
} }
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return userAuth, err return userAuth, err
} }
@@ -94,7 +94,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
// MarkPATUsed marks a personal access token as used // MarkPATUsed marks a personal access token as used
func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error { func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error {
return am.store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID) return am.store.MarkPATUsed(ctx, tokenID)
} }
// GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token. // GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token.
@@ -104,7 +104,7 @@ func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.Us
return nil, nil, "", "", err return nil, nil, "", "", err
} }
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID) domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, user.AccountID)
if err != nil { if err != nil {
return nil, nil, "", "", err return nil, nil, "", "", err
} }
@@ -142,12 +142,12 @@ func (am *manager) extractPATFromToken(ctx context.Context, token string) (*type
var pat *types.PersonalAccessToken var pat *types.PersonalAccessToken
err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken) pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthNone, encodedHashedToken)
if err != nil { if err != nil {
return err return err
} }
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID) user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthNone, pat.ID)
return err return err
}) })
if err != nil { if err != nil {

View File

@@ -8,14 +8,14 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
) )
// DNSConfigCache is a thread-safe cache for DNS configuration components // DNSConfigCache is a thread-safe cache for DNS configuration components
@@ -72,7 +72,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
} }
// SaveDNSSettings validates a user role and updates the account's DNS settings // SaveDNSSettings validates a user role and updates the account's DNS settings
@@ -113,11 +113,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, events...)
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.SaveDNSSettings(ctx, store.LockingStrengthUpdate, accountID, dnsSettingsToSave) return transaction.SaveDNSSettings(ctx, accountID, dnsSettingsToSave)
}) })
if err != nil { if err != nil {
return err return err
@@ -139,7 +139,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
var eventsToStore []func() var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups) modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
return nil return nil
@@ -195,7 +195,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
return nil return nil
} }
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, settings.DisabledManagementGroups)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -134,7 +134,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
} }
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare) peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err) log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return return

View File

@@ -11,9 +11,9 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
func isEnabled() bool { func isEnabled() bool {
@@ -103,7 +103,7 @@ func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, eve
} }
func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) { func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -154,7 +154,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
continue continue
} }
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
if err != nil { if err != nil {
// @todo consider logging // @todo consider logging
continue continue

View File

@@ -14,11 +14,11 @@ import (
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
) )
type GroupLinkError struct { type GroupLinkError struct {
@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) return am.Store.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
@@ -57,12 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName) return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
} }
// CreateGroup object of the peers // CreateGroup object of the peers
@@ -96,11 +96,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
if err := transaction.CreateGroup(ctx, store.LockingStrengthUpdate, newGroup); err != nil { if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err) return status.Errorf(status.Internal, "failed to create group: %v", err)
} }
@@ -147,7 +147,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err return err
} }
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err != nil { if err != nil {
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID) return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
} }
@@ -176,11 +176,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, newGroup) return transaction.UpdateGroup(ctx, newGroup)
}) })
if err != nil { if err != nil {
return err return err
@@ -234,11 +234,11 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.CreateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) return transaction.CreateGroups(ctx, accountID, groupsToSave)
}) })
if err != nil { if err != nil {
return err return err
@@ -292,11 +292,11 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.UpdateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) return transaction.UpdateGroups(ctx, accountID, groupsToSave)
}) })
if err != nil { if err != nil {
return err return err
@@ -320,7 +320,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err == nil && oldGroup != nil { if err == nil && oldGroup != nil {
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers) addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers) removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
@@ -332,13 +332,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
} }
modifiedPeers := slices.Concat(addedPeers, removedPeers) modifiedPeers := slices.Concat(addedPeers, removedPeers)
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers) peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, modifiedPeers)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil return nil
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err) log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
return nil return nil
@@ -423,11 +423,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
deletedGroups = append(deletedGroups, group) deletedGroups = append(deletedGroups, group)
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.DeleteGroups(ctx, store.LockingStrengthUpdate, accountID, groupIDsToDelete) return transaction.DeleteGroups(ctx, accountID, groupIDsToDelete)
}) })
if err != nil { if err != nil {
return err return err
@@ -454,7 +454,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
@@ -495,11 +495,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group) return transaction.UpdateGroup(ctx, group)
}) })
if err != nil { if err != nil {
return err return err
@@ -526,7 +526,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
@@ -567,11 +567,11 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group) return transaction.UpdateGroup(ctx, group)
}) })
if err != nil { if err != nil {
return err return err
@@ -591,7 +591,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
} }
if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI { if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name) existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthNone, accountID, newGroup.Name)
if err != nil { if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err return err
@@ -608,7 +608,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
} }
for _, peerID := range newGroup.Peers { for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) _, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
} }
@@ -620,7 +620,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error { func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration { if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID) executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get user") return status.Errorf(status.Internal, "failed to get user")
} }
@@ -666,7 +666,7 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. // checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error { func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID) dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get DNS settings") return status.Errorf(status.Internal, "failed to get DNS settings")
} }
@@ -675,7 +675,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
return &GroupLinkError{"disabled DNS management groups", group.Name} return &GroupLinkError{"disabled DNS management groups", group.Name}
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get account settings") return status.Errorf(status.Internal, "failed to get account settings")
} }
@@ -689,7 +689,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
// isGroupLinkedToRoute checks if a group is linked to any route in the account. // isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) { func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil return false, nil
@@ -709,7 +709,7 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. // isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) { func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil return false, nil
@@ -727,7 +727,7 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, account
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil return false, nil
@@ -746,7 +746,7 @@ func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) { func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil return false, nil
@@ -762,7 +762,7 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accou
// isGroupLinkedToUser checks if a group is linked to any user in the account. // isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) { func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil return false, nil
@@ -778,7 +778,7 @@ func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account. // isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) { func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
return false, nil return false, nil
@@ -798,7 +798,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil return false, nil
} }
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -826,7 +826,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. // anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@@ -26,10 +26,10 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
peer2 "github.com/netbirdio/netbird/management/server/peer" peer2 "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -898,7 +898,7 @@ func Test_AddPeerAndAddToAll(t *testing.T) {
} }
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error { err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
err = transaction.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) err = transaction.AddPeerToAccount(context.Background(), peer)
if err != nil { if err != nil {
return fmt.Errorf("AddPeer failed for peer %d: %w", i, err) return fmt.Errorf("AddPeer failed for peer %d: %w", i, err)
} }
@@ -971,7 +971,7 @@ func Test_IncrementNetworkSerial(t *testing.T) {
<-start <-start
err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error { err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error {
err = transaction.IncrementNetworkSerial(context.Background(), store.LockingStrengthNone, accountID) err = transaction.IncrementNetworkSerial(context.Background(), accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get account %s: %v", accountID, err) return fmt.Errorf("failed to get account %s: %v", accountID, err)
} }

View File

@@ -6,12 +6,12 @@ import (
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/http/api"
) )
type Manager interface { type Manager interface {
@@ -49,7 +49,7 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string
return nil, err return nil, err
} }
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err) return nil, fmt.Errorf("error getting account groups: %w", err)
} }
@@ -96,13 +96,13 @@ func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, trans
return nil, fmt.Errorf("error adding resource to group: %w", err) return nil, fmt.Errorf("error adding resource to group: %w", err)
} }
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting group: %w", err) return nil, fmt.Errorf("error getting group: %w", err)
} }
// TODO: at some point, this will need to become a switch statement // TODO: at some point, this will need to become a switch statement
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resource.ID) networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resource.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting network resource: %w", err) return nil, fmt.Errorf("error getting network resource: %w", err)
} }
@@ -120,13 +120,13 @@ func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context,
return nil, fmt.Errorf("error removing resource from group: %w", err) return nil, fmt.Errorf("error removing resource from group: %w", err)
} }
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting group: %w", err) return nil, fmt.Errorf("error getting group: %w", err)
} }
// TODO: at some point, this will need to become a switch statement // TODO: at some point, this will need to become a switch statement
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting network resource: %w", err) return nil, fmt.Errorf("error getting network resource: %w", err)
} }

View File

@@ -24,7 +24,6 @@ import (
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/auth" "github.com/netbirdio/netbird/management/server/auth"
@@ -32,9 +31,10 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
internalStatus "github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
internalStatus "github.com/netbirdio/netbird/shared/management/status"
) )
// GRPCServer an instance of a Management gRPC API server // GRPCServer an instance of a Management gRPC API server
@@ -920,7 +920,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
return nil, err return nil, err
} }
peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peerKey.String()) peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String())
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("peer %s is not registered for logout", peerKey.String()) log.WithContext(ctx).Debugf("peer %s is not registered for logout", peerKey.String())
// TODO: consider idempotency // TODO: consider idempotency

View File

@@ -77,7 +77,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID) _, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthNone, accountID, groupID)
if err != nil { if err != nil {
return err return err
} }
@@ -97,17 +97,17 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
var settings *types.Settings var settings *types.Settings
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -22,7 +22,6 @@ import (
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
"github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/formatter/hook"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
@@ -30,6 +29,7 @@ import (
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -645,7 +645,7 @@ func testSyncStatusRace(t *testing.T) {
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerWithInvalidStatus.PublicKey().String())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return

View File

@@ -13,9 +13,9 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$` const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$`
@@ -32,7 +32,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID) return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupID)
} }
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
@@ -73,11 +73,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, newNSGroup) return transaction.SaveNameServerGroup(ctx, newNSGroup)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -112,7 +112,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
var updateAccountPeers bool var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID) oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -127,11 +127,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, nsGroupToSave) return transaction.SaveNameServerGroup(ctx, nsGroupToSave)
}) })
if err != nil { if err != nil {
return err return err
@@ -173,11 +173,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.DeleteNameServerGroup(ctx, store.LockingStrengthUpdate, accountID, nsGroupID) return transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID)
}) })
if err != nil { if err != nil {
return err return err
@@ -202,7 +202,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
} }
func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
@@ -216,7 +216,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return err return err
} }
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -226,7 +226,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return err return err
} }
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, nameserverGroup.Groups)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -14,8 +14,8 @@ import (
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
) )
type Manager interface { type Manager interface {
@@ -56,7 +56,7 @@ func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID stri
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetAccountNetworks(ctx, store.LockingStrengthShare, accountID) return m.store.GetAccountNetworks(ctx, store.LockingStrengthNone, accountID)
} }
func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
@@ -73,7 +73,7 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network
unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID) unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID)
defer unlock() defer unlock()
err = m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) err = m.store.SaveNetwork(ctx, network)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to save network: %w", err) return nil, fmt.Errorf("failed to save network: %w", err)
} }
@@ -92,7 +92,7 @@ func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, network
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) return m.store.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
} }
func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
@@ -114,7 +114,7 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta()) m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta())
return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) return network, m.store.SaveNetwork(ctx, network)
} }
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
@@ -162,12 +162,12 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
eventsToStore = append(eventsToStore, event) eventsToStore = append(eventsToStore, event)
} }
err = transaction.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID) err = transaction.DeleteNetwork(ctx, accountID, networkID)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete network: %w", err) return fmt.Errorf("failed to delete network: %w", err)
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }

View File

@@ -12,10 +12,10 @@ import (
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
nbtypes "github.com/netbirdio/netbird/management/server/types" nbtypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/status"
) )
type Manager interface { type Manager interface {
@@ -57,7 +57,7 @@ func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthShare, accountID, networkID) return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
} }
func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) { func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
@@ -69,7 +69,7 @@ func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
} }
func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
@@ -81,7 +81,7 @@ func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID,
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network resources: %w", err) return nil, fmt.Errorf("failed to get network resources: %w", err)
} }
@@ -113,7 +113,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
var eventsToStore []func() var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) _, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil { if err == nil {
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name) return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
} }
@@ -123,7 +123,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
return fmt.Errorf("failed to get network: %w", err) return fmt.Errorf("failed to get network: %w", err)
} }
err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) err = transaction.SaveNetworkResource(ctx, resource)
if err != nil { if err != nil {
return fmt.Errorf("failed to save network resource: %w", err) return fmt.Errorf("failed to save network resource: %w", err)
} }
@@ -145,7 +145,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
eventsToStore = append(eventsToStore, event) eventsToStore = append(eventsToStore, event)
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID) err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -174,7 +174,7 @@ func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networ
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network resource: %w", err) return nil, fmt.Errorf("failed to get network resource: %w", err)
} }
@@ -218,22 +218,22 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID) return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID)
} }
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) _, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network resource: %w", err) return fmt.Errorf("failed to get network resource: %w", err)
} }
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil && oldResource.ID != resource.ID { if err == nil && oldResource.ID != resource.ID {
return status.Errorf(status.InvalidArgument, "new resource name already exists") return status.Errorf(status.InvalidArgument, "new resource name already exists")
} }
oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network resource: %w", err) return fmt.Errorf("failed to get network resource: %w", err)
} }
err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) err = transaction.SaveNetworkResource(ctx, resource)
if err != nil { if err != nil {
return fmt.Errorf("failed to save network resource: %w", err) return fmt.Errorf("failed to save network resource: %w", err)
} }
@@ -248,7 +248,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network)) m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network))
}) })
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID) err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -325,7 +325,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
return fmt.Errorf("failed to delete resource: %w", err) return fmt.Errorf("failed to delete resource: %w", err)
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -375,7 +375,7 @@ func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transacti
eventsToStore = append(eventsToStore, event) eventsToStore = append(eventsToStore, event)
} }
err = transaction.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID) err = transaction.DeleteNetworkResource(ctx, accountID, resourceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete network resource: %w", err) return nil, fmt.Errorf("failed to delete network resource: %w", err)
} }

View File

@@ -14,8 +14,8 @@ import (
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
) )
type Manager interface { type Manager interface {
@@ -54,7 +54,7 @@ func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthShare, accountID, networkID) return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
} }
func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) { func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) {
@@ -66,7 +66,7 @@ func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network routers: %w", err) return nil, fmt.Errorf("failed to get network routers: %w", err)
} }
@@ -93,7 +93,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
var network *networkTypes.Network var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network: %w", err) return fmt.Errorf("failed to get network: %w", err)
} }
@@ -104,12 +104,12 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
router.ID = xid.New().String() router.ID = xid.New().String()
err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) err = transaction.SaveNetworkRouter(ctx, router)
if err != nil { if err != nil {
return fmt.Errorf("failed to create network router: %w", err) return fmt.Errorf("failed to create network router: %w", err)
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID) err = transaction.IncrementNetworkSerial(ctx, router.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -136,7 +136,7 @@ func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkI
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthShare, accountID, routerID) router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthNone, accountID, routerID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network router: %w", err) return nil, fmt.Errorf("failed to get network router: %w", err)
} }
@@ -162,7 +162,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
var network *networkTypes.Network var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network: %w", err) return fmt.Errorf("failed to get network: %w", err)
} }
@@ -171,12 +171,12 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID) return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
} }
err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) err = transaction.SaveNetworkRouter(ctx, router)
if err != nil { if err != nil {
return fmt.Errorf("failed to update network router: %w", err) return fmt.Errorf("failed to update network router: %w", err)
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID) err = transaction.IncrementNetworkSerial(ctx, router.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -213,7 +213,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
return fmt.Errorf("failed to delete network router: %w", err) return fmt.Errorf("failed to delete network router: %w", err)
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -232,7 +232,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
} }
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) { func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err) return nil, fmt.Errorf("failed to get network: %w", err)
} }
@@ -246,7 +246,7 @@ func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction
return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID) return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID)
} }
err = transaction.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID) err = transaction.DeleteNetworkRouter(ctx, accountID, routerID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete network router: %w", err) return nil, fmt.Errorf("failed to delete network router: %w", err)
} }

View File

@@ -17,28 +17,28 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin. // the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -48,7 +48,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return nil, status.NewPermissionValidationError(err) return nil, status.NewPermissionValidationError(err)
} }
accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, nameFilter, ipFilter) accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -58,7 +58,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return accountPeers, nil return accountPeers, nil
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get account settings: %w", err) return nil, fmt.Errorf("failed to get account settings: %w", err)
} }
@@ -130,7 +130,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
} }
if peer.AddedWithSSOLogin() { if peer.AddedWithSSOLogin() {
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -173,7 +173,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
peer.Location.CountryCode = location.Country.ISOCode peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID peer.Location.GeoNameID = location.City.GeonameID
err = transaction.SavePeerLocation(ctx, store.LockingStrengthUpdate, accountID, peer) err = transaction.SavePeerLocation(ctx, accountID, peer)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
} }
@@ -182,7 +182,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio
log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected) log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
err := transaction.SavePeerStatus(ctx, store.LockingStrengthUpdate, accountID, peer.ID, *newStatus) err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -219,7 +219,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return err return err
} }
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -281,7 +281,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
inactivityExpirationChanged = true inactivityExpirationChanged = true
} }
return transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer) return transaction.SavePeer(ctx, accountID, peer)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -346,7 +346,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID) peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
if err != nil { if err != nil {
return err return err
} }
@@ -383,7 +383,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return fmt.Errorf("failed to delete peer: %w", err) return fmt.Errorf("failed to delete peer: %w", err)
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -617,7 +617,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}() }()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer) err = transaction.AddPeerToAccount(ctx, newPeer)
if err != nil { if err != nil {
return err return err
} }
@@ -658,7 +658,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
} }
err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) err = transaction.IncrementNetworkSerial(ctx, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@@ -734,7 +734,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
var err error var err error
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -746,7 +746,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
} }
if peer.UserID != "" { if peer.UserID != "" {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
if err != nil { if err != nil {
return err return err
} }
@@ -774,7 +774,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
if updated { if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate() am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil { if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return err return err
} }
@@ -849,7 +849,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
var isPeerUpdated bool var isPeerUpdated bool
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -911,7 +911,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
} }
if shouldStorePeer { if shouldStorePeer {
if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil { if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return err return err
} }
} }
@@ -934,7 +934,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
// getPeerPostureChecks returns the posture checks for the peer. // getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -958,7 +958,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...) peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
} }
peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, peerPostureChecksIDs) peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, peerPostureChecksIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -973,7 +973,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
continue continue
} }
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, rule.Sources) sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -998,7 +998,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired // with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests. // and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login types.PeerLogin) error { func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login types.PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey) peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, login.WireGuardPubKey)
if err != nil { if err != nil {
return err return err
} }
@@ -1009,7 +1009,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil return nil
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -1080,7 +1080,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
// If peer was expired before and if it reached this point, it is re-authenticated. // If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer. // UserID is present, meaning that JWT validation passed successfully in the API layer.
peer = peer.UpdateLastLogin() peer = peer.UpdateLastLogin()
err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, peer.AccountID, peer) err = transaction.SavePeer(ctx, peer.AccountID, peer)
if err != nil { if err != nil {
return err return err
} }
@@ -1090,7 +1090,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
log.WithContext(ctx).Debugf("failed to update user last login: %v", err) log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, peer.AccountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, peer.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get account settings: %w", err) return fmt.Errorf("failed to get account settings: %w", err)
} }
@@ -1132,7 +1132,7 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Se
// GetPeer for a given accountID, peerID and userID error if not found. // GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1145,7 +1145,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return peer, nil return peer, nil
} }
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1171,7 +1171,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
// it is also possible that user doesn't own the peer but some of his peers have access to it, // it is also possible that user doesn't own the peer but some of his peers have access to it,
// this is a valid case, show the peer as well. // this is a valid case, show the peer as well.
userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID) userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1394,7 +1394,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
// If there is no peer that expires this function returns false and a duration of 0. // If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are connected. // This function only considers peers that haven't been expired yet and that are connected.
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1404,7 +1404,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
return 0, false return 0, false
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err) log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1438,7 +1438,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
// If there is no peer that expires this function returns false and a duration of 0. // If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are not connected. // This function only considers peers that haven't been expired yet and that are not connected.
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err) log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1448,7 +1448,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
return 0, false return 0, false
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err) log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1479,12 +1479,12 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
// getExpiredPeers returns peers that have been expired. // getExpiredPeers returns peers that have been expired.
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1502,12 +1502,12 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
// getInactivePeers returns peers that have been expired by inactivity // getInactivePeers returns peers that have been expired by inactivity
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1530,7 +1530,7 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p
// getPeerGroupIDs returns the IDs of the groups that the peer is part of. // getPeerGroupIDs returns the IDs of the groups that the peer is part of.
func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) {
return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthShare, accountID, peerID) return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
} }
// IsPeerInActiveGroup checks if the given peer is part of a group that is used // IsPeerInActiveGroup checks if the given peer is part of a group that is used
@@ -1548,7 +1548,7 @@ func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
var peerDeletedEvents []func() var peerDeletedEvents []func()
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1568,7 +1568,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
return nil, err return nil, err
} }
if err = transaction.DeletePeer(ctx, store.LockingStrengthUpdate, accountID, peer.ID); err != nil { if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil {
return nil, err return nil, err
} }
@@ -1624,7 +1624,7 @@ func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transac
// isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account. // isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account.
func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) { func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err)
return false, nil return false, nil

View File

@@ -38,8 +38,6 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
@@ -47,6 +45,8 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route" nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
) )
func TestPeer_LoginExpired(t *testing.T) { func TestPeer_LoginExpired(t *testing.T) {
@@ -1307,7 +1307,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels) assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels)
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key) peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, addedPeer.Key)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID) assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.UserID, existingUserID) assert.Equal(t, peer.UserID, existingUserID)
@@ -1442,7 +1442,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
assert.NotNil(t, addedPeer, "addedPeer should not be nil on success") assert.NotNil(t, addedPeer, "addedPeer should not be nil on success")
assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch") assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch")
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, currentPeer.Key) peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, currentPeer.Key)
require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key) require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key)
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store") assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store") assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store")
@@ -1528,7 +1528,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
require.Error(t, err) require.Error(t, err)
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key) _, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key)
require.Error(t, err) require.Error(t, err)
account, err := s.GetAccount(context.Background(), existingAccountID) account, err := s.GetAccount(context.Background(), existingAccountID)
@@ -1699,7 +1699,7 @@ func Test_LoginPeer(t *testing.T) {
assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer") assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer")
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, loginInput.WireGuardPubKey) peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, loginInput.WireGuardPubKey)
require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey) require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey)
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store") assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore") assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore")
@@ -2160,10 +2160,10 @@ func Test_IsUniqueConstraintError(t *testing.T) {
} }
t.Cleanup(cleanup) t.Cleanup(cleanup)
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) err = s.AddPeerToAccount(context.Background(), peer)
assert.NoError(t, err) assert.NoError(t, err)
err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) err = s.AddPeerToAccount(context.Background(), peer)
result := isUniqueConstraintError(err) result := isUniqueConstraintError(err)
assert.True(t, result) assert.True(t, result)
}) })

View File

@@ -10,8 +10,8 @@ import (
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/status"
) )
type Manager interface { type Manager interface {
@@ -42,7 +42,7 @@ func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID str
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
} }
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
@@ -52,12 +52,12 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string)
} }
if !allowed { if !allowed {
return m.store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID) return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
} }
return m.store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
} }
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID) return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
} }

View File

@@ -11,9 +11,9 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
type Manager interface { type Manager interface {
@@ -45,7 +45,7 @@ func (m *managerImpl) ValidateUserPermissions(
return true, nil return true, nil
} }
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@@ -6,11 +6,11 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID) return am.Store.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID)
} }
// SavePolicy in the store // SavePolicy in the store
@@ -61,7 +61,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
@@ -71,7 +71,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
saveFunc = transaction.SavePolicy saveFunc = transaction.SavePolicy
} }
return saveFunc(ctx, store.LockingStrengthUpdate, policy) return saveFunc(ctx, policy)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -113,11 +113,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.DeletePolicy(ctx, store.LockingStrengthUpdate, accountID, policyID) return transaction.DeletePolicy(ctx, accountID, policyID)
}) })
if err != nil { if err != nil {
return err return err
@@ -142,13 +142,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
} }
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. // arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
if isUpdate { if isUpdate {
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -173,7 +173,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
// validatePolicy validates the policy and its rules. // validatePolicy validates the policy and its rules.
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
if policy.ID != "" { if policy.ID != "" {
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -182,12 +182,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
policy.AccountID = accountID policy.AccountID = accountID
} }
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups()) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
if err != nil { if err != nil {
return err return err
} }
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks) postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -13,9 +13,9 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
) )
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
} }
// SavePostureChecks saves a posture check. // SavePostureChecks saves a posture check.
@@ -62,7 +62,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
@@ -70,7 +70,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
} }
postureChecks.AccountID = accountID postureChecks.AccountID = accountID
return transaction.SavePostureChecks(ctx, store.LockingStrengthUpdate, postureChecks) return transaction.SavePostureChecks(ctx, postureChecks)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -101,7 +101,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
var postureChecks *posture.Checks var postureChecks *posture.Checks
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
if err != nil { if err != nil {
return err return err
} }
@@ -110,11 +110,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.DeletePostureChecks(ctx, store.LockingStrengthUpdate, accountID, postureChecksID) return transaction.DeletePostureChecks(ctx, accountID, postureChecksID)
}) })
if err != nil { if err != nil {
return err return err
@@ -135,7 +135,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
} }
// getPeerPostureChecks returns the posture checks applied for a given peer. // getPeerPostureChecks returns the posture checks applied for a given peer.
@@ -161,7 +161,7 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, pe
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -190,14 +190,14 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account
// If the posture check already has an ID, verify its existence in the store. // If the posture check already has an ID, verify its existence in the store.
if postureChecks.ID != "" { if postureChecks.ID != "" {
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil { if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID); err != nil {
return err return err
} }
return nil return nil
} }
// For new posture checks, ensure no duplicates by name. // For new posture checks, ensure no duplicates by name.
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -259,7 +259,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error { func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -9,15 +9,15 @@ import (
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
) )
// GetRoute gets a route object from account and route IDs // GetRoute gets a route object from account and route IDs
@@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID)) return am.Store.GetRouteByID(ctx, store.LockingStrengthNone, accountID, string(routeID))
} }
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
@@ -59,7 +59,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
seenPeers[string(prefixRoute.ID)] = true seenPeers[string(prefixRoute.ID)] = true
} }
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups) peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, prefixRoute.PeerGroups)
if err != nil { if err != nil {
return err return err
} }
@@ -83,7 +83,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
if peerID := checkRoute.Peer; peerID != "" { if peerID := checkRoute.Peer; peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group // check that peerID exists and is not in any route as single peer or part of the group
_, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID) _, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
@@ -104,7 +104,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
} }
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers) peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, group.Peers)
if err != nil { if err != nil {
return err return err
} }
@@ -181,11 +181,11 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute) return transaction.SaveRoute(ctx, newRoute)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -238,11 +238,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
} }
routeToSave.AccountID = accountID routeToSave.AccountID = accountID
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave) return transaction.SaveRoute(ctx, routeToSave)
}) })
if err != nil { if err != nil {
return err return err
@@ -284,11 +284,11 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return err return err
} }
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
return err return err
} }
return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID)) return transaction.DeleteRoute(ctx, accountID, string(routeID))
}) })
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
@@ -310,7 +310,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
} }
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error { func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
@@ -353,7 +353,7 @@ func validateRoute(ctx context.Context, transaction store.Store, accountID strin
// validateRouteGroups validates the route groups and returns the validated groups map. // validateRouteGroups validates the route groups and returns the validated groups map.
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) { func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups) groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate) groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupsToValidate)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -494,7 +494,7 @@ func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, ro
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix // GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) { func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -14,7 +14,6 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -27,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
) )
const ( const (
@@ -1100,7 +1100,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id) groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, account.Id)
require.NoError(t, err) require.NoError(t, err)
var groupHA1, groupHA2 *types.Group var groupHA1, groupHA2 *types.Group
for _, group := range groups { for _, group := range groups {

View File

@@ -11,10 +11,10 @@ import (
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/management/status"
) )
type Manager interface { type Manager interface {
@@ -60,7 +60,7 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
return nil, fmt.Errorf("get extra settings: %w", err) return nil, fmt.Errorf("get extra settings: %w", err)
} }
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account settings: %w", err) return nil, fmt.Errorf("get account settings: %w", err)
} }
@@ -82,7 +82,7 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (*
return nil, fmt.Errorf("get extra settings: %w", err) return nil, fmt.Errorf("get extra settings: %w", err)
} }
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account settings: %w", err) return nil, fmt.Errorf("get account settings: %w", err)
} }

View File

@@ -10,10 +10,10 @@ import (
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -81,7 +81,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey) events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, events...)
return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, setupKey) return transaction.SaveSetupKey(ctx, setupKey)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -127,7 +127,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
} }
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id) oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthUpdate, accountID, keyToSave.Id)
if err != nil { if err != nil {
return err return err
} }
@@ -148,7 +148,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey) events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey)
eventsToStore = append(eventsToStore, events...) eventsToStore = append(eventsToStore, events...)
return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, newKey) return transaction.SaveSetupKey(ctx, newKey)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -175,7 +175,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
} }
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
@@ -188,7 +188,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -214,12 +214,12 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
var deletedSetupKey *types.SetupKey var deletedSetupKey *types.SetupKey
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthUpdate, accountID, keyID)
if err != nil { if err != nil {
return err return err
} }
return transaction.DeleteSetupKey(ctx, store.LockingStrengthUpdate, accountID, keyID) return transaction.DeleteSetupKey(ctx, accountID, keyID)
}) })
if err != nil { if err != nil {
return err return err
@@ -231,7 +231,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
} }
func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error { func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, autoGroupIDs)
if err != nil { if err != nil {
return err return err
} }
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
var eventsToStore []func() var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups) modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
return nil return nil

View File

@@ -29,11 +29,11 @@ import (
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
) )
const ( const (
@@ -270,7 +270,7 @@ func generateAccountSQLTypes(account *types.Account) {
func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, newDomain string) { func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, newDomain string) {
var acc types.Account var acc types.Account
var domain string var domain string
result := s.db.Model(&acc).Select("domain").Where(idQueryCondition, accountID).First(&domain) result := s.db.Model(&acc).Select("domain").Where(idQueryCondition, accountID).Take(&domain)
if result.Error != nil { if result.Error != nil {
if !errors.Is(result.Error, gorm.ErrRecordNotFound) { if !errors.Is(result.Error, gorm.ErrRecordNotFound) {
log.WithContext(ctx).Errorf("error when getting account %s from the store to check domain: %s", accountID, result.Error) log.WithContext(ctx).Errorf("error when getting account %s from the store to check domain: %s", accountID, result.Error)
@@ -323,23 +323,26 @@ func (s *SqlStore) SaveInstallationID(_ context.Context, ID string) error {
func (s *SqlStore) GetInstallationID() string { func (s *SqlStore) GetInstallationID() string {
var installation installation var installation installation
if result := s.db.First(&installation, idQueryCondition, s.installationPK); result.Error != nil { if result := s.db.Take(&installation, idQueryCondition, s.installationPK); result.Error != nil {
return "" return ""
} }
return installation.InstallationIDValue return installation.InstallationIDValue
} }
func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error { func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
peerCopy := peer.Copy() peerCopy := peer.Copy()
peerCopy.AccountID = accountID peerCopy.AccountID = accountID
err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Transaction(func(tx *gorm.DB) error { err := s.db.Transaction(func(tx *gorm.DB) error {
// check if peer exists before saving // check if peer exists before saving
var peerID string var peerID string
result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID) result := tx.Model(&nbpeer.Peer{}).Select("id").Take(&peerID, accountAndIDQueryCondition, accountID, peer.ID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, peerNotFoundFMT, peer.ID)
}
return result.Error return result.Error
} }
@@ -385,7 +388,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
return nil return nil
} }
func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, peerStatus nbpeer.PeerStatus) error { func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus peerCopy.Status = &peerStatus
@@ -393,7 +396,7 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStren
"peer_status_last_seen", "peer_status_connected", "peer_status_last_seen", "peer_status_connected",
"peer_status_login_expired", "peer_status_required_approval", "peer_status_login_expired", "peer_status_required_approval",
} }
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). result := s.db.Model(&nbpeer.Peer{}).
Select(fieldsToUpdate). Select(fieldsToUpdate).
Where(accountAndIDQueryCondition, accountID, peerID). Where(accountAndIDQueryCondition, accountID, peerID).
Updates(&peerCopy) Updates(&peerCopy)
@@ -408,14 +411,14 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStren
return nil return nil
} }
func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peerWithLocation *nbpeer.Peer) error { func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer var peerCopy nbpeer.Peer
// Since the location field has been migrated to JSON serialization, // Since the location field has been migrated to JSON serialization,
// updating the struct ensures the correct data format is inserted into the database. // updating the struct ensures the correct data format is inserted into the database.
peerCopy.Location = peerWithLocation.Location peerCopy.Location = peerWithLocation.Location
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). result := s.db.Model(&nbpeer.Peer{}).
Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID). Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
Updates(peerCopy) Updates(peerCopy)
@@ -431,12 +434,12 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr
} }
// SaveUsers saves the given list of users to the database. // SaveUsers saves the given list of users to the database.
func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error { func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error {
if len(users) == 0 { if len(users) == 0 {
return nil return nil
} }
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}, clause.OnConflict{UpdateAll: true}).Create(&users) result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&users)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error) log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save users to store") return status.Errorf(status.Internal, "failed to save users to store")
@@ -445,8 +448,8 @@ func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength,
} }
// SaveUser saves the given user to the database. // SaveUser saves the given user to the database.
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error { func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) result := s.db.Save(user)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error) log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save user to store") return status.Errorf(status.Internal, "failed to save user to store")
@@ -455,7 +458,7 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u
} }
// CreateGroups creates the given list of groups to the database. // CreateGroups creates the given list of groups to the database.
func (s *SqlStore) CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error { func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error {
if len(groups) == 0 { if len(groups) == 0 {
return nil return nil
} }
@@ -463,7 +466,6 @@ func (s *SqlStore) CreateGroups(ctx context.Context, lockStrength LockingStrengt
return s.db.Transaction(func(tx *gorm.DB) error { return s.db.Transaction(func(tx *gorm.DB) error {
result := tx. result := tx.
Clauses( Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{ clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true, UpdateAll: true,
@@ -481,7 +483,7 @@ func (s *SqlStore) CreateGroups(ctx context.Context, lockStrength LockingStrengt
} }
// UpdateGroups updates the given list of groups to the database. // UpdateGroups updates the given list of groups to the database.
func (s *SqlStore) UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error { func (s *SqlStore) UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error {
if len(groups) == 0 { if len(groups) == 0 {
return nil return nil
} }
@@ -489,7 +491,6 @@ func (s *SqlStore) UpdateGroups(ctx context.Context, lockStrength LockingStrengt
return s.db.Transaction(func(tx *gorm.DB) error { return s.db.Transaction(func(tx *gorm.DB) error {
result := tx. result := tx.
Clauses( Clauses(
clause.Locking{Strength: string(lockStrength)},
clause.OnConflict{ clause.OnConflict{
Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}},
UpdateAll: true, UpdateAll: true,
@@ -517,7 +518,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
} }
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) { func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) {
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthNone, domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -536,7 +537,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
result := tx.Model(&types.Account{}).Select("id"). result := tx.Model(&types.Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
strings.ToLower(domain), true, types.PrivateCategory, strings.ToLower(domain), true, types.PrivateCategory,
).First(&accountID) ).Take(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
@@ -550,7 +551,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) { func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) {
var key types.SetupKey var key types.SetupKey
result := s.db.Select("account_id").First(&key, GetKeyQueryCondition(s), setupKey) result := s.db.Select("account_id").Take(&key, GetKeyQueryCondition(s), setupKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewSetupKeyNotFoundError(setupKey) return nil, status.NewSetupKeyNotFoundError(setupKey)
@@ -568,7 +569,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) { func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
var token types.PersonalAccessToken var token types.PersonalAccessToken
result := s.db.First(&token, "hashed_token = ?", hashedToken) result := s.db.Take(&token, "hashed_token = ?", hashedToken)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -589,7 +590,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
var user types.User var user types.User
result := tx. result := tx.
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id"). Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
Where("personal_access_tokens.id = ?", patID).First(&user) Where("personal_access_tokens.id = ?", patID).Take(&user)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPATNotFoundError(patID) return nil, status.NewPATNotFoundError(patID)
@@ -608,7 +609,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
} }
var user types.User var user types.User
result := tx.First(&user, idQueryCondition, userID) result := tx.Take(&user, idQueryCondition, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID) return nil, status.NewUserNotFoundError(userID)
@@ -619,16 +620,14 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil return &user, nil
} }
func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error { func (s *SqlStore) DeleteUser(ctx context.Context, accountID, userID string) error {
err := s.db.Transaction(func(tx *gorm.DB) error { err := s.db.Transaction(func(tx *gorm.DB) error {
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). result := tx.Delete(&types.PersonalAccessToken{}, "user_id = ?", userID)
Delete(&types.PersonalAccessToken{}, "user_id = ?", userID)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
} }
return tx.Clauses(clause.Locking{Strength: string(lockStrength)}). return tx.Delete(&types.User{}, accountAndIDQueryCondition, accountID, userID).Error
Delete(&types.User{}, accountAndIDQueryCondition, accountID, userID).Error
}) })
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to delete user from the store: %s", err) log.WithContext(ctx).Errorf("failed to delete user from the store: %s", err)
@@ -664,7 +663,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
} }
var user types.User var user types.User
result := tx.First(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner) result := tx.Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed")
@@ -761,7 +760,7 @@ func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStren
var accountMeta types.AccountMeta var accountMeta types.AccountMeta
result := tx.Model(&types.Account{}). result := tx.Model(&types.Account{}).
First(&accountMeta, idQueryCondition, accountID) Take(&accountMeta, idQueryCondition, accountID)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account meta %s from the store: %s", accountID, result.Error) log.WithContext(ctx).Errorf("error when getting account meta %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -776,7 +775,7 @@ func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStren
// GetAccountOnboarding retrieves the onboarding information for a specific account. // GetAccountOnboarding retrieves the onboarding information for a specific account.
func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) { func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) {
var accountOnboarding types.AccountOnboarding var accountOnboarding types.AccountOnboarding
result := s.db.Model(&accountOnboarding).First(&accountOnboarding, accountIDCondition, accountID) result := s.db.Model(&accountOnboarding).Take(&accountOnboarding, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountOnboardingNotFoundError(accountID) return nil, status.NewAccountOnboardingNotFoundError(accountID)
@@ -813,7 +812,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
Omit("GroupsG"). Omit("GroupsG").
Preload("UsersG.PATsG"). // have to be specifies as this is nester reference Preload("UsersG.PATsG"). // have to be specifies as this is nester reference
Preload(clause.Associations). Preload(clause.Associations).
First(&account, idQueryCondition, accountID) Take(&account, idQueryCondition, accountID)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -888,7 +887,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) { func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) {
var user types.User var user types.User
result := s.db.Select("account_id").First(&user, idQueryCondition, userID) result := s.db.Select("account_id").Take(&user, idQueryCondition, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -905,7 +904,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) { func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
var peer nbpeer.Peer var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID) result := s.db.Select("account_id").Take(&peer, idQueryCondition, peerID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -922,7 +921,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*type
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) { func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) {
var peer nbpeer.Peer var peer nbpeer.Peer
result := s.db.Select("account_id").First(&peer, GetKeyQueryCondition(s), peerKey) result := s.db.Select("account_id").Take(&peer, GetKeyQueryCondition(s), peerKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -954,7 +953,7 @@ func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) {
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
var peer nbpeer.Peer var peer nbpeer.Peer
var accountID string var accountID string
result := s.db.Model(&peer).Select("account_id").Where(GetKeyQueryCondition(s), peerKey).First(&accountID) result := s.db.Model(&peer).Select("account_id").Where(GetKeyQueryCondition(s), peerKey).Take(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -973,7 +972,7 @@ func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength Lockin
var accountID string var accountID string
result := tx.Model(&types.User{}). result := tx.Model(&types.User{}).
Select("account_id").Where(idQueryCondition, userID).First(&accountID) Select("account_id").Where(idQueryCondition, userID).Take(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@@ -992,7 +991,7 @@ func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength Lockin
var accountID string var accountID string
result := tx.Model(&nbpeer.Peer{}). result := tx.Model(&nbpeer.Peer{}).
Select("account_id").Where(idQueryCondition, peerID).First(&accountID) Select("account_id").Where(idQueryCondition, peerID).Take(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "peer %s account not found", peerID) return "", status.Errorf(status.NotFound, "peer %s account not found", peerID)
@@ -1005,7 +1004,7 @@ func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength Lockin
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
var accountID string var accountID string
result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(GetKeyQueryCondition(s), setupKey).First(&accountID) result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(GetKeyQueryCondition(s), setupKey).Take(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.NewSetupKeyNotFoundError(setupKey) return "", status.NewSetupKeyNotFoundError(setupKey)
@@ -1082,7 +1081,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
} }
var accountNetwork types.AccountNetwork var accountNetwork types.AccountNetwork
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID) return nil, status.NewAccountNotFoundError(accountID)
} }
@@ -1098,7 +1097,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
} }
var peer nbpeer.Peer var peer nbpeer.Peer
result := tx.First(&peer, GetKeyQueryCondition(s), peerKey) result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -1117,7 +1116,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
} }
var accountSettings types.AccountSettings var accountSettings types.AccountSettings
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found") return nil, status.Errorf(status.NotFound, "settings not found")
} }
@@ -1134,7 +1133,7 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
var createdBy string var createdBy string
result := tx.Model(&types.Account{}). result := tx.Model(&types.Account{}).
Select("created_by").First(&createdBy, idQueryCondition, accountID) Select("created_by").Take(&createdBy, idQueryCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.NewAccountNotFoundError(accountID) return "", status.NewAccountNotFoundError(accountID)
@@ -1148,7 +1147,7 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
// SaveUserLastLogin stores the last login time for a user in DB. // SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user types.User var user types.User
result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID) result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID) return status.NewUserNotFoundError(userID)
@@ -1171,7 +1170,7 @@ func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *p
} }
var postureCheck posture.Checks var postureCheck posture.Checks
err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).Take(&postureCheck).Error
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1336,7 +1335,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
var setupKey types.SetupKey var setupKey types.SetupKey
result := tx. result := tx.
First(&setupKey, GetKeyQueryCondition(s), key) Take(&setupKey, GetKeyQueryCondition(s), key)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -1446,7 +1445,7 @@ func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) e
// AddResourceToGroup adds a resource to a group. Method always needs to run n a transaction // AddResourceToGroup adds a resource to a group. Method always needs to run n a transaction
func (s *SqlStore) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error { func (s *SqlStore) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error {
var group types.Group var group types.Group
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).Take(&group)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewGroupNotFoundError(groupID) return status.NewGroupNotFoundError(groupID)
@@ -1473,7 +1472,7 @@ func (s *SqlStore) AddResourceToGroup(ctx context.Context, accountId string, gro
// RemoveResourceFromGroup removes a resource from a group. Method always needs to run in a transaction // RemoveResourceFromGroup removes a resource from a group. Method always needs to run in a transaction
func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error { func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error {
var group types.Group var group types.Group
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).Take(&group)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewGroupNotFoundError(groupID) return status.NewGroupNotFoundError(groupID)
@@ -1593,7 +1592,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
return peers, nil return peers, nil
} }
func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error { func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.Create(peer).Error; err != nil { if err := s.db.Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err) return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
} }
@@ -1610,7 +1609,7 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
var peer *nbpeer.Peer var peer *nbpeer.Peer
result := tx. result := tx.
First(&peer, accountAndIDQueryCondition, accountID, peerID) Take(&peer, accountAndIDQueryCondition, accountID, peerID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPeerNotFoundError(peerID) return nil, status.NewPeerNotFoundError(peerID)
@@ -1705,9 +1704,8 @@ func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength Lockin
} }
// DeletePeer removes a peer from the store. // DeletePeer removes a peer from the store.
func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error { func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID)
Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err) log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete peer from store") return status.Errorf(status.Internal, "failed to delete peer from store")
@@ -1720,9 +1718,8 @@ func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength,
return nil return nil
} }
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store") return status.Errorf(status.Internal, "failed to increment network serial count in store")
@@ -1772,7 +1769,7 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
var accountDNSSettings types.AccountDNSSettings var accountDNSSettings types.AccountDNSSettings
result := tx.Model(&types.Account{}). result := tx.Model(&types.Account{}).
First(&accountDNSSettings, idQueryCondition, accountID) Take(&accountDNSSettings, idQueryCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID) return nil, status.NewAccountNotFoundError(accountID)
@@ -1792,7 +1789,7 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
var accountID string var accountID string
result := tx.Model(&types.Account{}). result := tx.Model(&types.Account{}).
Select("id").First(&accountID, idQueryCondition, id) Select("id").Take(&accountID, idQueryCondition, id)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return false, nil return false, nil
@@ -1812,7 +1809,7 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
var account types.Account var account types.Account
result := tx.Model(&types.Account{}).Select("domain", "domain_category"). result := tx.Model(&types.Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account) Where(idQueryCondition, accountID).Take(&account)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", "", status.Errorf(status.NotFound, "account not found") return "", "", status.Errorf(status.NotFound, "account not found")
@@ -1831,7 +1828,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
} }
var group *types.Group var group *types.Group
result := tx.Preload(clause.Associations).First(&group, accountAndIDQueryCondition, accountID, groupID) result := tx.Preload(clause.Associations).Take(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewGroupNotFoundError(groupID) return nil, status.NewGroupNotFoundError(groupID)
@@ -1900,7 +1897,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
} }
// CreateGroup creates a group in the store. // CreateGroup creates a group in the store.
func (s *SqlStore) CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { func (s *SqlStore) CreateGroup(ctx context.Context, group *types.Group) error {
if group == nil { if group == nil {
return status.Errorf(status.InvalidArgument, "group is nil") return status.Errorf(status.InvalidArgument, "group is nil")
} }
@@ -1914,7 +1911,7 @@ func (s *SqlStore) CreateGroup(ctx context.Context, lockStrength LockingStrength
} }
// UpdateGroup updates a group in the store. // UpdateGroup updates a group in the store.
func (s *SqlStore) UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { func (s *SqlStore) UpdateGroup(ctx context.Context, group *types.Group) error {
if group == nil { if group == nil {
return status.Errorf(status.InvalidArgument, "group is nil") return status.Errorf(status.InvalidArgument, "group is nil")
} }
@@ -1928,9 +1925,8 @@ func (s *SqlStore) UpdateGroup(ctx context.Context, lockStrength LockingStrength
} }
// DeleteGroup deletes a group from the database. // DeleteGroup deletes a group from the database.
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { func (s *SqlStore) DeleteGroup(ctx context.Context, accountID, groupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.Select(clause.Associations).
Select(clause.Associations).
Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID) Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
@@ -1945,9 +1941,8 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength
} }
// DeleteGroups deletes groups from the database. // DeleteGroups deletes groups from the database.
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { func (s *SqlStore) DeleteGroups(ctx context.Context, accountID string, groupIDs []string) error {
result := s.db.Clauses(clause.Locking{Strength: string(strength)}). result := s.db.Select(clause.Associations).
Select(clause.Associations).
Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
@@ -1985,7 +1980,7 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng
var policy *types.Policy var policy *types.Policy
result := tx.Preload(clause.Associations). result := tx.Preload(clause.Associations).
First(&policy, accountAndIDQueryCondition, accountID, policyID) Take(&policy, accountAndIDQueryCondition, accountID, policyID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewPolicyNotFoundError(policyID) return nil, status.NewPolicyNotFoundError(policyID)
@@ -1997,8 +1992,8 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng
return policy, nil return policy, nil
} }
func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error { func (s *SqlStore) CreatePolicy(ctx context.Context, policy *types.Policy) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy) result := s.db.Create(policy)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error) log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error)
return status.Errorf(status.Internal, "failed to create policy in store") return status.Errorf(status.Internal, "failed to create policy in store")
@@ -2008,9 +2003,8 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrengt
} }
// SavePolicy saves a policy to the database. // SavePolicy saves a policy to the database.
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error { func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error {
result := s.db.Session(&gorm.Session{FullSaveAssociations: true}). result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Save(policy)
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err) log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
return status.Errorf(status.Internal, "failed to save policy to store") return status.Errorf(status.Internal, "failed to save policy to store")
@@ -2018,13 +2012,13 @@ func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength,
return nil return nil
} }
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error {
return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil { if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil {
return fmt.Errorf("delete policy rules: %w", err) return fmt.Errorf("delete policy rules: %w", err)
} }
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). result := tx.
Where(accountAndIDQueryCondition, accountID, policyID). Where(accountAndIDQueryCondition, accountID, policyID).
Delete(&types.Policy{}) Delete(&types.Policy{})
@@ -2067,7 +2061,7 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin
var postureCheck *posture.Checks var postureCheck *posture.Checks
result := tx. result := tx.
First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID) Take(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPostureChecksNotFoundError(postureChecksID) return nil, status.NewPostureChecksNotFoundError(postureChecksID)
@@ -2102,8 +2096,8 @@ func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength Locki
} }
// SavePostureChecks saves a posture checks to the database. // SavePostureChecks saves a posture checks to the database.
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { func (s *SqlStore) SavePostureChecks(ctx context.Context, postureCheck *posture.Checks) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) result := s.db.Save(postureCheck)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error) log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save posture checks to store") return status.Errorf(status.Internal, "failed to save posture checks to store")
@@ -2113,9 +2107,8 @@ func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingSt
} }
// DeletePostureChecks deletes a posture checks from the database. // DeletePostureChecks deletes a posture checks from the database.
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error { func (s *SqlStore) DeletePostureChecks(ctx context.Context, accountID, postureChecksID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID)
Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error) log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete posture checks from store") return status.Errorf(status.Internal, "failed to delete posture checks from store")
@@ -2130,9 +2123,13 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
// GetAccountRoutes retrieves network routes for an account. // GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var routes []*route.Route var routes []*route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := tx.Find(&routes, accountIDCondition, accountID)
Find(&routes, accountIDCondition, accountID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err) log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get routes from store") return nil, status.Errorf(status.Internal, "failed to get routes from store")
@@ -2143,9 +2140,13 @@ func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStr
// GetRouteByID retrieves a route by its ID and account ID. // GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) { func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var route *route.Route var route *route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := tx.Take(&route, accountAndIDQueryCondition, accountID, routeID)
First(&route, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewRouteNotFoundError(routeID) return nil, status.NewRouteNotFoundError(routeID)
@@ -2158,8 +2159,8 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
} }
// SaveRoute saves a route to the database. // SaveRoute saves a route to the database.
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error { func (s *SqlStore) SaveRoute(ctx context.Context, route *route.Route) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route) result := s.db.Save(route)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save route to the store: %s", err) log.WithContext(ctx).Errorf("failed to save route to the store: %s", err)
return status.Errorf(status.Internal, "failed to save route to store") return status.Errorf(status.Internal, "failed to save route to store")
@@ -2169,9 +2170,8 @@ func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength,
} }
// DeleteRoute deletes a route from the database. // DeleteRoute deletes a route from the database.
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error { func (s *SqlStore) DeleteRoute(ctx context.Context, accountID, routeID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err) log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete route from store") return status.Errorf(status.Internal, "failed to delete route from store")
@@ -2210,8 +2210,7 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre
} }
var setupKey *types.SetupKey var setupKey *types.SetupKey
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). result := tx.Take(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewSetupKeyNotFoundError(setupKeyID) return nil, status.NewSetupKeyNotFoundError(setupKeyID)
@@ -2224,8 +2223,8 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre
} }
// SaveSetupKey saves a setup key to the database. // SaveSetupKey saves a setup key to the database.
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error { func (s *SqlStore) SaveSetupKey(ctx context.Context, setupKey *types.SetupKey) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey) result := s.db.Save(setupKey)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error) log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error)
return status.Errorf(status.Internal, "failed to save setup key to store") return status.Errorf(status.Internal, "failed to save setup key to store")
@@ -2235,8 +2234,8 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt
} }
// DeleteSetupKey deletes a setup key from the database. // DeleteSetupKey deletes a setup key from the database.
func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error { func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&types.SetupKey{}, accountAndIDQueryCondition, accountID, keyID) result := s.db.Delete(&types.SetupKey{}, accountAndIDQueryCondition, accountID, keyID)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error) log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete setup key from store") return status.Errorf(status.Internal, "failed to delete setup key from store")
@@ -2275,7 +2274,7 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
var nsGroup *nbdns.NameServerGroup var nsGroup *nbdns.NameServerGroup
result := tx. result := tx.
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID) Take(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewNameServerGroupNotFoundError(nsGroupID) return nil, status.NewNameServerGroupNotFoundError(nsGroupID)
@@ -2288,8 +2287,8 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
} }
// SaveNameServerGroup saves a name server group to the database. // SaveNameServerGroup saves a name server group to the database.
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error { func (s *SqlStore) SaveNameServerGroup(ctx context.Context, nameServerGroup *nbdns.NameServerGroup) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup) result := s.db.Save(nameServerGroup)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err) log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
return status.Errorf(status.Internal, "failed to save name server group to store") return status.Errorf(status.Internal, "failed to save name server group to store")
@@ -2298,8 +2297,8 @@ func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength Locking
} }
// DeleteNameServerGroup deletes a name server group from the database. // DeleteNameServerGroup deletes a name server group from the database.
func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) error { func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID) result := s.db.Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err) log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete name server group from store") return status.Errorf(status.Internal, "failed to delete name server group from store")
@@ -2313,8 +2312,8 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
} }
// SaveDNSSettings saves the DNS settings to the store. // SaveDNSSettings saves the DNS settings to the store.
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error { func (s *SqlStore) SaveDNSSettings(ctx context.Context, accountID string, settings *types.DNSSettings) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). result := s.db.Model(&types.Account{}).
Where(idQueryCondition, accountID).Updates(&types.AccountDNSSettings{DNSSettings: *settings}) Where(idQueryCondition, accountID).Updates(&types.AccountDNSSettings{DNSSettings: *settings})
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error) log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error)
@@ -2329,8 +2328,8 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre
} }
// SaveAccountSettings stores the account settings in DB. // SaveAccountSettings stores the account settings in DB.
func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error { func (s *SqlStore) SaveAccountSettings(ctx context.Context, accountID string, settings *types.Settings) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). result := s.db.Model(&types.Account{}).
Select("*").Where(idQueryCondition, accountID).Updates(&types.AccountSettings{Settings: settings}) Select("*").Where(idQueryCondition, accountID).Updates(&types.AccountSettings{Settings: settings})
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save account settings to store: %v", result.Error) log.WithContext(ctx).Errorf("failed to save account settings to store: %v", result.Error)
@@ -2367,8 +2366,7 @@ func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStren
} }
var network *networkTypes.Network var network *networkTypes.Network
result := tx. result := tx.Take(&network, accountAndIDQueryCondition, accountID, networkID)
First(&network, accountAndIDQueryCondition, accountID, networkID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewNetworkNotFoundError(networkID) return nil, status.NewNetworkNotFoundError(networkID)
@@ -2381,8 +2379,8 @@ func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStren
return network, nil return network, nil
} }
func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error { func (s *SqlStore) SaveNetwork(ctx context.Context, network *networkTypes.Network) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(network) result := s.db.Save(network)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save network to store: %v", result.Error) log.WithContext(ctx).Errorf("failed to save network to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save network to store") return status.Errorf(status.Internal, "failed to save network to store")
@@ -2391,9 +2389,8 @@ func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength
return nil return nil
} }
func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error { func (s *SqlStore) DeleteNetwork(ctx context.Context, accountID, networkID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.Delete(&networkTypes.Network{}, accountAndIDQueryCondition, accountID, networkID)
Delete(&networkTypes.Network{}, accountAndIDQueryCondition, accountID, networkID)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete network from store: %v", result.Error) log.WithContext(ctx).Errorf("failed to delete network from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete network from store") return status.Errorf(status.Internal, "failed to delete network from store")
@@ -2448,7 +2445,7 @@ func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength Lockin
var netRouter *routerTypes.NetworkRouter var netRouter *routerTypes.NetworkRouter
result := tx. result := tx.
First(&netRouter, accountAndIDQueryCondition, accountID, routerID) Take(&netRouter, accountAndIDQueryCondition, accountID, routerID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewNetworkRouterNotFoundError(routerID) return nil, status.NewNetworkRouterNotFoundError(routerID)
@@ -2460,8 +2457,8 @@ func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength Lockin
return netRouter, nil return netRouter, nil
} }
func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error { func (s *SqlStore) SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(router) result := s.db.Save(router)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save network router to store: %v", result.Error) log.WithContext(ctx).Errorf("failed to save network router to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save network router to store") return status.Errorf(status.Internal, "failed to save network router to store")
@@ -2470,9 +2467,8 @@ func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingSt
return nil return nil
} }
func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error { func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, accountID, routerID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.Delete(&routerTypes.NetworkRouter{}, accountAndIDQueryCondition, accountID, routerID)
Delete(&routerTypes.NetworkRouter{}, accountAndIDQueryCondition, accountID, routerID)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete network router from store: %v", result.Error) log.WithContext(ctx).Errorf("failed to delete network router from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete network router from store") return status.Errorf(status.Internal, "failed to delete network router from store")
@@ -2527,7 +2523,7 @@ func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength Lock
var netResources *resourceTypes.NetworkResource var netResources *resourceTypes.NetworkResource
result := tx. result := tx.
First(&netResources, accountAndIDQueryCondition, accountID, resourceID) Take(&netResources, accountAndIDQueryCondition, accountID, resourceID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewNetworkResourceNotFoundError(resourceID) return nil, status.NewNetworkResourceNotFoundError(resourceID)
@@ -2547,7 +2543,7 @@ func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength Lo
var netResources *resourceTypes.NetworkResource var netResources *resourceTypes.NetworkResource
result := tx. result := tx.
First(&netResources, "account_id = ? AND name = ?", accountID, resourceName) Take(&netResources, "account_id = ? AND name = ?", accountID, resourceName)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewNetworkResourceNotFoundError(resourceName) return nil, status.NewNetworkResourceNotFoundError(resourceName)
@@ -2559,8 +2555,8 @@ func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength Lo
return netResources, nil return netResources, nil
} }
func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error { func (s *SqlStore) SaveNetworkResource(ctx context.Context, resource *resourceTypes.NetworkResource) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(resource) result := s.db.Save(resource)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save network resource to store: %v", result.Error) log.WithContext(ctx).Errorf("failed to save network resource to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save network resource to store") return status.Errorf(status.Internal, "failed to save network resource to store")
@@ -2569,9 +2565,8 @@ func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength Locking
return nil return nil
} }
func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error { func (s *SqlStore) DeleteNetworkResource(ctx context.Context, accountID, resourceID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.Delete(&resourceTypes.NetworkResource{}, accountAndIDQueryCondition, accountID, resourceID)
Delete(&resourceTypes.NetworkResource{}, accountAndIDQueryCondition, accountID, resourceID)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete network resource from store: %v", result.Error) log.WithContext(ctx).Errorf("failed to delete network resource from store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete network resource from store") return status.Errorf(status.Internal, "failed to delete network resource from store")
@@ -2592,7 +2587,7 @@ func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength Locking
} }
var pat types.PersonalAccessToken var pat types.PersonalAccessToken
result := tx.First(&pat, "hashed_token = ?", hashedToken) result := tx.Take(&pat, "hashed_token = ?", hashedToken)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPATNotFoundError(hashedToken) return nil, status.NewPATNotFoundError(hashedToken)
@@ -2613,7 +2608,7 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength,
var pat types.PersonalAccessToken var pat types.PersonalAccessToken
result := tx. result := tx.
First(&pat, "id = ? AND user_id = ?", patID, userID) Take(&pat, "id = ? AND user_id = ?", patID, userID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewPATNotFoundError(patID) return nil, status.NewPATNotFoundError(patID)
@@ -2643,13 +2638,13 @@ func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength
} }
// MarkPATUsed marks a personal access token as used. // MarkPATUsed marks a personal access token as used.
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error { func (s *SqlStore) MarkPATUsed(ctx context.Context, patID string) error {
patCopy := types.PersonalAccessToken{ patCopy := types.PersonalAccessToken{
LastUsed: util.ToPtr(time.Now().UTC()), LastUsed: util.ToPtr(time.Now().UTC()),
} }
fieldsToUpdate := []string{"last_used"} fieldsToUpdate := []string{"last_used"}
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Select(fieldsToUpdate). result := s.db.Select(fieldsToUpdate).
Where(idQueryCondition, patID).Updates(&patCopy) Where(idQueryCondition, patID).Updates(&patCopy)
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("failed to mark pat as used: %s", result.Error) log.WithContext(ctx).Errorf("failed to mark pat as used: %s", result.Error)
@@ -2664,8 +2659,8 @@ func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength
} }
// SavePAT saves a personal access token to the database. // SavePAT saves a personal access token to the database.
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *types.PersonalAccessToken) error { func (s *SqlStore) SavePAT(ctx context.Context, pat *types.PersonalAccessToken) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat) result := s.db.Save(pat)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err) log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err)
return status.Errorf(status.Internal, "failed to save pat to store") return status.Errorf(status.Internal, "failed to save pat to store")
@@ -2675,9 +2670,8 @@ func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pa
} }
// DeletePAT deletes a personal access token from the database. // DeletePAT deletes a personal access token from the database.
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error { func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.Delete(&types.PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID)
Delete(&types.PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete pat from the store: %s", err) log.WithContext(ctx).Errorf("failed to delete pat from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete pat from store") return status.Errorf(status.Internal, "failed to delete pat from store")
@@ -2700,7 +2694,7 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength
var peer nbpeer.Peer var peer nbpeer.Peer
result := tx. result := tx.
First(&peer, "account_id = ? AND ip = ?", accountID, jsonValue) Take(&peer, "account_id = ? AND ip = ?", accountID, jsonValue)
if result.Error != nil { if result.Error != nil {
// no logging here // no logging here
return nil, status.Errorf(status.Internal, "failed to get peer from store") return nil, status.Errorf(status.Internal, "failed to get peer from store")

File diff suppressed because it is too large Load Diff

View File

@@ -72,8 +72,8 @@ type Store interface {
SaveAccount(ctx context.Context, account *types.Account) error SaveAccount(ctx context.Context, account *types.Account) error
DeleteAccount(ctx context.Context, account *types.Account) error DeleteAccount(ctx context.Context, account *types.Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error SaveDNSSettings(ctx context.Context, accountID string, settings *types.DNSSettings) error
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error SaveAccountSettings(ctx context.Context, accountID string, settings *types.Settings) error
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error
@@ -81,10 +81,10 @@ type Store interface {
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error)
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error)
GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error)
SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error SaveUsers(ctx context.Context, users []*types.User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error SaveUser(ctx context.Context, user *types.User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error DeleteUser(ctx context.Context, accountID, userID string) error
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteHashedPAT2TokenIDIndex(hashedToken string) error
DeleteTokenID2UserIDIndex(tokenID string) error DeleteTokenID2UserIDIndex(tokenID string) error
@@ -92,34 +92,34 @@ type Store interface {
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error)
GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error)
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error MarkPATUsed(ctx context.Context, patID string) error
SavePAT(ctx context.Context, strength LockingStrength, pat *types.PersonalAccessToken) error SavePAT(ctx context.Context, pat *types.PersonalAccessToken) error
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error DeletePAT(ctx context.Context, userID, patID string) error
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error
UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error
CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error CreateGroup(ctx context.Context, group *types.Group) error
UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error UpdateGroup(ctx context.Context, group *types.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error DeleteGroup(ctx context.Context, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error DeleteGroups(ctx context.Context, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error)
CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error CreatePolicy(ctx context.Context, policy *types.Policy) error
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error SavePolicy(ctx context.Context, policy *types.Policy) error
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error DeletePolicy(ctx context.Context, accountID, policyID string) error
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error SavePostureChecks(ctx context.Context, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error DeletePostureChecks(ctx context.Context, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
@@ -130,7 +130,7 @@ type Store interface {
GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error) GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error)
AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error
RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error
AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
@@ -139,30 +139,30 @@ type Store interface {
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error DeletePeer(ctx context.Context, accountID string, peerID string) error
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error)
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error SaveSetupKey(ctx context.Context, setupKey *types.SetupKey) error
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error DeleteSetupKey(ctx context.Context, accountID, keyID string) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error SaveRoute(ctx context.Context, route *route.Route) error
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error DeleteRoute(ctx context.Context, accountID, routeID string) error
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error SaveNameServerGroup(ctx context.Context, nameServerGroup *dns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error DeleteNameServerGroup(ctx context.Context, accountID, nameServerGroupID string) error
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error IncrementNetworkSerial(ctx context.Context, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error)
GetInstallationID() string GetInstallationID() string
@@ -184,21 +184,21 @@ type Store interface {
GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error)
GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error)
SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error SaveNetwork(ctx context.Context, network *networkTypes.Network) error
DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error DeleteNetwork(ctx context.Context, accountID, networkID string) error
GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error)
GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error)
GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error)
SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error
DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error DeleteNetworkRouter(ctx context.Context, accountID, routerID string) error
GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error)
GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error)
GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error)
GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error)
SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error SaveNetworkResource(ctx context.Context, resource *resourceTypes.NetworkResource) error
DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error DeleteNetworkResource(ctx context.Context, accountID, resourceID string) error
GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error)
GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error)
GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error) GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error)

View File

@@ -17,11 +17,11 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/status"
) )
// createServiceUser creates a new service user under the given account. // createServiceUser creates a new service user under the given account.
@@ -46,7 +46,7 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI
newUser.AccountID = accountID newUser.AccountID = accountID
log.WithContext(ctx).Debugf("New User: %v", newUser) log.WithContext(ctx).Debugf("New User: %v", newUser)
if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil { if err = am.Store.SaveUser(ctx, newUser); err != nil {
return nil, err return nil, err
} }
@@ -95,14 +95,14 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
inviterID := userID inviterID := userID
if initiatorUser.IsServiceUser { if initiatorUser.IsServiceUser {
createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthShare, accountID) createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -124,7 +124,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
} }
if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil { if err = am.Store.SaveUser(ctx, newUser); err != nil {
return nil, err return nil, err
} }
@@ -178,13 +178,13 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
} }
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
} }
// GetUser looks up a user by provided nbContext.UserAuths. // GetUser looks up a user by provided nbContext.UserAuths.
// Expects account to have been created already. // Expects account to have been created already.
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) { func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -209,11 +209,11 @@ func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAu
// ListUsers returns lists of all users under the account. // ListUsers returns lists of all users under the account.
// It doesn't populate user information such as email or name. // It doesn't populate user information such as email or name.
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
return am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
} }
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error { func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error {
if err := am.Store.DeleteUser(ctx, store.LockingStrengthUpdate, accountID, targetUser.Id); err != nil { if err := am.Store.DeleteUser(ctx, accountID, targetUser.Id); err != nil {
return err return err
} }
meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt} meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt}
@@ -230,7 +230,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -243,7 +243,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -347,12 +347,12 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -367,7 +367,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
} }
if err = am.Store.SavePAT(ctx, store.LockingStrengthUpdate, &pat.PersonalAccessToken); err != nil { if err = am.Store.SavePAT(ctx, &pat.PersonalAccessToken); err != nil {
return nil, err return nil, err
} }
@@ -390,12 +390,12 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return err return err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -404,12 +404,12 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return status.NewAdminPermissionError() return status.NewAdminPermissionError()
} }
pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID) pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
if err != nil { if err != nil {
return err return err
} }
if err = am.Store.DeletePAT(ctx, store.LockingStrengthUpdate, targetUserID, tokenID); err != nil { if err = am.Store.DeletePAT(ctx, targetUserID, tokenID); err != nil {
return err return err
} }
@@ -429,12 +429,12 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -443,7 +443,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, status.NewAdminPermissionError() return nil, status.NewAdminPermissionError()
} }
return am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID) return am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
} }
// GetAllPATs returns all PATs for a user // GetAllPATs returns all PATs for a user
@@ -456,12 +456,12 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -470,7 +470,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, status.NewAdminPermissionError() return nil, status.NewAdminPermissionError()
} }
return am.Store.GetUserPATs(ctx, store.LockingStrengthShare, targetUserID) return am.Store.GetUserPATs(ctx, store.LockingStrengthNone, targetUserID)
} }
// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error. // SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
@@ -511,7 +511,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
if !allowed { if !allowed {
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -521,7 +521,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var addUserEvents []func() var addUserEvents []func()
var usersToSave = make([]*types.User, 0, len(updates)) var usersToSave = make([]*types.User, 0, len(updates))
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err) return nil, fmt.Errorf("error getting account groups: %w", err)
} }
@@ -533,7 +533,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var initiatorUser *types.User var initiatorUser *types.User
if initiatorUserID != activity.SystemInitiator { if initiatorUserID != activity.SystemInitiator {
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -560,7 +560,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
updateAccountPeers = true updateAccountPeers = true
} }
} }
return transaction.SaveUsers(ctx, store.LockingStrengthUpdate, usersToSave) return transaction.SaveUsers(ctx, usersToSave)
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -593,7 +593,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
} }
if settings.GroupsPropagationEnabled && updateAccountPeers { if settings.GroupsPropagationEnabled && updateAccountPeers {
if err = am.Store.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil {
return nil, fmt.Errorf("failed to increment network serial: %w", err) return nil, fmt.Errorf("failed to increment network serial: %w", err)
} }
am.UpdateAccountPeers(ctx, accountID) am.UpdateAccountPeers(ctx, accountID)
@@ -700,7 +700,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist. // getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) { func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) {
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id) existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, update.Id)
if err != nil { if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
if !addIfNotExists { if !addIfNotExists {
@@ -724,7 +724,7 @@ func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initi
newInitiatorUser := initiatorUser.Copy() newInitiatorUser := initiatorUser.Copy()
newInitiatorUser.Role = types.UserRoleAdmin newInitiatorUser.Role = types.UserRoleAdmin
if err := transaction.SaveUser(ctx, store.LockingStrengthUpdate, newInitiatorUser); err != nil { if err := transaction.SaveUser(ctx, newInitiatorUser); err != nil {
return false, err return false, err
} }
return true, nil return true, nil
@@ -835,7 +835,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
var user *types.User var user *types.User
if initiatorUserID != activity.SystemInitiator { if initiatorUserID != activity.SystemInitiator {
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err) return nil, fmt.Errorf("failed to get user: %w", err)
} }
@@ -845,7 +845,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
accountUsers := []*types.User{} accountUsers := []*types.User{}
switch { switch {
case allowed: case allowed:
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -939,7 +939,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
// expireAndUpdatePeers expires all peers of the given user and updates them in the account // expireAndUpdatePeers expires all peers of the given user and updates them in the account
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error {
log.WithContext(ctx).Debugf("Expiring %d peers for account %s", len(peers), accountID) log.WithContext(ctx).Debugf("Expiring %d peers for account %s", len(peers), accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -956,7 +956,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
peerIDs = append(peerIDs, peer.ID) peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true) peer.MarkLoginExpired(true)
if err := am.Store.SavePeerStatus(ctx, store.LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil { if err := am.Store.SavePeerStatus(ctx, accountID, peer.ID, *peer.Status); err != nil {
return err return err
} }
am.StoreEvent( am.StoreEvent(
@@ -1009,7 +1009,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -1023,7 +1023,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
continue continue
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
allErrors = errors.Join(allErrors, err) allErrors = errors.Join(allErrors, err)
continue continue
@@ -1087,12 +1087,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var err error var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserInfo.ID) targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, targetUserInfo.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user to delete: %w", err) return fmt.Errorf("failed to get user to delete: %w", err)
} }
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, targetUserInfo.ID) userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user peers: %w", err) return fmt.Errorf("failed to get user peers: %w", err)
} }
@@ -1105,7 +1105,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
} }
} }
if err = transaction.DeleteUser(ctx, store.LockingStrengthUpdate, accountID, targetUserInfo.ID); err != nil { if err = transaction.DeleteUser(ctx, accountID, targetUserInfo.ID); err != nil {
return fmt.Errorf("failed to delete user: %s %w", targetUserInfo.ID, err) return fmt.Errorf("failed to delete user: %s %w", targetUserInfo.ID, err)
} }
@@ -1126,7 +1126,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
// GetOwnerInfo retrieves the owner information for a given account ID. // GetOwnerInfo retrieves the owner information for a given account ID.
func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) { func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) {
owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthShare, accountID) owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1176,7 +1176,7 @@ func validateUserInvite(invite *types.UserInfo) error {
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
accountID, userID := userAuth.AccountId, userAuth.UserId accountID, userID := userAuth.AccountId, userAuth.UserId
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1193,7 +1193,7 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut
return nil, err return nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -15,9 +15,9 @@ import (
"github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/status"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@@ -88,7 +88,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
assert.Equal(t, pat.ID, tokenID) assert.Equal(t, pat.ID, tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID) user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthNone, tokenID)
if err != nil { if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err) t.Fatalf("Error when getting user by token ID: %s", err)
} }
@@ -1521,7 +1521,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
_, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true) _, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true)
assert.Error(t, err, "update user to another account should fail") assert.Error(t, err, "update user to another account should fail")
user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthShare, targetId) user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetId)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, account1.Users[targetId].Id, user.Id) assert.Equal(t, account1.Users[targetId].Id, user.Id)
assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID) assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID)

View File

@@ -26,7 +26,7 @@ func NewManager(store store.Store) Manager {
} }
func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) { func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) {
return m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
} }
func NewManagerMock() Manager { func NewManagerMock() Manager {