Compare commits

...

3 Commits

Author SHA1 Message Date
Maycon Santos
5f5d597c59 fix GetSetupKeyByID 2025-06-26 13:54:04 +02:00
Maycon Santos
b85aad07d4 add getTXWithLockStrength method 2025-06-26 13:44:00 +02:00
Maycon Santos
7398836c2e Stop using locking share for read calls to avoid deadlocks
Added peer.userID index
2025-06-26 12:24:42 +02:00
29 changed files with 295 additions and 493 deletions

View File

@@ -370,7 +370,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
} }
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil { if err != nil {
return err return err
} }
@@ -708,7 +708,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
// AccountExists checks if an account exists. // AccountExists checks if an account exists.
func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
return am.Store.AccountExists(ctx, store.LockingStrengthShare, accountID) return am.Store.AccountExists(ctx, store.LockingStrengthNone, accountID)
} }
// GetAccountIDByUserID retrieves the account ID based on the userID provided. // GetAccountIDByUserID retrieves the account ID based on the userID provided.
@@ -720,7 +720,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
return "", status.Errorf(status.NotFound, "no valid userID provided") return "", status.Errorf(status.NotFound, "no valid userID provided")
} }
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
@@ -775,7 +775,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID) accountIDString := fmt.Sprintf("%v", accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountIDString) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -829,7 +829,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) { func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -859,7 +859,7 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP, // add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta // or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID) log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
return nil, err return nil, err
@@ -1010,7 +1010,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlockAccount() defer unlockAccount()
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, accountID) accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return err return err
@@ -1020,7 +1020,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
return nil return nil
} }
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting user: %v", err) log.WithContext(ctx).Errorf("error getting user: %v", err)
return err return err
@@ -1185,7 +1185,7 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
} }
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
@@ -1205,7 +1205,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return "", "", err return "", "", err
} }
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
// this is not really possible because we got an account by user ID // this is not really possible because we got an account by user ID
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId) return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
@@ -1237,7 +1237,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil return nil
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return err return err
} }
@@ -1263,12 +1263,12 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
var hasChanges bool var hasChanges bool
var user *types.User var user *types.User
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
return fmt.Errorf("error getting user: %w", err) return fmt.Errorf("error getting user: %w", err)
} }
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId) groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@@ -1298,7 +1298,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
// Propagate changes to peers if group propagation is enabled // Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId) groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@@ -1308,7 +1308,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
groupsMap[group.ID] = group groupsMap[group.ID] = group
} }
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId) peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
return fmt.Errorf("error getting user peers: %w", err) return fmt.Errorf("error getting user peers: %w", err)
} }
@@ -1340,7 +1340,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
for _, g := range addNewGroups { for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else { } else {
@@ -1353,7 +1353,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
for _, g := range removeOldGroups { for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else { } else {
@@ -1414,7 +1414,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
} }
if userAuth.IsChild { if userAuth.IsChild {
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId) exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil || !exists { if err != nil || !exists {
return "", err return "", err
} }
@@ -1438,7 +1438,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", err return "", err
} }
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err return "", err
@@ -1459,7 +1459,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return am.addNewPrivateAccount(ctx, domainAccountID, userAuth) return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
} }
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
@@ -1474,7 +1474,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
cancel := am.Store.AcquireGlobalLock(ctx) cancel := am.Store.AcquireGlobalLock(ctx)
// check again if the domain has a primary account because of simultaneous requests // check again if the domain has a primary account because of simultaneous requests
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
cancel() cancel()
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
@@ -1485,7 +1485,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
} }
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err return "", err
@@ -1495,7 +1495,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId) return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId)
} }
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId) accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, userAuth.AccountId)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return "", err return "", err
@@ -1506,7 +1506,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
} }
// We checked if the domain has a primary account already // We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, userAuth.Domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", err return "", err
@@ -1636,7 +1636,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
} }
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) { func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -1658,7 +1658,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
} }
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) { func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err) return "", fmt.Errorf("failed to get peer dns labels: %w", err)
} }
@@ -1684,7 +1684,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
if !allowed { if !allowed {
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
} }
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
@@ -1770,7 +1770,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
cancel := am.Store.AcquireGlobalLock(ctx) cancel := am.Store.AcquireGlobalLock(ctx)
defer cancel() defer cancel()
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
return nil, false, err return nil, false, err
} }
@@ -1790,7 +1790,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
for range 2 { for range 2 {
accountId := xid.New().String() accountId := xid.New().String()
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId) exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, accountId)
if err != nil || exists { if err != nil || exists {
continue continue
} }
@@ -1865,7 +1865,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return nil return nil
} }
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain) existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain)
// error is not a not found error // error is not a not found error
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
@@ -1902,7 +1902,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
// propagateUserGroupMemberships propagates all account users' group memberships to their peers. // propagateUserGroupMemberships propagates all account users' group memberships to their peers.
// Returns true if any groups were modified, true if those updates affect peers and an error. // Returns true if any groups were modified, true if those updates affect peers and an error.
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, false, err return false, false, err
} }
@@ -1912,7 +1912,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
groupsMap[group.ID] = group groupsMap[group.ID] = group
} }
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, false, err return false, false, err
} }
@@ -1920,7 +1920,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
groupsToUpdate := make(map[string]*types.Group) groupsToUpdate := make(map[string]*types.Group)
for _, user := range users { for _, user := range users {
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id) userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id)
if err != nil { if err != nil {
return false, false, err return false, false, err
} }

View File

@@ -782,7 +782,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
return return
} }
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID) exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthNone, accountID)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, exists, "expected to get existing account after creation using userid") assert.True(t, exists, "expected to get existing account after creation using userid")
@@ -899,11 +899,11 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount)) t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
} }
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1") pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, "service-user-1")
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, pats, 0) assert.Len(t, pats, 0)
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId) pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, userId)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, pats, 0) assert.Len(t, pats, 0)
} }
@@ -1775,7 +1775,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
assert.NotNil(t, settings) assert.NotNil(t, settings)
@@ -1960,7 +1960,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updatedSettings.PeerLoginExpirationEnabled) assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour) assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
assert.False(t, settings.PeerLoginExpirationEnabled) assert.False(t, settings.PeerLoginExpirationEnabled)
@@ -2643,7 +2643,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced") assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
}) })
@@ -2657,7 +2657,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims) err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Empty(t, user.AutoGroups, "auto groups must be empty") assert.Empty(t, user.AutoGroups, "auto groups must be empty")
}) })
@@ -2671,11 +2671,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims) err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0) assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2692,11 +2692,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1) assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2710,7 +2710,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change") assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
@@ -2724,7 +2724,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change") assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
@@ -2738,11 +2738,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID") groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, "accountID")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added") assert.Len(t, groups, 3, "new group3 should be added")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added") assert.Len(t, user.AutoGroups, 1, "new group should be added")
}) })
@@ -2756,7 +2756,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present") assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
@@ -2771,7 +2771,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed") assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
}) })
@@ -3354,7 +3354,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id} group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1)) require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group1.ID) user.AutoGroups = append(user.AutoGroups, group1.ID)
@@ -3365,7 +3365,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.True(t, groupsUpdated) assert.True(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers) assert.False(t, groupChangesAffectPeers)
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthShare, account.Id, group1.ID) group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, group1.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1") assert.Contains(t, group.Peers, "peer1")
@@ -3376,7 +3376,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id} group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2)) require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group2.ID) user.AutoGroups = append(user.AutoGroups, group2.ID)
@@ -3403,7 +3403,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.True(t, groupsUpdated) assert.True(t, groupsUpdated)
assert.True(t, groupChangesAffectPeers) assert.True(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"}) groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err) require.NoError(t, err)
for _, group := range groups { for _, group := range groups {
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)
@@ -3420,7 +3420,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
}) })
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) { t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = []string{"group1"} user.AutoGroups = []string{"group1"}
@@ -3431,7 +3431,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.False(t, groupsUpdated) assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers) assert.False(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"}) groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err) require.NoError(t, err)
for _, group := range groups { for _, group := range groups {
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)

View File

@@ -73,7 +73,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
return userAuth, nil return userAuth, nil
} }
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return userAuth, err return userAuth, err
} }
@@ -104,7 +104,7 @@ func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.Us
return nil, nil, "", "", err return nil, nil, "", "", err
} }
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID) domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, user.AccountID)
if err != nil { if err != nil {
return nil, nil, "", "", err return nil, nil, "", "", err
} }
@@ -142,12 +142,12 @@ func (am *manager) extractPATFromToken(ctx context.Context, token string) (*type
var pat *types.PersonalAccessToken var pat *types.PersonalAccessToken
err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken) pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthNone, encodedHashedToken)
if err != nil { if err != nil {
return err return err
} }
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID) user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthNone, pat.ID)
return err return err
}) })
if err != nil { if err != nil {

View File

@@ -72,7 +72,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
} }
// SaveDNSSettings validates a user role and updates the account's DNS settings // SaveDNSSettings validates a user role and updates the account's DNS settings
@@ -139,7 +139,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
var eventsToStore []func() var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups) modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
return nil return nil
@@ -195,7 +195,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
return nil return nil
} }
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, settings.DisabledManagementGroups)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -122,7 +122,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
} }
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare) peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err) log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return return

View File

@@ -103,7 +103,7 @@ func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, eve
} }
func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) { func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -154,7 +154,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
continue continue
} }
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
if err != nil { if err != nil {
// @todo consider logging // @todo consider logging
continue continue

View File

@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) return am.Store.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
@@ -57,12 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName) return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
} }
// SaveGroup object of the peers // SaveGroup object of the peers
@@ -140,7 +140,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err == nil && oldGroup != nil { if err == nil && oldGroup != nil {
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers) addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers) removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
@@ -152,13 +152,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
} }
modifiedPeers := slices.Concat(addedPeers, removedPeers) modifiedPeers := slices.Concat(addedPeers, removedPeers)
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers) peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, modifiedPeers)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil return nil
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err) log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
return nil return nil
@@ -431,7 +431,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
} }
if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI { if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name) existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthNone, accountID, newGroup.Name)
if err != nil { if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err return err
@@ -448,7 +448,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
} }
for _, peerID := range newGroup.Peers { for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) _, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
} }
@@ -460,7 +460,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error { func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration { if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID) executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get user") return status.Errorf(status.Internal, "failed to get user")
} }
@@ -506,7 +506,7 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. // checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error { func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID) dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get DNS settings") return status.Errorf(status.Internal, "failed to get DNS settings")
} }
@@ -515,7 +515,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
return &GroupLinkError{"disabled DNS management groups", group.Name} return &GroupLinkError{"disabled DNS management groups", group.Name}
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get account settings") return status.Errorf(status.Internal, "failed to get account settings")
} }
@@ -529,7 +529,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
// isGroupLinkedToRoute checks if a group is linked to any route in the account. // isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) { func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil return false, nil
@@ -549,7 +549,7 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. // isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) { func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil return false, nil
@@ -567,7 +567,7 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, account
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil return false, nil
@@ -586,7 +586,7 @@ func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) { func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil return false, nil
@@ -602,7 +602,7 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accou
// isGroupLinkedToUser checks if a group is linked to any user in the account. // isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) { func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil return false, nil
@@ -618,7 +618,7 @@ func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account. // isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) { func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
return false, nil return false, nil
@@ -638,7 +638,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil return false, nil
} }
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -666,7 +666,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. // anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@@ -49,7 +49,7 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string
return nil, err return nil, err
} }
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err) return nil, fmt.Errorf("error getting account groups: %w", err)
} }
@@ -96,13 +96,13 @@ func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, trans
return nil, fmt.Errorf("error adding resource to group: %w", err) return nil, fmt.Errorf("error adding resource to group: %w", err)
} }
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting group: %w", err) return nil, fmt.Errorf("error getting group: %w", err)
} }
// TODO: at some point, this will need to become a switch statement // TODO: at some point, this will need to become a switch statement
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resource.ID) networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resource.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting network resource: %w", err) return nil, fmt.Errorf("error getting network resource: %w", err)
} }
@@ -120,13 +120,13 @@ func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context,
return nil, fmt.Errorf("error removing resource from group: %w", err) return nil, fmt.Errorf("error removing resource from group: %w", err)
} }
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting group: %w", err) return nil, fmt.Errorf("error getting group: %w", err)
} }
// TODO: at some point, this will need to become a switch statement // TODO: at some point, this will need to become a switch statement
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting network resource: %w", err) return nil, fmt.Errorf("error getting network resource: %w", err)
} }

View File

@@ -63,7 +63,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID) _, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthNone, accountID, groupID)
if err != nil { if err != nil {
return err return err
} }
@@ -83,17 +83,17 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
var settings *types.Settings var settings *types.Settings
groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -641,7 +641,7 @@ func testSyncStatusRace(t *testing.T) {
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerWithInvalidStatus.PublicKey().String())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return

View File

@@ -32,7 +32,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID) return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupID)
} }
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
@@ -112,7 +112,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
var updateAccountPeers bool var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID) oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -202,7 +202,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
} }
func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
@@ -216,7 +216,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return err return err
} }
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -226,7 +226,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return err return err
} }
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, nameserverGroup.Groups)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -56,7 +56,7 @@ func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID stri
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetAccountNetworks(ctx, store.LockingStrengthShare, accountID) return m.store.GetAccountNetworks(ctx, store.LockingStrengthNone, accountID)
} }
func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
@@ -92,7 +92,7 @@ func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, network
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) return m.store.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
} }
func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {

View File

@@ -57,7 +57,7 @@ func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthShare, accountID, networkID) return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
} }
func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) { func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
@@ -69,7 +69,7 @@ func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
} }
func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
@@ -81,7 +81,7 @@ func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID,
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network resources: %w", err) return nil, fmt.Errorf("failed to get network resources: %w", err)
} }
@@ -113,7 +113,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
var eventsToStore []func() var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) _, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil { if err == nil {
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name) return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
} }
@@ -174,7 +174,7 @@ func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networ
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network resource: %w", err) return nil, fmt.Errorf("failed to get network resource: %w", err)
} }
@@ -218,17 +218,17 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID) return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID)
} }
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) _, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network resource: %w", err) return fmt.Errorf("failed to get network resource: %w", err)
} }
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil && oldResource.ID != resource.ID { if err == nil && oldResource.ID != resource.ID {
return status.Errorf(status.InvalidArgument, "new resource name already exists") return status.Errorf(status.InvalidArgument, "new resource name already exists")
} }
oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network resource: %w", err) return fmt.Errorf("failed to get network resource: %w", err)
} }

View File

@@ -54,7 +54,7 @@ func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthShare, accountID, networkID) return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
} }
func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) { func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) {
@@ -66,7 +66,7 @@ func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network routers: %w", err) return nil, fmt.Errorf("failed to get network routers: %w", err)
} }
@@ -93,7 +93,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
var network *networkTypes.Network var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network: %w", err) return fmt.Errorf("failed to get network: %w", err)
} }
@@ -136,7 +136,7 @@ func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkI
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthShare, accountID, routerID) router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthNone, accountID, routerID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network router: %w", err) return nil, fmt.Errorf("failed to get network router: %w", err)
} }
@@ -162,7 +162,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
var network *networkTypes.Network var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network: %w", err) return fmt.Errorf("failed to get network: %w", err)
} }
@@ -232,7 +232,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
} }
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) { func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err) return nil, fmt.Errorf("failed to get network: %w", err)
} }

View File

@@ -35,7 +35,7 @@ import (
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin. // the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -45,7 +45,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return nil, status.NewPermissionValidationError(err) return nil, status.NewPermissionValidationError(err)
} }
accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, nameFilter, ipFilter) accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -55,7 +55,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return accountPeers, nil return accountPeers, nil
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get account settings: %w", err) return nil, fmt.Errorf("failed to get account settings: %w", err)
} }
@@ -127,7 +127,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
} }
if peer.AddedWithSSOLogin() { if peer.AddedWithSSOLogin() {
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -216,7 +216,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return err return err
} }
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -335,7 +335,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID) peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
if err != nil { if err != nil {
return err return err
} }
@@ -468,7 +468,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
addedByUser := false addedByUser := false
if len(userID) > 0 { if len(userID) > 0 {
addedByUser = true addedByUser = true
accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
} else { } else {
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey) accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
} }
@@ -488,7 +488,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
// and the peer disconnects with a timeout and tries to register again. // and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration. // We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry. // The connecting peer should be able to recover with a retry.
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key) _, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key)
if err == nil { if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
} }
@@ -584,7 +584,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
ExtraDNSLabels: peer.ExtraDNSLabels, ExtraDNSLabels: peer.ExtraDNSLabels,
AllowExtraDNSLabels: allowExtraDNSLabels, AllowExtraDNSLabels: allowExtraDNSLabels,
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get account settings: %w", err) return fmt.Errorf("failed to get account settings: %w", err)
} }
@@ -674,7 +674,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) { func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) {
takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthShare, accountID) takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err) return nil, fmt.Errorf("failed to get taken IPs: %w", err)
} }
@@ -706,7 +706,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
var err error var err error
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -718,7 +718,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
} }
if peer.UserID != "" { if peer.UserID != "" {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
if err != nil { if err != nil {
return err return err
} }
@@ -821,7 +821,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
var isPeerUpdated bool var isPeerUpdated bool
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -906,7 +906,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
// getPeerPostureChecks returns the posture checks for the peer. // getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -930,7 +930,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...) peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
} }
peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, peerPostureChecksIDs) peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, peerPostureChecksIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -945,7 +945,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
continue continue
} }
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, rule.Sources) sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -970,7 +970,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired // with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests. // and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login types.PeerLogin) error { func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login types.PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey) peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, login.WireGuardPubKey)
if err != nil { if err != nil {
return err return err
} }
@@ -981,7 +981,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil return nil
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -1000,7 +1000,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
}() }()
if isRequiresApproval { if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -1062,7 +1062,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
log.WithContext(ctx).Debugf("failed to update user last login: %v", err) log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, peer.AccountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, peer.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get account settings: %w", err) return fmt.Errorf("failed to get account settings: %w", err)
} }
@@ -1104,7 +1104,7 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Se
// GetPeer for a given accountID, peerID and userID error if not found. // GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1117,7 +1117,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return peer, nil return peer, nil
} }
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1143,7 +1143,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
// it is also possible that user doesn't own the peer but some of his peers have access to it, // it is also possible that user doesn't own the peer but some of his peers have access to it,
// this is a valid case, show the peer as well. // this is a valid case, show the peer as well.
userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID) userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1328,7 +1328,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
// If there is no peer that expires this function returns false and a duration of 0. // If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are connected. // This function only considers peers that haven't been expired yet and that are connected.
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1338,7 +1338,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
return 0, false return 0, false
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err) log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1372,7 +1372,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
// If there is no peer that expires this function returns false and a duration of 0. // If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are not connected. // This function only considers peers that haven't been expired yet and that are not connected.
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err) log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1382,7 +1382,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
return 0, false return 0, false
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err) log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1413,12 +1413,12 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
// getExpiredPeers returns peers that have been expired. // getExpiredPeers returns peers that have been expired.
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1436,12 +1436,12 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
// getInactivePeers returns peers that have been expired by inactivity // getInactivePeers returns peers that have been expired by inactivity
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1459,12 +1459,12 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID
// GetPeerGroups returns groups that the peer is part of. // GetPeerGroups returns groups that the peer is part of.
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) {
return am.Store.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID) return am.Store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
} }
// getPeerGroupIDs returns the IDs of the groups that the peer is part of. // getPeerGroupIDs returns the IDs of the groups that the peer is part of.
func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) {
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID) groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1478,7 +1478,7 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
} }
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) { func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) {
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1505,7 +1505,7 @@ func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
var peerDeletedEvents []func() var peerDeletedEvents []func()
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1516,7 +1516,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
return nil, err return nil, err
} }
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1577,7 +1577,7 @@ func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transac
// isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account. // isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account.
func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) { func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err)
return false, nil return false, nil

View File

@@ -31,7 +31,7 @@ type Peer struct {
// Status peer's management connection status // Status peer's management connection status
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
// The user ID that registered the peer // The user ID that registered the peer
UserID string UserID string `gorm:"index"`
// SSHKey is a public SSH key of the peer // SSHKey is a public SSH key of the peer
SSHKey string SSHKey string
// SSHEnabled indicates whether SSH server is enabled on the peer // SSHEnabled indicates whether SSH server is enabled on the peer

View File

@@ -1301,7 +1301,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels) assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels)
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key) peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, addedPeer.Key)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID) assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.UserID, existingUserID) assert.Equal(t, peer.UserID, existingUserID)
@@ -1423,7 +1423,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
assert.NotNil(t, addedPeer, "addedPeer should not be nil on success") assert.NotNil(t, addedPeer, "addedPeer should not be nil on success")
assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch") assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch")
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, currentPeer.Key) peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, currentPeer.Key)
require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key) require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key)
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store") assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store") assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store")
@@ -1505,7 +1505,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
require.Error(t, err) require.Error(t, err)
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key) _, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key)
require.Error(t, err) require.Error(t, err)
account, err := s.GetAccount(context.Background(), existingAccountID) account, err := s.GetAccount(context.Background(), existingAccountID)
@@ -1671,7 +1671,7 @@ func Test_LoginPeer(t *testing.T) {
assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer") assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer")
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, loginInput.WireGuardPubKey) peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, loginInput.WireGuardPubKey)
require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey) require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey)
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store") assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore") assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore")

View File

@@ -42,7 +42,7 @@ func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID str
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
} }
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
@@ -52,12 +52,12 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string)
} }
if !allowed { if !allowed {
return m.store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID) return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
} }
return m.store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
} }
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID) return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
} }

View File

@@ -45,7 +45,7 @@ func (m *managerImpl) ValidateUserPermissions(
return true, nil return true, nil
} }
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID) return am.Store.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID)
} }
// SavePolicy in the store // SavePolicy in the store
@@ -142,13 +142,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
} }
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. // arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
if isUpdate { if isUpdate {
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -173,7 +173,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
// validatePolicy validates the policy and its rules. // validatePolicy validates the policy and its rules.
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
if policy.ID != "" { if policy.ID != "" {
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -182,12 +182,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
policy.AccountID = accountID policy.AccountID = accountID
} }
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups()) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
if err != nil { if err != nil {
return err return err
} }
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks) postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
} }
// SavePostureChecks saves a posture check. // SavePostureChecks saves a posture check.
@@ -101,7 +101,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
var postureChecks *posture.Checks var postureChecks *posture.Checks
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
if err != nil { if err != nil {
return err return err
} }
@@ -135,7 +135,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
} }
// getPeerPostureChecks returns the posture checks applied for a given peer. // getPeerPostureChecks returns the posture checks applied for a given peer.
@@ -161,7 +161,7 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, pe
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -190,14 +190,14 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account
// If the posture check already has an ID, verify its existence in the store. // If the posture check already has an ID, verify its existence in the store.
if postureChecks.ID != "" { if postureChecks.ID != "" {
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil { if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID); err != nil {
return err return err
} }
return nil return nil
} }
// For new posture checks, ensure no duplicates by name. // For new posture checks, ensure no duplicates by name.
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -259,7 +259,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error { func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID)) return am.Store.GetRouteByID(ctx, store.LockingStrengthNone, accountID, string(routeID))
} }
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
@@ -59,7 +59,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
seenPeers[string(prefixRoute.ID)] = true seenPeers[string(prefixRoute.ID)] = true
} }
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups) peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, prefixRoute.PeerGroups)
if err != nil { if err != nil {
return err return err
} }
@@ -83,7 +83,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
if peerID := checkRoute.Peer; peerID != "" { if peerID := checkRoute.Peer; peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group // check that peerID exists and is not in any route as single peer or part of the group
_, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID) _, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
@@ -104,7 +104,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
} }
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers) peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, group.Peers)
if err != nil { if err != nil {
return err return err
} }
@@ -310,7 +310,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
} }
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error { func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
@@ -353,7 +353,7 @@ func validateRoute(ctx context.Context, transaction store.Store, accountID strin
// validateRouteGroups validates the route groups and returns the validated groups map. // validateRouteGroups validates the route groups and returns the validated groups map.
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) { func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups) groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate) groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupsToValidate)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -494,7 +494,7 @@ func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, ro
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix // GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) { func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1100,7 +1100,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id) groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, account.Id)
require.NoError(t, err) require.NoError(t, err)
var groupHA1, groupHA2 *types.Group var groupHA1, groupHA2 *types.Group
for _, group := range groups { for _, group := range groups {

View File

@@ -60,7 +60,7 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
return nil, fmt.Errorf("get extra settings: %w", err) return nil, fmt.Errorf("get extra settings: %w", err)
} }
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account settings: %w", err) return nil, fmt.Errorf("get account settings: %w", err)
} }
@@ -82,7 +82,7 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (*
return nil, fmt.Errorf("get extra settings: %w", err) return nil, fmt.Errorf("get extra settings: %w", err)
} }
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account settings: %w", err) return nil, fmt.Errorf("get account settings: %w", err)
} }

View File

@@ -127,7 +127,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
} }
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id) oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyToSave.Id)
if err != nil { if err != nil {
return err return err
} }
@@ -175,7 +175,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
} }
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
@@ -188,7 +188,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -214,7 +214,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
var deletedSetupKey *types.SetupKey var deletedSetupKey *types.SetupKey
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID)
if err != nil { if err != nil {
return err return err
} }
@@ -231,7 +231,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
} }
func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error { func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, autoGroupIDs)
if err != nil { if err != nil {
return err return err
} }
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
var eventsToStore []func() var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups) modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
return nil return nil

View File

@@ -478,7 +478,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
} }
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) { func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) {
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthNone, domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -488,11 +488,7 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
} }
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountID string var accountID string
result := tx.Model(&types.Account{}).Select("id"). result := tx.Model(&types.Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
@@ -542,11 +538,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
} }
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) { func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var user types.User var user types.User
result := tx. result := tx.
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id"). Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
@@ -563,11 +555,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
} }
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var user types.User var user types.User
result := tx.First(&user, idQueryCondition, userID) result := tx.First(&user, idQueryCondition, userID)
if result.Error != nil { if result.Error != nil {
@@ -600,11 +588,7 @@ func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength,
} }
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) { func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var users []*types.User var users []*types.User
result := tx.Find(&users, accountIDCondition, accountID) result := tx.Find(&users, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
@@ -619,11 +603,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
} }
func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) { func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var user types.User var user types.User
result := tx.First(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner) result := tx.First(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
if result.Error != nil { if result.Error != nil {
@@ -637,11 +617,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
} }
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) { func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var groups []*types.Group var groups []*types.Group
result := tx.Find(&groups, accountIDCondition, accountID) result := tx.Find(&groups, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
@@ -656,11 +632,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
} }
func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) { func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var groups []*types.Group var groups []*types.Group
likePattern := `%"ID":"` + resourceID + `"%` likePattern := `%"ID":"` + resourceID + `"%`
@@ -706,11 +678,7 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
} }
func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) { func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountMeta types.AccountMeta var accountMeta types.AccountMeta
result := tx.Model(&types.Account{}). result := tx.Model(&types.Account{}).
First(&accountMeta, idQueryCondition, accountID) First(&accountMeta, idQueryCondition, accountID)
@@ -880,11 +848,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
} }
func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) { func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountID string var accountID string
result := tx.Model(&types.User{}). result := tx.Model(&types.User{}).
Select("account_id").Where(idQueryCondition, userID).First(&accountID) Select("account_id").Where(idQueryCondition, userID).First(&accountID)
@@ -899,11 +863,7 @@ func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength Lockin
} }
func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) { func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountID string var accountID string
result := tx.Model(&nbpeer.Peer{}). result := tx.Model(&nbpeer.Peer{}).
Select("account_id").Where(idQueryCondition, peerID).First(&accountID) Select("account_id").Where(idQueryCondition, peerID).First(&accountID)
@@ -936,11 +896,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
} }
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var ipJSONStrings []string var ipJSONStrings []string
// Fetch the IP addresses as JSON strings // Fetch the IP addresses as JSON strings
@@ -968,11 +924,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
} }
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var labels []string var labels []string
result := tx.Model(&nbpeer.Peer{}). result := tx.Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID). Where("account_id = ?", accountID).
@@ -990,11 +942,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
} }
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountNetwork types.AccountNetwork var accountNetwork types.AccountNetwork
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -1006,11 +954,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
} }
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peer nbpeer.Peer var peer nbpeer.Peer
result := tx.First(&peer, GetKeyQueryCondition(s), peerKey) result := tx.First(&peer, GetKeyQueryCondition(s), peerKey)
@@ -1025,11 +969,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
} }
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) { func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountSettings types.AccountSettings var accountSettings types.AccountSettings
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -1041,11 +981,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
} }
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) { func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var createdBy string var createdBy string
result := tx.Model(&types.Account{}). result := tx.Model(&types.Account{}).
Select("created_by").First(&createdBy, idQueryCondition, accountID) Select("created_by").First(&createdBy, idQueryCondition, accountID)
@@ -1243,11 +1179,7 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
} }
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var setupKey types.SetupKey var setupKey types.SetupKey
result := tx. result := tx.
First(&setupKey, GetKeyQueryCondition(s), key) First(&setupKey, GetKeyQueryCondition(s), key)
@@ -1391,11 +1323,7 @@ func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string
// GetPeerGroups retrieves all groups assigned to a specific peer in a given account. // GetPeerGroups retrieves all groups assigned to a specific peer in a given account.
func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) { func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var groups []*types.Group var groups []*types.Group
query := tx. query := tx.
Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId)) Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId))
@@ -1409,8 +1337,9 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
// GetAccountPeers retrieves peers for an account. // GetAccountPeers retrieves peers for an account.
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
tx := s.getTXWithLockStrength(lockStrength)
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountIDCondition, accountID) query := tx.Where(accountIDCondition, accountID)
if nameFilter != "" { if nameFilter != "" {
query = query.Where("name LIKE ?", "%"+nameFilter+"%") query = query.Where("name LIKE ?", "%"+nameFilter+"%")
@@ -1429,11 +1358,7 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre
// GetUserPeers retrieves peers for a user. // GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
// Exclude peers added via setup keys, as they are not user-specific and have an empty user_id. // Exclude peers added via setup keys, as they are not user-specific and have an empty user_id.
@@ -1461,11 +1386,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStr
// GetPeerByID retrieves a peer by its ID and account ID. // GetPeerByID retrieves a peer by its ID and account ID.
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) { func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peer *nbpeer.Peer var peer *nbpeer.Peer
result := tx. result := tx.
First(&peer, accountAndIDQueryCondition, accountID, peerID) First(&peer, accountAndIDQueryCondition, accountID, peerID)
@@ -1481,11 +1402,7 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
// GetPeersByIDs retrieves peers by their IDs and account ID. // GetPeersByIDs retrieves peers by their IDs and account ID.
func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) { func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
result := tx.Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs) result := tx.Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs)
if result.Error != nil { if result.Error != nil {
@@ -1503,11 +1420,7 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng
// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user. // GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user.
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
result := tx. result := tx.
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
@@ -1522,11 +1435,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng
// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user. // GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user.
func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
result := tx. result := tx.
Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
@@ -1541,11 +1450,7 @@ func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStreng
// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing. // GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing.
func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var allEphemeralPeers, batchPeers []*nbpeer.Peer var allEphemeralPeers, batchPeers []*nbpeer.Peer
result := tx. result := tx.
Where("ephemeral = ?", true). Where("ephemeral = ?", true).
@@ -1623,11 +1528,7 @@ func (s *SqlStore) GetDB() *gorm.DB {
} }
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) { func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountDNSSettings types.AccountDNSSettings var accountDNSSettings types.AccountDNSSettings
result := tx.Model(&types.Account{}). result := tx.Model(&types.Account{}).
First(&accountDNSSettings, idQueryCondition, accountID) First(&accountDNSSettings, idQueryCondition, accountID)
@@ -1643,11 +1544,7 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
// AccountExists checks whether an account exists by the given ID. // AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountID string var accountID string
result := tx.Model(&types.Account{}). result := tx.Model(&types.Account{}).
Select("id").First(&accountID, idQueryCondition, id) Select("id").First(&accountID, idQueryCondition, id)
@@ -1663,11 +1560,7 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. // GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var account types.Account var account types.Account
result := tx.Model(&types.Account{}).Select("domain", "domain_category"). result := tx.Model(&types.Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account) Where(idQueryCondition, accountID).First(&account)
@@ -1683,11 +1576,7 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
// GetGroupByID retrieves a group by ID and account ID. // GetGroupByID retrieves a group by ID and account ID.
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) { func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var group *types.Group var group *types.Group
result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID) result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
@@ -1703,11 +1592,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
// GetGroupByName retrieves a group by name and account ID. // GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) { func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var group types.Group var group types.Group
// TODO: This fix is accepted for now, but if we need to handle this more frequently // TODO: This fix is accepted for now, but if we need to handle this more frequently
@@ -1736,11 +1621,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
// GetGroupsByIDs retrieves groups by their IDs and account ID. // GetGroupsByIDs retrieves groups by their IDs and account ID.
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) { func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var groups []*types.Group var groups []*types.Group
result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil { if result.Error != nil {
@@ -1796,11 +1677,7 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
// GetAccountPolicies retrieves policies for an account. // GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) { func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var policies []*types.Policy var policies []*types.Policy
result := tx. result := tx.
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID) Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
@@ -1814,11 +1691,7 @@ func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingS
// GetPolicyByID retrieves a policy by its ID and account ID. // GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) { func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var policy *types.Policy var policy *types.Policy
result := tx.Preload(clause.Associations). result := tx.Preload(clause.Associations).
@@ -1880,11 +1753,7 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrengt
// GetAccountPostureChecks retrieves posture checks for an account. // GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
result := tx.Find(&postureChecks, accountIDCondition, accountID) result := tx.Find(&postureChecks, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
@@ -1897,10 +1766,7 @@ func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength Loc
// GetPostureChecksByID retrieves posture checks by their ID and account ID. // GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) { func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var postureCheck *posture.Checks var postureCheck *posture.Checks
result := tx. result := tx.
@@ -1918,11 +1784,7 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin
// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID. // GetPostureChecksByIDs retrieves posture checks by their IDs and account ID.
func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) { func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
result := tx.Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs) result := tx.Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs)
if result.Error != nil { if result.Error != nil {
@@ -1967,9 +1829,9 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
// GetAccountRoutes retrieves network routes for an account. // GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
tx := s.getTXWithLockStrength(lockStrength)
var routes []*route.Route var routes []*route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := tx.Find(&routes, accountIDCondition, accountID)
Find(&routes, accountIDCondition, accountID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err) log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get routes from store") return nil, status.Errorf(status.Internal, "failed to get routes from store")
@@ -1980,9 +1842,9 @@ func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStr
// GetRouteByID retrieves a route by its ID and account ID. // GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) { func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
tx := s.getTXWithLockStrength(lockStrength)
var route *route.Route var route *route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := tx.First(&route, accountAndIDQueryCondition, accountID, routeID)
First(&route, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewRouteNotFoundError(routeID) return nil, status.NewRouteNotFoundError(routeID)
@@ -2023,11 +1885,7 @@ func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength
// GetAccountSetupKeys retrieves setup keys for an account. // GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) { func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var setupKeys []*types.SetupKey var setupKeys []*types.SetupKey
result := tx. result := tx.
Find(&setupKeys, accountIDCondition, accountID) Find(&setupKeys, accountIDCondition, accountID)
@@ -2041,14 +1899,9 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking
// GetSetupKeyByID retrieves a setup key by its ID and account ID. // GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) { func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var setupKey *types.SetupKey var setupKey *types.SetupKey
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). result := tx.First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewSetupKeyNotFoundError(setupKeyID) return nil, status.NewSetupKeyNotFoundError(setupKeyID)
@@ -2088,11 +1941,7 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren
// GetAccountNameServerGroups retrieves name server groups for an account. // GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var nsGroups []*nbdns.NameServerGroup var nsGroups []*nbdns.NameServerGroup
result := tx.Find(&nsGroups, accountIDCondition, accountID) result := tx.Find(&nsGroups, accountIDCondition, accountID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
@@ -2105,11 +1954,7 @@ func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength
// GetNameServerGroupByID retrieves a name server group by its ID and account ID. // GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var nsGroup *nbdns.NameServerGroup var nsGroup *nbdns.NameServerGroup
result := tx. result := tx.
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID) First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
@@ -2182,11 +2027,7 @@ func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength Locking
} }
func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) { func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var networks []*networkTypes.Network var networks []*networkTypes.Network
result := tx.Find(&networks, accountIDCondition, accountID) result := tx.Find(&networks, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
@@ -2198,11 +2039,7 @@ func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingS
} }
func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) { func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var network *networkTypes.Network var network *networkTypes.Network
result := tx. result := tx.
First(&network, accountAndIDQueryCondition, accountID, networkID) First(&network, accountAndIDQueryCondition, accountID, networkID)
@@ -2244,11 +2081,7 @@ func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStreng
} }
func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) { func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var netRouters []*routerTypes.NetworkRouter var netRouters []*routerTypes.NetworkRouter
result := tx. result := tx.
Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID) Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID)
@@ -2261,11 +2094,7 @@ func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength Lo
} }
func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) { func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var netRouters []*routerTypes.NetworkRouter var netRouters []*routerTypes.NetworkRouter
result := tx. result := tx.
Find(&netRouters, accountIDCondition, accountID) Find(&netRouters, accountIDCondition, accountID)
@@ -2278,11 +2107,7 @@ func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrengt
} }
func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) { func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var netRouter *routerTypes.NetworkRouter var netRouter *routerTypes.NetworkRouter
result := tx. result := tx.
First(&netRouter, accountAndIDQueryCondition, accountID, routerID) First(&netRouter, accountAndIDQueryCondition, accountID, routerID)
@@ -2323,11 +2148,7 @@ func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength Locking
} }
func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) { func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var netResources []*resourceTypes.NetworkResource var netResources []*resourceTypes.NetworkResource
result := tx. result := tx.
Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID) Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID)
@@ -2340,11 +2161,7 @@ func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength
} }
func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) { func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var netResources []*resourceTypes.NetworkResource var netResources []*resourceTypes.NetworkResource
result := tx. result := tx.
Find(&netResources, accountIDCondition, accountID) Find(&netResources, accountIDCondition, accountID)
@@ -2357,11 +2174,7 @@ func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStren
} }
func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) { func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var netResources *resourceTypes.NetworkResource var netResources *resourceTypes.NetworkResource
result := tx. result := tx.
First(&netResources, accountAndIDQueryCondition, accountID, resourceID) First(&netResources, accountAndIDQueryCondition, accountID, resourceID)
@@ -2377,11 +2190,7 @@ func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength Lock
} }
func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) { func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var netResources *resourceTypes.NetworkResource var netResources *resourceTypes.NetworkResource
result := tx. result := tx.
First(&netResources, "account_id = ? AND name = ?", accountID, resourceName) First(&netResources, "account_id = ? AND name = ?", accountID, resourceName)
@@ -2423,10 +2232,7 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token. // GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) { func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var pat types.PersonalAccessToken var pat types.PersonalAccessToken
result := tx.First(&pat, "hashed_token = ?", hashedToken) result := tx.First(&pat, "hashed_token = ?", hashedToken)
@@ -2443,11 +2249,7 @@ func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength Locking
// GetPATByID retrieves a personal access token by its ID and user ID. // GetPATByID retrieves a personal access token by its ID and user ID.
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) { func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var pat types.PersonalAccessToken var pat types.PersonalAccessToken
result := tx. result := tx.
First(&pat, "id = ? AND user_id = ?", patID, userID) First(&pat, "id = ? AND user_id = ?", patID, userID)
@@ -2464,11 +2266,7 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength,
// GetUserPATs retrieves personal access tokens for a user. // GetUserPATs retrieves personal access tokens for a user.
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) { func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var pats []*types.PersonalAccessToken var pats []*types.PersonalAccessToken
result := tx.Find(&pats, "user_id = ?", userID) result := tx.Find(&pats, "user_id = ?", userID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
@@ -2528,11 +2326,7 @@ func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength,
} }
func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) { func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
jsonValue := fmt.Sprintf(`"%s"`, ip.String()) jsonValue := fmt.Sprintf(`"%s"`, ip.String())
var peer nbpeer.Peer var peer nbpeer.Peer
@@ -2559,3 +2353,11 @@ func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain stri
return count, nil return count, nil
} }
func (s *SqlStore) getTXWithLockStrength(lockStrength LockingStrength) *gorm.DB {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
return tx
}

View File

@@ -95,14 +95,14 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
inviterID := userID inviterID := userID
if initiatorUser.IsServiceUser { if initiatorUser.IsServiceUser {
createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthShare, accountID) createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -178,13 +178,13 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
} }
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
} }
// GetUser looks up a user by provided nbContext.UserAuths. // GetUser looks up a user by provided nbContext.UserAuths.
// Expects account to have been created already. // Expects account to have been created already.
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) { func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -209,7 +209,7 @@ func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAu
// ListUsers returns lists of all users under the account. // ListUsers returns lists of all users under the account.
// It doesn't populate user information such as email or name. // It doesn't populate user information such as email or name.
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
return am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
} }
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error { func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error {
@@ -230,7 +230,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -243,7 +243,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -347,12 +347,12 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -390,12 +390,12 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return err return err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -404,7 +404,7 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return status.NewAdminPermissionError() return status.NewAdminPermissionError()
} }
pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID) pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
if err != nil { if err != nil {
return err return err
} }
@@ -429,12 +429,12 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -443,7 +443,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, status.NewAdminPermissionError() return nil, status.NewAdminPermissionError()
} }
return am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID) return am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
} }
// GetAllPATs returns all PATs for a user // GetAllPATs returns all PATs for a user
@@ -456,12 +456,12 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -470,7 +470,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, status.NewAdminPermissionError() return nil, status.NewAdminPermissionError()
} }
return am.Store.GetUserPATs(ctx, store.LockingStrengthShare, targetUserID) return am.Store.GetUserPATs(ctx, store.LockingStrengthNone, targetUserID)
} }
// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error. // SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
@@ -511,7 +511,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
if !allowed { if !allowed {
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -521,7 +521,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var addUserEvents []func() var addUserEvents []func()
var usersToSave = make([]*types.User, 0, len(updates)) var usersToSave = make([]*types.User, 0, len(updates))
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err) return nil, fmt.Errorf("error getting account groups: %w", err)
} }
@@ -533,7 +533,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var initiatorUser *types.User var initiatorUser *types.User
if initiatorUserID != activity.SystemInitiator { if initiatorUserID != activity.SystemInitiator {
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -695,7 +695,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist. // getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) { func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) {
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id) existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, update.Id)
if err != nil { if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
if !addIfNotExists { if !addIfNotExists {
@@ -830,7 +830,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
var user *types.User var user *types.User
if initiatorUserID != activity.SystemInitiator { if initiatorUserID != activity.SystemInitiator {
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err) return nil, fmt.Errorf("failed to get user: %w", err)
} }
@@ -840,7 +840,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
accountUsers := []*types.User{} accountUsers := []*types.User{}
switch { switch {
case allowed: case allowed:
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -933,7 +933,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
// expireAndUpdatePeers expires all peers of the given user and updates them in the account // expireAndUpdatePeers expires all peers of the given user and updates them in the account
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error {
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -1003,7 +1003,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -1017,7 +1017,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
continue continue
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
allErrors = errors.Join(allErrors, err) allErrors = errors.Join(allErrors, err)
continue continue
@@ -1081,12 +1081,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var err error var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserInfo.ID) targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserInfo.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user to delete: %w", err) return fmt.Errorf("failed to get user to delete: %w", err)
} }
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, targetUserInfo.ID) userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user peers: %w", err) return fmt.Errorf("failed to get user peers: %w", err)
} }
@@ -1120,7 +1120,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
// GetOwnerInfo retrieves the owner information for a given account ID. // GetOwnerInfo retrieves the owner information for a given account ID.
func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) { func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) {
owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthShare, accountID) owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1257,7 +1257,7 @@ func validateUserInvite(invite *types.UserInfo) error {
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
accountID, userID := userAuth.AccountId, userAuth.UserId accountID, userID := userAuth.AccountId, userAuth.UserId
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1274,7 +1274,7 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut
return nil, err return nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -88,7 +88,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
assert.Equal(t, pat.ID, tokenID) assert.Equal(t, pat.ID, tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID) user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthNone, tokenID)
if err != nil { if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err) t.Fatalf("Error when getting user by token ID: %s", err)
} }
@@ -1521,7 +1521,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
_, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true) _, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true)
assert.Error(t, err, "update user to another account should fail") assert.Error(t, err, "update user to another account should fail")
user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthShare, targetId) user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetId)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, account1.Users[targetId].Id, user.Id) assert.Equal(t, account1.Users[targetId].Id, user.Id)
assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID) assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID)

View File

@@ -26,7 +26,7 @@ func NewManager(store store.Store) Manager {
} }
func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) { func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) {
return m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
} }
func NewManagerMock() Manager { func NewManagerMock() Manager {