diff --git a/management/server/account.go b/management/server/account.go index 49ea38fe1..0f60bc91c 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -40,12 +40,12 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/settings" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -346,12 +346,12 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if updateAccountPeers || groupsUpdated { - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } } - return transaction.SaveAccountSettings(ctx, store.LockingStrengthUpdate, accountID, newSettings) + return transaction.SaveAccountSettings(ctx, accountID, newSettings) }) if err != nil { return nil, err @@ -405,7 +405,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) } - peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") + peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") if err != nil { return err } @@ -746,7 +746,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u // AccountExists checks if an account exists. func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { - return am.Store.AccountExists(ctx, store.LockingStrengthShare, accountID) + return am.Store.AccountExists(ctx, store.LockingStrengthNone, accountID) } // GetAccountIDByUserID retrieves the account ID based on the userID provided. @@ -758,7 +758,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI return "", status.Errorf(status.NotFound, "no valid userID provided") } - accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) + accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) @@ -813,7 +813,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any) log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) accountIDString := fmt.Sprintf("%v", accountID) - accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountIDString) + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString) if err != nil { return nil, nil, err } @@ -867,7 +867,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) { - accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } @@ -897,7 +897,7 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s // add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP, // or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID) return nil, err @@ -1048,7 +1048,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlockAccount() - accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, accountID) + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return err @@ -1058,7 +1058,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx return nil } - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting user: %v", err) return err @@ -1145,7 +1145,7 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, newUser := types.NewRegularUser(userAuth.UserId) newUser.AccountID = domainAccountID - err := am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser) + err := am.Store.SaveUser(ctx, newUser) if err != nil { return "", err } @@ -1223,7 +1223,7 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s return nil, status.NewPermissionDeniedError() } - return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID) + return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID) } // GetAccountOnboarding retrieves the onboarding information for a specific account. @@ -1308,7 +1308,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u return "", "", err } - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { // this is not really possible because we got an account by user ID return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId) @@ -1340,7 +1340,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth return nil } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId) if err != nil { return err } @@ -1366,12 +1366,12 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth var hasChanges bool var user *types.User err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) + user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { return fmt.Errorf("error getting user: %w", err) } - groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId) + groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } @@ -1387,7 +1387,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth return nil } - if err = transaction.CreateGroups(ctx, store.LockingStrengthUpdate, userAuth.AccountId, newGroupsToCreate); err != nil { + if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil { return fmt.Errorf("error saving groups: %w", err) } @@ -1395,13 +1395,13 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups) user.AutoGroups = updatedAutoGroups - if err = transaction.SaveUser(ctx, store.LockingStrengthUpdate, user); err != nil { + if err = transaction.SaveUser(ctx, user); err != nil { return fmt.Errorf("error saving user: %w", err) } // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId) + peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, userAuth.AccountId, userAuth.UserId) if err != nil { return fmt.Errorf("error getting user peers: %w", err) } @@ -1419,7 +1419,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, userAuth.AccountId); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, userAuth.AccountId); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } } @@ -1437,7 +1437,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } for _, g := range addNewGroups { - group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) + group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) } else { @@ -1450,7 +1450,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } for _, g := range removeOldGroups { - group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) + group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) } else { @@ -1511,7 +1511,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context } if userAuth.IsChild { - exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId) + exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, userAuth.AccountId) if err != nil || !exists { return "", err } @@ -1535,7 +1535,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", err } - userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err @@ -1556,7 +1556,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return am.addNewPrivateAccount(ctx, domainAccountID, userAuth) } func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) @@ -1571,7 +1571,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont cancel := am.Store.AcquireGlobalLock(ctx) // check again if the domain has a primary account because of simultaneous requests - domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) + domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain) if handleNotFound(err) != nil { cancel() log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) @@ -1582,7 +1582,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont } func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { - userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err @@ -1592,7 +1592,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId) } - accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId) + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, userAuth.AccountId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return "", err @@ -1603,7 +1603,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context } // We checked if the domain has a primary account already - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, userAuth.Domain) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) return "", err @@ -1751,7 +1751,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee } func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) { - user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) + user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID) if err != nil { return false, err } @@ -1780,7 +1780,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account if !allowed { return nil, status.NewPermissionDeniedError() } - return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) } // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id @@ -1870,7 +1870,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C cancel := am.Store.AcquireGlobalLock(ctx) defer cancel() - existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) + existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain) if handleNotFound(err) != nil { return nil, false, err } @@ -1890,7 +1890,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C for range 2 { accountId := xid.New().String() - exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId) + exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, accountId) if err != nil || exists { continue } @@ -1965,7 +1965,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc return nil } - existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain) + existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain) // error is not a not found error if handleNotFound(err) != nil { @@ -2002,17 +2002,17 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc // propagateUserGroupMemberships propagates all account users' group memberships to their peers. // Returns true if any groups were modified, true if those updates affect peers and an error. func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { - users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) + users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) if err != nil { return false, false, err } - accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthShare, accountID) + accountGroupPeers, err := transaction.GetAccountGroupPeers(ctx, store.LockingStrengthNone, accountID) if err != nil { return false, false, fmt.Errorf("error getting account group peers: %w", err) } - accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + accountGroups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { return false, false, fmt.Errorf("error getting account groups: %w", err) } @@ -2025,7 +2025,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, updatedGroups := []string{} for _, user := range users { - userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id) + userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id) if err != nil { return false, false, err } @@ -2074,7 +2074,7 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t account.Network.Net = newIPNet - peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") + peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") if err != nil { return err } @@ -2099,7 +2099,7 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t } for _, peer := range peers { - if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil { + if err = transaction.SavePeer(ctx, accountID, peer); err != nil { return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err) } } @@ -2154,7 +2154,7 @@ func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context, return fmt.Errorf("get account: %w", err) } - existingPeer, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) + existingPeer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) if err != nil { return fmt.Errorf("get peer: %w", err) } @@ -2185,7 +2185,7 @@ func (am *DefaultAccountManager) updatePeerIPInTransaction(ctx context.Context, func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transaction store.Store, accountID, userID string, peer *nbpeer.Peer, newIP netip.Addr) error { log.WithContext(ctx).Infof("updating peer %s IP from %s to %s", peer.ID, peer.IP, newIP) - settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return fmt.Errorf("get account settings: %w", err) } @@ -2195,7 +2195,7 @@ func (am *DefaultAccountManager) savePeerIPUpdate(ctx context.Context, transacti oldIP := peer.IP.String() peer.IP = newIP.AsSlice() - err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer) + err = transaction.SavePeer(ctx, accountID, peer) if err != nil { return fmt.Errorf("save peer: %w", err) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 77014855f..0c618a8a3 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -783,7 +783,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) { return } - exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID) + exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthNone, accountID) assert.NoError(t, err) assert.True(t, exists, "expected to get existing account after creation using userid") @@ -900,11 +900,11 @@ func TestAccountManager_DeleteAccount(t *testing.T) { t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount)) } - pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1") + pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, "service-user-1") require.NoError(t, err) assert.Len(t, pats, 0) - pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId) + pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, userId) require.NoError(t, err) assert.Len(t, pats, 0) } @@ -1786,7 +1786,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") - settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) + settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID) require.NoError(t, err, "unable to get account settings") assert.NotNil(t, settings) @@ -1971,7 +1971,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updatedSettings.PeerLoginExpirationEnabled) assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour) - settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) + settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID) require.NoError(t, err, "unable to get account settings") assert.False(t, settings.PeerLoginExpirationEnabled) @@ -2655,7 +2655,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced") }) @@ -2669,7 +2669,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err := manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1") assert.NoError(t, err, "unable to get user") assert.Empty(t, user.AutoGroups, "auto groups must be empty") }) @@ -2683,18 +2683,18 @@ func TestAccount_SetJWTGroups(t *testing.T) { err := manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 0) - group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") + group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1") assert.NoError(t, err, "unable to get group") assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") }) t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} - assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"])) + assert.NoError(t, manager.Store.SaveUser(context.Background(), account.Users["user1"])) claims := nbcontext.UserAuth{ UserId: "user1", @@ -2704,11 +2704,11 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1) - group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") + group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1") assert.NoError(t, err, "unable to get group") assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") }) @@ -2722,7 +2722,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) @@ -2736,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) @@ -2750,11 +2750,11 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") - groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID") + groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, "accountID") assert.NoError(t, err) assert.Len(t, groups, 3, "new group3 should be added") - user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1, "new group should be added") }) @@ -2768,7 +2768,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present") @@ -2783,7 +2783,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.SyncUserJWTGroups(context.Background(), claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed") }) @@ -3348,11 +3348,11 @@ func TestPropagateUserGroupMemberships(t *testing.T) { require.NoError(t, err) peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId, IP: net.IP{1, 1, 1, 1}, DNSLabel: "peer1.domain.test"} - err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1) + err = manager.Store.AddPeerToAccount(ctx, peer1) require.NoError(t, err) peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId, IP: net.IP{2, 2, 2, 2}, DNSLabel: "peer2.domain.test"} - err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2) + err = manager.Store.AddPeerToAccount(ctx, peer2) require.NoError(t, err) t.Run("should skip propagation when the user has no groups", func(t *testing.T) { @@ -3364,20 +3364,20 @@ func TestPropagateUserGroupMemberships(t *testing.T) { t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) { group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id} - require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group1)) + require.NoError(t, manager.Store.CreateGroup(ctx, group1)) - user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) + user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId) require.NoError(t, err) user.AutoGroups = append(user.AutoGroups, group1.ID) - require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) + require.NoError(t, manager.Store.SaveUser(ctx, user)) groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.True(t, groupsUpdated) assert.False(t, groupChangesAffectPeers) - group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthShare, account.Id, group1.ID) + group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, group1.ID) require.NoError(t, err) assert.Len(t, group.Peers, 2) assert.Contains(t, group.Peers, "peer1") @@ -3386,13 +3386,13 @@ func TestPropagateUserGroupMemberships(t *testing.T) { t.Run("should update membership and account peers for used groups", func(t *testing.T) { group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id} - require.NoError(t, manager.Store.CreateGroup(ctx, store.LockingStrengthUpdate, group2)) + require.NoError(t, manager.Store.CreateGroup(ctx, group2)) - user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) + user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId) require.NoError(t, err) user.AutoGroups = append(user.AutoGroups, group2.ID) - require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) + require.NoError(t, manager.Store.SaveUser(ctx, user)) _, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{ Name: "Group1 Policy", @@ -3415,7 +3415,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) { assert.True(t, groupsUpdated) assert.True(t, groupChangesAffectPeers) - groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"}) + groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"}) require.NoError(t, err) for _, group := range groups { assert.Len(t, group.Peers, 2) @@ -3432,18 +3432,18 @@ func TestPropagateUserGroupMemberships(t *testing.T) { }) t.Run("should not remove peers when groups are removed from user", func(t *testing.T) { - user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) + user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId) require.NoError(t, err) user.AutoGroups = []string{"group1"} - require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) + require.NoError(t, manager.Store.SaveUser(ctx, user)) groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) require.NoError(t, err) assert.False(t, groupsUpdated) assert.False(t, groupChangesAffectPeers) - groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"}) + groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"}) require.NoError(t, err) for _, group := range groups { assert.Len(t, group.Peers, 2) diff --git a/management/server/auth/manager.go b/management/server/auth/manager.go index 6835a3ced..53d479c90 100644 --- a/management/server/auth/manager.go +++ b/management/server/auth/manager.go @@ -73,7 +73,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco return userAuth, nil } - settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) + settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId) if err != nil { return userAuth, err } @@ -94,7 +94,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco // MarkPATUsed marks a personal access token as used func (am *manager) MarkPATUsed(ctx context.Context, tokenID string) error { - return am.store.MarkPATUsed(ctx, store.LockingStrengthUpdate, tokenID) + return am.store.MarkPATUsed(ctx, tokenID) } // GetPATInfo retrieves user, personal access token, domain, and category details from a personal access token. @@ -104,7 +104,7 @@ func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.Us return nil, nil, "", "", err } - domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID) + domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, user.AccountID) if err != nil { return nil, nil, "", "", err } @@ -142,12 +142,12 @@ func (am *manager) extractPATFromToken(ctx context.Context, token string) (*type var pat *types.PersonalAccessToken err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken) + pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthNone, encodedHashedToken) if err != nil { return err } - user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID) + user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthNone, pat.ID) return err }) if err != nil { diff --git a/management/server/dns.go b/management/server/dns.go index 489618c17..12aa6e21c 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -8,14 +8,14 @@ import ( log "github.com/sirupsen/logrus" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/management/status" ) // DNSConfigCache is a thread-safe cache for DNS configuration components @@ -72,7 +72,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s return nil, status.NewPermissionDeniedError() } - return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) + return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) } // SaveDNSSettings validates a user role and updates the account's DNS settings @@ -113,11 +113,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) eventsToStore = append(eventsToStore, events...) - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.SaveDNSSettings(ctx, store.LockingStrengthUpdate, accountID, dnsSettingsToSave) + return transaction.SaveDNSSettings(ctx, accountID, dnsSettingsToSave) }) if err != nil { return err @@ -139,7 +139,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t var eventsToStore []func() modifiedGroups := slices.Concat(addedGroups, removedGroups) - groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups) if err != nil { log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) return nil @@ -195,7 +195,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID return nil } - groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, settings.DisabledManagementGroups) if err != nil { return err } diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 9f4348ebb..e3cb5459a 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -134,7 +134,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. } func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { - peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare) + peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone) if err != nil { log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err) return diff --git a/management/server/event.go b/management/server/event.go index 3144c52ea..d26c569ae 100644 --- a/management/server/event.go +++ b/management/server/event.go @@ -11,9 +11,9 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/status" ) func isEnabled() bool { @@ -103,7 +103,7 @@ func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, eve } func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) { - accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId) + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountId) if err != nil { return nil, err } @@ -154,7 +154,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context, continue } - externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) + externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id) if err != nil { // @todo consider logging continue diff --git a/management/server/group.go b/management/server/group.go index 2b804b5f6..915a87086 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -14,11 +14,11 @@ import ( routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/status" ) type GroupLinkError struct { @@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) + return am.Store.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID) } // GetAllGroups returns all groups in an account @@ -57,12 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + return am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { - return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName) + return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) } // CreateGroup object of the peers @@ -96,11 +96,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - if err := transaction.CreateGroup(ctx, store.LockingStrengthUpdate, newGroup); err != nil { + if err := transaction.CreateGroup(ctx, newGroup); err != nil { return status.Errorf(status.Internal, "failed to create group: %v", err) } @@ -147,7 +147,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use return err } - oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) + oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID) if err != nil { return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID) } @@ -176,11 +176,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, newGroup) + return transaction.UpdateGroup(ctx, newGroup) }) if err != nil { return err @@ -234,11 +234,11 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.CreateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) + return transaction.CreateGroups(ctx, accountID, groupsToSave) }) if err != nil { return err @@ -292,11 +292,11 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.UpdateGroups(ctx, store.LockingStrengthUpdate, accountID, groupsToSave) + return transaction.UpdateGroups(ctx, accountID, groupsToSave) }) if err != nil { return err @@ -320,7 +320,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac addedPeers := make([]string, 0) removedPeers := make([]string, 0) - oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) + oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID) if err == nil && oldGroup != nil { addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers) removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers) @@ -332,13 +332,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac } modifiedPeers := slices.Concat(addedPeers, removedPeers) - peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers) + peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, modifiedPeers) if err != nil { log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) return nil } - settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err) return nil @@ -423,11 +423,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us deletedGroups = append(deletedGroups, group) } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.DeleteGroups(ctx, store.LockingStrengthUpdate, accountID, groupIDsToDelete) + return transaction.DeleteGroups(ctx, accountID, groupIDsToDelete) }) if err != nil { return err @@ -454,7 +454,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } @@ -495,11 +495,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group) + return transaction.UpdateGroup(ctx, group) }) if err != nil { return err @@ -526,7 +526,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } @@ -567,11 +567,11 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.UpdateGroup(ctx, store.LockingStrengthUpdate, group) + return transaction.UpdateGroup(ctx, group) }) if err != nil { return err @@ -591,7 +591,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st } if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI { - existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name) + existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthNone, accountID, newGroup.Name) if err != nil { if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { return err @@ -608,7 +608,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st } for _, peerID := range newGroup.Peers { - _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) + _, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) if err != nil { return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) } @@ -620,7 +620,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user if group.Issued == types.GroupIssuedIntegration { - executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { return status.Errorf(status.Internal, "failed to get user") } @@ -666,7 +666,7 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty // checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error { - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID) + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID) if err != nil { return status.Errorf(status.Internal, "failed to get DNS settings") } @@ -675,7 +675,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr return &GroupLinkError{"disabled DNS management groups", group.Name} } - settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID) if err != nil { return status.Errorf(status.Internal, "failed to get account settings") } @@ -689,7 +689,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr // isGroupLinkedToRoute checks if a group is linked to any route in the account. func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) { - routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) + routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) return false, nil @@ -709,7 +709,7 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) { - policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) return false, nil @@ -727,7 +727,7 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, account // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { - nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) + nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) return false, nil @@ -746,7 +746,7 @@ func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) { - setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) + setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) return false, nil @@ -762,7 +762,7 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accou // isGroupLinkedToUser checks if a group is linked to any user in the account. func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) { - users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) + users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) return false, nil @@ -778,7 +778,7 @@ func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID // isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account. func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) { - routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) + routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err) return false, nil @@ -798,7 +798,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac return false, nil } - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return false, err } @@ -826,7 +826,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac // anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { - groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs) if err != nil { return false, err } diff --git a/management/server/group_test.go b/management/server/group_test.go index c31280156..1626a0464 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -26,10 +26,10 @@ import ( networkTypes "github.com/netbirdio/netbird/management/server/networks/types" peer2 "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -898,7 +898,7 @@ func Test_AddPeerAndAddToAll(t *testing.T) { } err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error { - err = transaction.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) + err = transaction.AddPeerToAccount(context.Background(), peer) if err != nil { return fmt.Errorf("AddPeer failed for peer %d: %w", i, err) } @@ -971,7 +971,7 @@ func Test_IncrementNetworkSerial(t *testing.T) { <-start err = manager.Store.ExecuteInTransaction(context.Background(), func(transaction store.Store) error { - err = transaction.IncrementNetworkSerial(context.Background(), store.LockingStrengthNone, accountID) + err = transaction.IncrementNetworkSerial(context.Background(), accountID) if err != nil { return fmt.Errorf("failed to get account %s: %v", accountID, err) } diff --git a/management/server/groups/manager.go b/management/server/groups/manager.go index 8647eeb68..dd11f862f 100644 --- a/management/server/groups/manager.go +++ b/management/server/groups/manager.go @@ -6,12 +6,12 @@ import ( "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/http/api" ) type Manager interface { @@ -49,7 +49,7 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string return nil, err } - groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, fmt.Errorf("error getting account groups: %w", err) } @@ -96,13 +96,13 @@ func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, trans return nil, fmt.Errorf("error adding resource to group: %w", err) } - group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) + group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID) if err != nil { return nil, fmt.Errorf("error getting group: %w", err) } // TODO: at some point, this will need to become a switch statement - networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resource.ID) + networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resource.ID) if err != nil { return nil, fmt.Errorf("error getting network resource: %w", err) } @@ -120,13 +120,13 @@ func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context, return nil, fmt.Errorf("error removing resource from group: %w", err) } - group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) + group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID) if err != nil { return nil, fmt.Errorf("error getting group: %w", err) } // TODO: at some point, this will need to become a switch statement - networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) + networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID) if err != nil { return nil, fmt.Errorf("error getting network resource: %w", err) } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index f5d9c250b..8893ad2e2 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -24,7 +24,6 @@ import ( "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/encryption" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/auth" @@ -32,9 +31,10 @@ import ( nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/settings" - internalStatus "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" + internalStatus "github.com/netbirdio/netbird/shared/management/status" ) // GRPCServer an instance of a Management gRPC API server @@ -920,7 +920,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (* return nil, err } - peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peerKey.String()) + peer, err := s.accountManager.GetStore().GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerKey.String()) if err != nil { log.WithContext(ctx).Debugf("peer %s is not registered for logout", peerKey.String()) // TODO: consider idempotency diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 1e92e0c50..73abacc36 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -77,7 +77,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { for _, groupID := range groupIDs { - _, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID) + _, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthNone, accountID, groupID) if err != nil { return err } @@ -97,17 +97,17 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI var peers []*nbpeer.Peer var settings *types.Settings - groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } - peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") + peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") if err != nil { return nil, err } - settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index b896c52da..c9f8b5448 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -22,7 +22,6 @@ import ( "github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/formatter/hook" - mgmtProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" @@ -30,6 +29,7 @@ import ( "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" + mgmtProto "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/util" ) @@ -645,7 +645,7 @@ func testSyncStatusRace(t *testing.T) { } time.Sleep(10 * time.Millisecond) - peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) + peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerWithInvalidStatus.PublicKey().String()) if err != nil { t.Fatal(err) return diff --git a/management/server/nameserver.go b/management/server/nameserver.go index ef62d0eaf..1ee8805fc 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -13,9 +13,9 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/status" ) const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*[*.a-z]{1,}$` @@ -32,7 +32,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account return nil, status.NewPermissionDeniedError() } - return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID) + return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupID) } // CreateNameServerGroup creates and saves a new nameserver group @@ -73,11 +73,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, newNSGroup) + return transaction.SaveNameServerGroup(ctx, newNSGroup) }) if err != nil { return nil, err @@ -112,7 +112,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun var updateAccountPeers bool err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID) + oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID) if err != nil { return err } @@ -127,11 +127,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, nsGroupToSave) + return transaction.SaveNameServerGroup(ctx, nsGroupToSave) }) if err != nil { return err @@ -173,11 +173,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.DeleteNameServerGroup(ctx, store.LockingStrengthUpdate, accountID, nsGroupID) + return transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID) }) if err != nil { return err @@ -202,7 +202,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou return nil, status.NewPermissionDeniedError() } - return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) + return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID) } func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { @@ -216,7 +216,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou return err } - nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) + nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { return err } @@ -226,7 +226,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou return err } - groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, nameserverGroup.Groups) if err != nil { return err } diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index cb1116b0e..2bab0e289 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -14,8 +14,8 @@ import ( "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" ) type Manager interface { @@ -56,7 +56,7 @@ func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID stri return nil, status.NewPermissionDeniedError() } - return m.store.GetAccountNetworks(ctx, store.LockingStrengthShare, accountID) + return m.store.GetAccountNetworks(ctx, store.LockingStrengthNone, accountID) } func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { @@ -73,7 +73,7 @@ func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID) defer unlock() - err = m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) + err = m.store.SaveNetwork(ctx, network) if err != nil { return nil, fmt.Errorf("failed to save network: %w", err) } @@ -92,7 +92,7 @@ func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, network return nil, status.NewPermissionDeniedError() } - return m.store.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) + return m.store.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID) } func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { @@ -114,7 +114,7 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta()) - return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) + return network, m.store.SaveNetwork(ctx, network) } func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { @@ -162,12 +162,12 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw eventsToStore = append(eventsToStore, event) } - err = transaction.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID) + err = transaction.DeleteNetwork(ctx, accountID, networkID) if err != nil { return fmt.Errorf("failed to delete network: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + err = transaction.IncrementNetworkSerial(ctx, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 7c8c68e32..d0b29075b 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -12,10 +12,10 @@ import ( "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" nbtypes "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/status" ) type Manager interface { @@ -57,7 +57,7 @@ func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, u return nil, status.NewPermissionDeniedError() } - return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthShare, accountID, networkID) + return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthNone, accountID, networkID) } func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) { @@ -69,7 +69,7 @@ func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, u return nil, status.NewPermissionDeniedError() } - return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) + return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID) } func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { @@ -81,7 +81,7 @@ func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, return nil, status.NewPermissionDeniedError() } - resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) + resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, fmt.Errorf("failed to get network resources: %w", err) } @@ -113,7 +113,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc var eventsToStore []func() err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - _, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) + _, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name) if err == nil { return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name) } @@ -123,7 +123,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc return fmt.Errorf("failed to get network: %w", err) } - err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + err = transaction.SaveNetworkResource(ctx, resource) if err != nil { return fmt.Errorf("failed to save network resource: %w", err) } @@ -145,7 +145,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc eventsToStore = append(eventsToStore, event) } - err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID) + err = transaction.IncrementNetworkSerial(ctx, resource.AccountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -174,7 +174,7 @@ func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networ return nil, status.NewPermissionDeniedError() } - resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) + resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID) if err != nil { return nil, fmt.Errorf("failed to get network resource: %w", err) } @@ -218,22 +218,22 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID) } - _, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) + _, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID) if err != nil { return fmt.Errorf("failed to get network resource: %w", err) } - oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) + oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name) if err == nil && oldResource.ID != resource.ID { return status.Errorf(status.InvalidArgument, "new resource name already exists") } - oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) + oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID) if err != nil { return fmt.Errorf("failed to get network resource: %w", err) } - err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + err = transaction.SaveNetworkResource(ctx, resource) if err != nil { return fmt.Errorf("failed to save network resource: %w", err) } @@ -248,7 +248,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network)) }) - err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID) + err = transaction.IncrementNetworkSerial(ctx, resource.AccountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -325,7 +325,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net return fmt.Errorf("failed to delete resource: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + err = transaction.IncrementNetworkSerial(ctx, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -375,7 +375,7 @@ func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transacti eventsToStore = append(eventsToStore, event) } - err = transaction.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID) + err = transaction.DeleteNetworkResource(ctx, accountID, resourceID) if err != nil { return nil, fmt.Errorf("failed to delete network resource: %w", err) } diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index ff5a771d6..ca99e4fd1 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -14,8 +14,8 @@ import ( "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" ) type Manager interface { @@ -54,7 +54,7 @@ func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, use return nil, status.NewPermissionDeniedError() } - return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthShare, accountID, networkID) + return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, accountID, networkID) } func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) { @@ -66,7 +66,7 @@ func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, use return nil, status.NewPermissionDeniedError() } - routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) + routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, fmt.Errorf("failed to get network routers: %w", err) } @@ -93,7 +93,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t var network *networkTypes.Network err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) + network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID) if err != nil { return fmt.Errorf("failed to get network: %w", err) } @@ -104,12 +104,12 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t router.ID = xid.New().String() - err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + err = transaction.SaveNetworkRouter(ctx, router) if err != nil { return fmt.Errorf("failed to create network router: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID) + err = transaction.IncrementNetworkSerial(ctx, router.AccountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -136,7 +136,7 @@ func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkI return nil, status.NewPermissionDeniedError() } - router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthShare, accountID, routerID) + router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthNone, accountID, routerID) if err != nil { return nil, fmt.Errorf("failed to get network router: %w", err) } @@ -162,7 +162,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t var network *networkTypes.Network err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) + network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID) if err != nil { return fmt.Errorf("failed to get network: %w", err) } @@ -171,12 +171,12 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID) } - err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + err = transaction.SaveNetworkRouter(ctx, router) if err != nil { return fmt.Errorf("failed to update network router: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID) + err = transaction.IncrementNetworkSerial(ctx, router.AccountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -213,7 +213,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo return fmt.Errorf("failed to delete network router: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + err = transaction.IncrementNetworkSerial(ctx, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -232,7 +232,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo } func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) { - network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID) if err != nil { return nil, fmt.Errorf("failed to get network: %w", err) } @@ -246,7 +246,7 @@ func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID) } - err = transaction.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID) + err = transaction.DeleteNetworkRouter(ctx, accountID, routerID) if err != nil { return nil, fmt.Errorf("failed to delete network router: %w", err) } diff --git a/management/server/peer.go b/management/server/peer.go index 6cd519f27..d72eac91a 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -17,28 +17,28 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/idp" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // the current user is not an admin. func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { return nil, err } @@ -48,7 +48,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, status.NewPermissionValidationError(err) } - accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, nameFilter, ipFilter) + accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter) if err != nil { return nil, err } @@ -58,7 +58,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return accountPeers, nil } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, fmt.Errorf("failed to get account settings: %w", err) } @@ -130,7 +130,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if peer.AddedWithSSOLogin() { - settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return err } @@ -173,7 +173,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio peer.Location.CountryCode = location.Country.ISOCode peer.Location.CityName = location.City.Names.En peer.Location.GeoNameID = location.City.GeonameID - err = transaction.SavePeerLocation(ctx, store.LockingStrengthUpdate, accountID, peer) + err = transaction.SavePeerLocation(ctx, accountID, peer) if err != nil { log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) } @@ -182,7 +182,7 @@ func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocatio log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected) - err := transaction.SavePeerStatus(ctx, store.LockingStrengthUpdate, accountID, peer.ID, *newStatus) + err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus) if err != nil { return false, err } @@ -219,7 +219,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return err } - settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return err } @@ -281,7 +281,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user inactivityExpirationChanged = true } - return transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer) + return transaction.SavePeer(ctx, accountID, peer) }) if err != nil { return nil, err @@ -346,7 +346,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return status.NewPermissionDeniedError() } - peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID) + peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) if err != nil { return err } @@ -383,7 +383,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return fmt.Errorf("failed to delete peer: %w", err) } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -617,7 +617,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s }() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer) + err = transaction.AddPeerToAccount(ctx, newPeer) if err != nil { return err } @@ -658,7 +658,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } } - err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + err = transaction.IncrementNetworkSerial(ctx, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -734,7 +734,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy var err error var postureChecks []*posture.Checks - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, nil, nil, err } @@ -746,7 +746,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } if peer.UserID != "" { - user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) + user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID) if err != nil { return err } @@ -774,7 +774,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy if updated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) - if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil { + if err = transaction.SavePeer(ctx, accountID, peer); err != nil { return err } @@ -849,7 +849,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer var isPeerUpdated bool var postureChecks []*posture.Checks - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, nil, nil, err } @@ -911,7 +911,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer } if shouldStorePeer { - if err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, accountID, peer); err != nil { + if err = transaction.SavePeer(ctx, accountID, peer); err != nil { return err } } @@ -934,7 +934,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer // getPeerPostureChecks returns the posture checks for the peer. func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { - policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } @@ -958,7 +958,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...) } - peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, peerPostureChecksIDs) + peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, peerPostureChecksIDs) if err != nil { return nil, err } @@ -973,7 +973,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli continue } - sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, rule.Sources) + sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources) if err != nil { return nil, err } @@ -998,7 +998,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli // with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired // and before starting the engine, we do the checks without an account lock to avoid piling up requests. func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login types.PeerLogin) error { - peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, login.WireGuardPubKey) if err != nil { return err } @@ -1009,7 +1009,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return err } @@ -1080,7 +1080,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact // If peer was expired before and if it reached this point, it is re-authenticated. // UserID is present, meaning that JWT validation passed successfully in the API layer. peer = peer.UpdateLastLogin() - err = transaction.SavePeer(ctx, store.LockingStrengthUpdate, peer.AccountID, peer) + err = transaction.SavePeer(ctx, peer.AccountID, peer) if err != nil { return err } @@ -1090,7 +1090,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact log.WithContext(ctx).Debugf("failed to update user last login: %v", err) } - settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, peer.AccountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, peer.AccountID) if err != nil { return fmt.Errorf("failed to get account settings: %w", err) } @@ -1132,7 +1132,7 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Se // GetPeer for a given accountID, peerID and userID error if not found. func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { - peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) + peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) if err != nil { return nil, err } @@ -1145,7 +1145,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return peer, nil } - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { return nil, err } @@ -1171,7 +1171,7 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun // it is also possible that user doesn't own the peer but some of his peers have access to it, // this is a valid case, show the peer as well. - userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID) + userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) if err != nil { return nil, err } @@ -1394,7 +1394,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI // If there is no peer that expires this function returns false and a duration of 0. // This function only considers peers that haven't been expired yet and that are connected. func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { - peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) return peerSchedulerRetryInterval, true @@ -1404,7 +1404,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco return 0, false } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get account settings: %v", err) return peerSchedulerRetryInterval, true @@ -1438,7 +1438,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco // If there is no peer that expires this function returns false and a duration of 0. // This function only considers peers that haven't been expired yet and that are not connected. func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { - peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err) return peerSchedulerRetryInterval, true @@ -1448,7 +1448,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte return 0, false } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get account settings: %v", err) return peerSchedulerRetryInterval, true @@ -1479,12 +1479,12 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte // getExpiredPeers returns peers that have been expired. func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { - peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } @@ -1502,12 +1502,12 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID // getInactivePeers returns peers that have been expired by inactivity func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { - peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } @@ -1530,7 +1530,7 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p // getPeerGroupIDs returns the IDs of the groups that the peer is part of. func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { - return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthShare, accountID, peerID) + return transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID) } // IsPeerInActiveGroup checks if the given peer is part of a group that is used @@ -1548,7 +1548,7 @@ func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { var peerDeletedEvents []func() - settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } @@ -1568,7 +1568,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto return nil, err } - if err = transaction.DeletePeer(ctx, store.LockingStrengthUpdate, accountID, peer.ID); err != nil { + if err = transaction.DeletePeer(ctx, accountID, peer.ID); err != nil { return nil, err } @@ -1624,7 +1624,7 @@ func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transac // isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account. func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) { - routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) + routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err) return false, nil diff --git a/management/server/peer_test.go b/management/server/peer_test.go index a78f220d3..5c52692f3 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -38,8 +38,6 @@ import ( networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/shared/management/domain" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" @@ -47,6 +45,8 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" nbroute "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" + "github.com/netbirdio/netbird/shared/management/proto" ) func TestPeer_LoginExpired(t *testing.T) { @@ -1307,7 +1307,7 @@ func Test_RegisterPeerByUser(t *testing.T) { require.NoError(t, err) assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels) - peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key) + peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, addedPeer.Key) require.NoError(t, err) assert.Equal(t, peer.AccountID, existingAccountID) assert.Equal(t, peer.UserID, existingUserID) @@ -1442,7 +1442,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { assert.NotNil(t, addedPeer, "addedPeer should not be nil on success") assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch") - peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, currentPeer.Key) + peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, currentPeer.Key) require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key) assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store") assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store") @@ -1528,7 +1528,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) require.Error(t, err) - _, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key) + _, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key) require.Error(t, err) account, err := s.GetAccount(context.Background(), existingAccountID) @@ -1699,7 +1699,7 @@ func Test_LoginPeer(t *testing.T) { assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer") - peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, loginInput.WireGuardPubKey) + peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, loginInput.WireGuardPubKey) require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey) assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store") assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore") @@ -2160,10 +2160,10 @@ func Test_IsUniqueConstraintError(t *testing.T) { } t.Cleanup(cleanup) - err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) + err = s.AddPeerToAccount(context.Background(), peer) assert.NoError(t, err) - err = s.AddPeerToAccount(context.Background(), store.LockingStrengthUpdate, peer) + err = s.AddPeerToAccount(context.Background(), peer) result := isUniqueConstraintError(err) assert.True(t, result) }) diff --git a/management/server/peers/manager.go b/management/server/peers/manager.go index 7e0262ef9..50e36a880 100644 --- a/management/server/peers/manager.go +++ b/management/server/peers/manager.go @@ -10,8 +10,8 @@ import ( "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" ) type Manager interface { @@ -42,7 +42,7 @@ func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID str return nil, status.NewPermissionDeniedError() } - return m.store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) + return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) } func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { @@ -52,12 +52,12 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) } if !allowed { - return m.store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID) + return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) } - return m.store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") + return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") } func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { - return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID) + return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) } diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 4b2d91c3b..0ab244243 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -11,9 +11,9 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/roles" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/status" ) type Manager interface { @@ -45,7 +45,7 @@ func (m *managerImpl) ValidateUserPermissions( return true, nil } - user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { return false, err } diff --git a/management/server/policy.go b/management/server/policy.go index b19f99960..d5c66e9f8 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -6,11 +6,11 @@ import ( "github.com/rs/xid" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/posture" @@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, status.NewPermissionDeniedError() } - return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID) + return am.Store.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID) } // SavePolicy in the store @@ -61,7 +61,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } @@ -71,7 +71,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user saveFunc = transaction.SavePolicy } - return saveFunc(ctx, store.LockingStrengthUpdate, policy) + return saveFunc(ctx, policy) }) if err != nil { return nil, err @@ -113,11 +113,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.DeletePolicy(ctx, store.LockingStrengthUpdate, accountID, policyID) + return transaction.DeletePolicy(ctx, accountID, policyID) }) if err != nil { return err @@ -142,13 +142,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return nil, status.NewPermissionDeniedError() } - return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) + return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) } // arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { if isUpdate { - existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if err != nil { return false, err } @@ -173,7 +173,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a // validatePolicy validates the policy and its rules. func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { if policy.ID != "" { - _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) + _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if err != nil { return err } @@ -182,12 +182,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri policy.AccountID = accountID } - groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups()) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups()) if err != nil { return err } - postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks) + postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks) if err != nil { return err } diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 50f7e4776..9414b8065 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -13,9 +13,9 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/status" ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { @@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID return nil, status.NewPermissionDeniedError() } - return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) + return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID) } // SavePostureChecks saves a posture check. @@ -62,7 +62,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } @@ -70,7 +70,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI } postureChecks.AccountID = accountID - return transaction.SavePostureChecks(ctx, store.LockingStrengthUpdate, postureChecks) + return transaction.SavePostureChecks(ctx, postureChecks) }) if err != nil { return nil, err @@ -101,7 +101,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun var postureChecks *posture.Checks err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) + postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID) if err != nil { return err } @@ -110,11 +110,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.DeletePostureChecks(ctx, store.LockingStrengthUpdate, accountID, postureChecksID) + return transaction.DeletePostureChecks(ctx, accountID, postureChecksID) }) if err != nil { return err @@ -135,7 +135,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI return nil, status.NewPermissionDeniedError() } - return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) + return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID) } // getPeerPostureChecks returns the posture checks applied for a given peer. @@ -161,7 +161,7 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, pe // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { - policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) if err != nil { return false, err } @@ -190,14 +190,14 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account // If the posture check already has an ID, verify its existence in the store. if postureChecks.ID != "" { - if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil { + if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID); err != nil { return err } return nil } // For new posture checks, ensure no duplicates by name. - checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) + checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID) if err != nil { return err } @@ -259,7 +259,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error { - policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) if err != nil { return err } diff --git a/management/server/route.go b/management/server/route.go index 50c4bdc4a..6adff56b5 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -9,15 +9,15 @@ import ( "github.com/rs/xid" - "github.com/netbirdio/netbird/shared/management/domain" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" + "github.com/netbirdio/netbird/shared/management/proto" + "github.com/netbirdio/netbird/shared/management/status" ) // GetRoute gets a route object from account and route IDs @@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, return nil, status.NewPermissionDeniedError() } - return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID)) + return am.Store.GetRouteByID(ctx, store.LockingStrengthNone, accountID, string(routeID)) } // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. @@ -59,7 +59,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto seenPeers[string(prefixRoute.ID)] = true } - peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups) + peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, prefixRoute.PeerGroups) if err != nil { return err } @@ -83,7 +83,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto if peerID := checkRoute.Peer; peerID != "" { // check that peerID exists and is not in any route as single peer or part of the group - _, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID) + _, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthNone, accountID, peerID) if err != nil { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } @@ -104,7 +104,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto } // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix - peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers) + peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, group.Peers) if err != nil { return err } @@ -181,11 +181,11 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute) + return transaction.SaveRoute(ctx, newRoute) }) if err != nil { return nil, err @@ -238,11 +238,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } routeToSave.AccountID = accountID - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave) + return transaction.SaveRoute(ctx, routeToSave) }) if err != nil { return err @@ -284,11 +284,11 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri return err } - if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { return err } - return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID)) + return transaction.DeleteRoute(ctx, accountID, string(routeID)) }) am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) @@ -310,7 +310,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user return nil, status.NewPermissionDeniedError() } - return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) + return am.Store.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) } func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error { @@ -353,7 +353,7 @@ func validateRoute(ctx context.Context, transaction store.Store, accountID strin // validateRouteGroups validates the route groups and returns the validated groups map. func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) { groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups) - groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate) + groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupsToValidate) if err != nil { return nil, err } @@ -494,7 +494,7 @@ func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, ro // GetRoutesByPrefixOrDomains return list of routes by account and route prefix func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) { - accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) + accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } diff --git a/management/server/route_test.go b/management/server/route_test.go index ac51c25fd..c3eea35ea 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -14,7 +14,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -27,6 +26,7 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) const ( @@ -1100,7 +1100,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id) + groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, account.Id) require.NoError(t, err) var groupHA1, groupHA2 *types.Group for _, group := range groups { diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go index 9cb142938..6d09f1786 100644 --- a/management/server/settings/manager.go +++ b/management/server/settings/manager.go @@ -11,10 +11,10 @@ import ( "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/management/status" ) type Manager interface { @@ -60,7 +60,7 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) return nil, fmt.Errorf("get extra settings: %w", err) } - settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, fmt.Errorf("get account settings: %w", err) } @@ -82,7 +82,7 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (* return nil, fmt.Errorf("get extra settings: %w", err) } - settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, fmt.Errorf("get account settings: %w", err) } diff --git a/management/server/setupkey.go b/management/server/setupkey.go index e3647f6df..71915b4a2 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -10,10 +10,10 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -81,7 +81,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey) eventsToStore = append(eventsToStore, events...) - return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, setupKey) + return transaction.SaveSetupKey(ctx, setupKey) }) if err != nil { return nil, err @@ -127,7 +127,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) } - oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id) + oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthUpdate, accountID, keyToSave.Id) if err != nil { return err } @@ -148,7 +148,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey) eventsToStore = append(eventsToStore, events...) - return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, newKey) + return transaction.SaveSetupKey(ctx, newKey) }) if err != nil { return nil, err @@ -175,7 +175,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u return nil, status.NewPermissionDeniedError() } - return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) + return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID) } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. @@ -188,7 +188,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return nil, status.NewPermissionDeniedError() } - setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) + setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID) if err != nil { return nil, err } @@ -214,12 +214,12 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, var deletedSetupKey *types.SetupKey err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) + deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthUpdate, accountID, keyID) if err != nil { return err } - return transaction.DeleteSetupKey(ctx, store.LockingStrengthUpdate, accountID, keyID) + return transaction.DeleteSetupKey(ctx, accountID, keyID) }) if err != nil { return err @@ -231,7 +231,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, } func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error { - groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, autoGroupIDs) if err != nil { return err } @@ -255,7 +255,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran var eventsToStore []func() modifiedGroups := slices.Concat(addedGroups, removedGroups) - groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups) if err != nil { log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) return nil diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index c0f683124..5986214d9 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -29,11 +29,11 @@ import ( networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -270,7 +270,7 @@ func generateAccountSQLTypes(account *types.Account) { func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, newDomain string) { var acc types.Account var domain string - result := s.db.Model(&acc).Select("domain").Where(idQueryCondition, accountID).First(&domain) + result := s.db.Model(&acc).Select("domain").Where(idQueryCondition, accountID).Take(&domain) if result.Error != nil { if !errors.Is(result.Error, gorm.ErrRecordNotFound) { log.WithContext(ctx).Errorf("error when getting account %s from the store to check domain: %s", accountID, result.Error) @@ -323,23 +323,26 @@ func (s *SqlStore) SaveInstallationID(_ context.Context, ID string) error { func (s *SqlStore) GetInstallationID() string { var installation installation - if result := s.db.First(&installation, idQueryCondition, s.installationPK); result.Error != nil { + if result := s.db.Take(&installation, idQueryCondition, s.installationPK); result.Error != nil { return "" } return installation.InstallationIDValue } -func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error { +func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error { // To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields. peerCopy := peer.Copy() peerCopy.AccountID = accountID - err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Transaction(func(tx *gorm.DB) error { + err := s.db.Transaction(func(tx *gorm.DB) error { // check if peer exists before saving var peerID string - result := tx.Model(&nbpeer.Peer{}).Select("id").Find(&peerID, accountAndIDQueryCondition, accountID, peer.ID) + result := tx.Model(&nbpeer.Peer{}).Select("id").Take(&peerID, accountAndIDQueryCondition, accountID, peer.ID) if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.Errorf(status.NotFound, peerNotFoundFMT, peer.ID) + } return result.Error } @@ -385,7 +388,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID return nil } -func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, peerStatus nbpeer.PeerStatus) error { +func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string, peerStatus nbpeer.PeerStatus) error { var peerCopy nbpeer.Peer peerCopy.Status = &peerStatus @@ -393,7 +396,7 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStren "peer_status_last_seen", "peer_status_connected", "peer_status_login_expired", "peer_status_required_approval", } - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + result := s.db.Model(&nbpeer.Peer{}). Select(fieldsToUpdate). Where(accountAndIDQueryCondition, accountID, peerID). Updates(&peerCopy) @@ -408,14 +411,14 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStren return nil } -func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peerWithLocation *nbpeer.Peer) error { +func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error { // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. var peerCopy nbpeer.Peer // Since the location field has been migrated to JSON serialization, // updating the struct ensures the correct data format is inserted into the database. peerCopy.Location = peerWithLocation.Location - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + result := s.db.Model(&nbpeer.Peer{}). Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID). Updates(peerCopy) @@ -431,12 +434,12 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr } // SaveUsers saves the given list of users to the database. -func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error { +func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error { if len(users) == 0 { return nil } - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}, clause.OnConflict{UpdateAll: true}).Create(&users) + result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(&users) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save users to store: %s", result.Error) return status.Errorf(status.Internal, "failed to save users to store") @@ -445,8 +448,8 @@ func (s *SqlStore) SaveUsers(ctx context.Context, lockStrength LockingStrength, } // SaveUser saves the given user to the database. -func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) +func (s *SqlStore) SaveUser(ctx context.Context, user *types.User) error { + result := s.db.Save(user) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save user to store: %s", result.Error) return status.Errorf(status.Internal, "failed to save user to store") @@ -455,7 +458,7 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u } // CreateGroups creates the given list of groups to the database. -func (s *SqlStore) CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error { +func (s *SqlStore) CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error { if len(groups) == 0 { return nil } @@ -463,7 +466,6 @@ func (s *SqlStore) CreateGroups(ctx context.Context, lockStrength LockingStrengt return s.db.Transaction(func(tx *gorm.DB) error { result := tx. Clauses( - clause.Locking{Strength: string(lockStrength)}, clause.OnConflict{ Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, UpdateAll: true, @@ -481,7 +483,7 @@ func (s *SqlStore) CreateGroups(ctx context.Context, lockStrength LockingStrengt } // UpdateGroups updates the given list of groups to the database. -func (s *SqlStore) UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error { +func (s *SqlStore) UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error { if len(groups) == 0 { return nil } @@ -489,7 +491,6 @@ func (s *SqlStore) UpdateGroups(ctx context.Context, lockStrength LockingStrengt return s.db.Transaction(func(tx *gorm.DB) error { result := tx. Clauses( - clause.Locking{Strength: string(lockStrength)}, clause.OnConflict{ Where: clause.Where{Exprs: []clause.Expression{clause.Eq{Column: "groups.account_id", Value: accountID}}}, UpdateAll: true, @@ -517,7 +518,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { } func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) { - accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthNone, domain) if err != nil { return nil, err } @@ -536,7 +537,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength result := tx.Model(&types.Account{}).Select("id"). Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", strings.ToLower(domain), true, types.PrivateCategory, - ).First(&accountID) + ).Take(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") @@ -550,7 +551,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) { var key types.SetupKey - result := s.db.Select("account_id").First(&key, GetKeyQueryCondition(s), setupKey) + result := s.db.Select("account_id").Take(&key, GetKeyQueryCondition(s), setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewSetupKeyNotFoundError(setupKey) @@ -568,7 +569,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) { var token types.PersonalAccessToken - result := s.db.First(&token, "hashed_token = ?", hashedToken) + result := s.db.Take(&token, "hashed_token = ?", hashedToken) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -589,7 +590,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren var user types.User result := tx. Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id"). - Where("personal_access_tokens.id = ?", patID).First(&user) + Where("personal_access_tokens.id = ?", patID).Take(&user) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewPATNotFoundError(patID) @@ -608,7 +609,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre } var user types.User - result := tx.First(&user, idQueryCondition, userID) + result := tx.Take(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) @@ -619,16 +620,14 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } -func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error { +func (s *SqlStore) DeleteUser(ctx context.Context, accountID, userID string) error { err := s.db.Transaction(func(tx *gorm.DB) error { - result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&types.PersonalAccessToken{}, "user_id = ?", userID) + result := tx.Delete(&types.PersonalAccessToken{}, "user_id = ?", userID) if result.Error != nil { return result.Error } - return tx.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&types.User{}, accountAndIDQueryCondition, accountID, userID).Error + return tx.Delete(&types.User{}, accountAndIDQueryCondition, accountID, userID).Error }) if err != nil { log.WithContext(ctx).Errorf("failed to delete user from the store: %s", err) @@ -664,7 +663,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre } var user types.User - result := tx.First(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner) + result := tx.Take(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account owner not found: index lookup failed") @@ -761,7 +760,7 @@ func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStren var accountMeta types.AccountMeta result := tx.Model(&types.Account{}). - First(&accountMeta, idQueryCondition, accountID) + Take(&accountMeta, idQueryCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("error when getting account meta %s from the store: %s", accountID, result.Error) if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -776,7 +775,7 @@ func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStren // GetAccountOnboarding retrieves the onboarding information for a specific account. func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) { var accountOnboarding types.AccountOnboarding - result := s.db.Model(&accountOnboarding).First(&accountOnboarding, accountIDCondition, accountID) + result := s.db.Model(&accountOnboarding).Take(&accountOnboarding, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewAccountOnboardingNotFoundError(accountID) @@ -813,7 +812,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc Omit("GroupsG"). Preload("UsersG.PATsG"). // have to be specifies as this is nester reference Preload(clause.Associations). - First(&account, idQueryCondition, accountID) + Take(&account, idQueryCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -888,7 +887,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) { var user types.User - result := s.db.Select("account_id").First(&user, idQueryCondition, userID) + result := s.db.Select("account_id").Take(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -905,7 +904,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types. func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) { var peer nbpeer.Peer - result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID) + result := s.db.Select("account_id").Take(&peer, idQueryCondition, peerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -922,7 +921,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*type func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) { var peer nbpeer.Peer - result := s.db.Select("account_id").First(&peer, GetKeyQueryCondition(s), peerKey) + result := s.db.Select("account_id").Take(&peer, GetKeyQueryCondition(s), peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -954,7 +953,7 @@ func (s *SqlStore) GetAnyAccountID(ctx context.Context) (string, error) { func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { var peer nbpeer.Peer var accountID string - result := s.db.Model(&peer).Select("account_id").Where(GetKeyQueryCondition(s), peerKey).First(&accountID) + result := s.db.Model(&peer).Select("account_id").Where(GetKeyQueryCondition(s), peerKey).Take(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -973,7 +972,7 @@ func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength Lockin var accountID string result := tx.Model(&types.User{}). - Select("account_id").Where(idQueryCondition, userID).First(&accountID) + Select("account_id").Where(idQueryCondition, userID).Take(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -992,7 +991,7 @@ func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength Lockin var accountID string result := tx.Model(&nbpeer.Peer{}). - Select("account_id").Where(idQueryCondition, peerID).First(&accountID) + Select("account_id").Where(idQueryCondition, peerID).Take(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "peer %s account not found", peerID) @@ -1005,7 +1004,7 @@ func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength Lockin func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { var accountID string - result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(GetKeyQueryCondition(s), setupKey).First(&accountID) + result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(GetKeyQueryCondition(s), setupKey).Take(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.NewSetupKeyNotFoundError(setupKey) @@ -1082,7 +1081,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt } var accountNetwork types.AccountNetwork - if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { + if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } @@ -1098,7 +1097,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking } var peer nbpeer.Peer - result := tx.First(&peer, GetKeyQueryCondition(s), peerKey) + result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1117,7 +1116,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS } var accountSettings types.AccountSettings - if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { + if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } @@ -1134,7 +1133,7 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking var createdBy string result := tx.Model(&types.Account{}). - Select("created_by").First(&createdBy, idQueryCondition, accountID) + Select("created_by").Take(&createdBy, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.NewAccountNotFoundError(accountID) @@ -1148,7 +1147,7 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { var user types.User - result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID) + result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewUserNotFoundError(userID) @@ -1171,7 +1170,7 @@ func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *p } var postureCheck posture.Checks - err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error + err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).Take(&postureCheck).Error if err != nil { return nil, err } @@ -1336,7 +1335,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking var setupKey types.SetupKey result := tx. - First(&setupKey, GetKeyQueryCondition(s), key) + Take(&setupKey, GetKeyQueryCondition(s), key) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1446,7 +1445,7 @@ func (s *SqlStore) RemovePeerFromAllGroups(ctx context.Context, peerID string) e // AddResourceToGroup adds a resource to a group. Method always needs to run n a transaction func (s *SqlStore) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error { var group types.Group - result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).Take(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewGroupNotFoundError(groupID) @@ -1473,7 +1472,7 @@ func (s *SqlStore) AddResourceToGroup(ctx context.Context, accountId string, gro // RemoveResourceFromGroup removes a resource from a group. Method always needs to run in a transaction func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error { var group types.Group - result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).Take(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewGroupNotFoundError(groupID) @@ -1593,7 +1592,7 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt return peers, nil } -func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error { +func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { if err := s.db.Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } @@ -1610,7 +1609,7 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength var peer *nbpeer.Peer result := tx. - First(&peer, accountAndIDQueryCondition, accountID, peerID) + Take(&peer, accountAndIDQueryCondition, accountID, peerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewPeerNotFoundError(peerID) @@ -1705,9 +1704,8 @@ func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength Lockin } // DeletePeer removes a peer from the store. -func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID) +func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID string) error { + result := s.db.Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err) return status.Errorf(status.Internal, "failed to delete peer from store") @@ -1720,9 +1718,8 @@ func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, return nil } -func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) +func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { + result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) return status.Errorf(status.Internal, "failed to increment network serial count in store") @@ -1772,7 +1769,7 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki var accountDNSSettings types.AccountDNSSettings result := tx.Model(&types.Account{}). - First(&accountDNSSettings, idQueryCondition, accountID) + Take(&accountDNSSettings, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) @@ -1792,7 +1789,7 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng var accountID string result := tx.Model(&types.Account{}). - Select("id").First(&accountID, idQueryCondition, id) + Select("id").Take(&accountID, idQueryCondition, id) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return false, nil @@ -1812,7 +1809,7 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength var account types.Account result := tx.Model(&types.Account{}).Select("domain", "domain_category"). - Where(idQueryCondition, accountID).First(&account) + Where(idQueryCondition, accountID).Take(&account) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", "", status.Errorf(status.NotFound, "account not found") @@ -1831,7 +1828,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt } var group *types.Group - result := tx.Preload(clause.Associations).First(&group, accountAndIDQueryCondition, accountID, groupID) + result := tx.Preload(clause.Associations).Take(&group, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewGroupNotFoundError(groupID) @@ -1900,7 +1897,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren } // CreateGroup creates a group in the store. -func (s *SqlStore) CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { +func (s *SqlStore) CreateGroup(ctx context.Context, group *types.Group) error { if group == nil { return status.Errorf(status.InvalidArgument, "group is nil") } @@ -1914,7 +1911,7 @@ func (s *SqlStore) CreateGroup(ctx context.Context, lockStrength LockingStrength } // UpdateGroup updates a group in the store. -func (s *SqlStore) UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { +func (s *SqlStore) UpdateGroup(ctx context.Context, group *types.Group) error { if group == nil { return status.Errorf(status.InvalidArgument, "group is nil") } @@ -1928,9 +1925,8 @@ func (s *SqlStore) UpdateGroup(ctx context.Context, lockStrength LockingStrength } // DeleteGroup deletes a group from the database. -func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Select(clause.Associations). +func (s *SqlStore) DeleteGroup(ctx context.Context, accountID, groupID string) error { + result := s.db.Select(clause.Associations). Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) @@ -1945,9 +1941,8 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength } // DeleteGroups deletes groups from the database. -func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { - result := s.db.Clauses(clause.Locking{Strength: string(strength)}). - Select(clause.Associations). +func (s *SqlStore) DeleteGroups(ctx context.Context, accountID string, groupIDs []string) error { + result := s.db.Select(clause.Associations). Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) @@ -1985,7 +1980,7 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng var policy *types.Policy result := tx.Preload(clause.Associations). - First(&policy, accountAndIDQueryCondition, accountID, policyID) + Take(&policy, accountAndIDQueryCondition, accountID, policyID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewPolicyNotFoundError(policyID) @@ -1997,8 +1992,8 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng return policy, nil } -func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy) +func (s *SqlStore) CreatePolicy(ctx context.Context, policy *types.Policy) error { + result := s.db.Create(policy) if result.Error != nil { log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error) return status.Errorf(status.Internal, "failed to create policy in store") @@ -2008,9 +2003,8 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrengt } // SavePolicy saves a policy to the database. -func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error { - result := s.db.Session(&gorm.Session{FullSaveAssociations: true}). - Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy) +func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error { + result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).Save(policy) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err) return status.Errorf(status.Internal, "failed to save policy to store") @@ -2018,13 +2012,13 @@ func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, return nil } -func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { +func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error { return s.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil { return fmt.Errorf("delete policy rules: %w", err) } - result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). + result := tx. Where(accountAndIDQueryCondition, accountID, policyID). Delete(&types.Policy{}) @@ -2067,7 +2061,7 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin var postureCheck *posture.Checks result := tx. - First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID) + Take(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewPostureChecksNotFoundError(postureChecksID) @@ -2102,8 +2096,8 @@ func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength Locki } // SavePostureChecks saves a posture checks to the database. -func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) +func (s *SqlStore) SavePostureChecks(ctx context.Context, postureCheck *posture.Checks) error { + result := s.db.Save(postureCheck) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error) return status.Errorf(status.Internal, "failed to save posture checks to store") @@ -2113,9 +2107,8 @@ func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingSt } // DeletePostureChecks deletes a posture checks from the database. -func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID) +func (s *SqlStore) DeletePostureChecks(ctx context.Context, accountID, postureChecksID string) error { + result := s.db.Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error) return status.Errorf(status.Internal, "failed to delete posture checks from store") @@ -2130,9 +2123,13 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking // GetAccountRoutes retrieves network routes for an account. func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var routes []*route.Route - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Find(&routes, accountIDCondition, accountID) + result := tx.Find(&routes, accountIDCondition, accountID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err) return nil, status.Errorf(status.Internal, "failed to get routes from store") @@ -2143,9 +2140,13 @@ func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStr // GetRouteByID retrieves a route by its ID and account ID. func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + var route *route.Route - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&route, accountAndIDQueryCondition, accountID, routeID) + result := tx.Take(&route, accountAndIDQueryCondition, accountID, routeID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewRouteNotFoundError(routeID) @@ -2158,8 +2159,8 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt } // SaveRoute saves a route to the database. -func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route) +func (s *SqlStore) SaveRoute(ctx context.Context, route *route.Route) error { + result := s.db.Save(route) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to save route to the store: %s", err) return status.Errorf(status.Internal, "failed to save route to store") @@ -2169,9 +2170,8 @@ func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, } // DeleteRoute deletes a route from the database. -func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID) +func (s *SqlStore) DeleteRoute(ctx context.Context, accountID, routeID string) error { + result := s.db.Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err) return status.Errorf(status.Internal, "failed to delete route from store") @@ -2210,8 +2210,7 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre } var setupKey *types.SetupKey - result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID) + result := tx.Take(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewSetupKeyNotFoundError(setupKeyID) @@ -2224,8 +2223,8 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre } // SaveSetupKey saves a setup key to the database. -func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey) +func (s *SqlStore) SaveSetupKey(ctx context.Context, setupKey *types.SetupKey) error { + result := s.db.Save(setupKey) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error) return status.Errorf(status.Internal, "failed to save setup key to store") @@ -2235,8 +2234,8 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt } // DeleteSetupKey deletes a setup key from the database. -func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&types.SetupKey{}, accountAndIDQueryCondition, accountID, keyID) +func (s *SqlStore) DeleteSetupKey(ctx context.Context, accountID, keyID string) error { + result := s.db.Delete(&types.SetupKey{}, accountAndIDQueryCondition, accountID, keyID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error) return status.Errorf(status.Internal, "failed to delete setup key from store") @@ -2275,7 +2274,7 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock var nsGroup *nbdns.NameServerGroup result := tx. - First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID) + Take(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewNameServerGroupNotFoundError(nsGroupID) @@ -2288,8 +2287,8 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock } // SaveNameServerGroup saves a name server group to the database. -func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup) +func (s *SqlStore) SaveNameServerGroup(ctx context.Context, nameServerGroup *nbdns.NameServerGroup) error { + result := s.db.Save(nameServerGroup) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err) return status.Errorf(status.Internal, "failed to save name server group to store") @@ -2298,8 +2297,8 @@ func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength Locking } // DeleteNameServerGroup deletes a name server group from the database. -func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID) +func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID string) error { + result := s.db.Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err) return status.Errorf(status.Internal, "failed to delete name server group from store") @@ -2313,8 +2312,8 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki } // SaveDNSSettings saves the DNS settings to the store. -func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). +func (s *SqlStore) SaveDNSSettings(ctx context.Context, accountID string, settings *types.DNSSettings) error { + result := s.db.Model(&types.Account{}). Where(idQueryCondition, accountID).Updates(&types.AccountDNSSettings{DNSSettings: *settings}) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error) @@ -2329,8 +2328,8 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre } // SaveAccountSettings stores the account settings in DB. -func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). +func (s *SqlStore) SaveAccountSettings(ctx context.Context, accountID string, settings *types.Settings) error { + result := s.db.Model(&types.Account{}). Select("*").Where(idQueryCondition, accountID).Updates(&types.AccountSettings{Settings: settings}) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save account settings to store: %v", result.Error) @@ -2367,8 +2366,7 @@ func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStren } var network *networkTypes.Network - result := tx. - First(&network, accountAndIDQueryCondition, accountID, networkID) + result := tx.Take(&network, accountAndIDQueryCondition, accountID, networkID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewNetworkNotFoundError(networkID) @@ -2381,8 +2379,8 @@ func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStren return network, nil } -func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(network) +func (s *SqlStore) SaveNetwork(ctx context.Context, network *networkTypes.Network) error { + result := s.db.Save(network) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save network to store: %v", result.Error) return status.Errorf(status.Internal, "failed to save network to store") @@ -2391,9 +2389,8 @@ func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength return nil } -func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&networkTypes.Network{}, accountAndIDQueryCondition, accountID, networkID) +func (s *SqlStore) DeleteNetwork(ctx context.Context, accountID, networkID string) error { + result := s.db.Delete(&networkTypes.Network{}, accountAndIDQueryCondition, accountID, networkID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete network from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete network from store") @@ -2448,7 +2445,7 @@ func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength Lockin var netRouter *routerTypes.NetworkRouter result := tx. - First(&netRouter, accountAndIDQueryCondition, accountID, routerID) + Take(&netRouter, accountAndIDQueryCondition, accountID, routerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewNetworkRouterNotFoundError(routerID) @@ -2460,8 +2457,8 @@ func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength Lockin return netRouter, nil } -func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(router) +func (s *SqlStore) SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error { + result := s.db.Save(router) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save network router to store: %v", result.Error) return status.Errorf(status.Internal, "failed to save network router to store") @@ -2470,9 +2467,8 @@ func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingSt return nil } -func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&routerTypes.NetworkRouter{}, accountAndIDQueryCondition, accountID, routerID) +func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, accountID, routerID string) error { + result := s.db.Delete(&routerTypes.NetworkRouter{}, accountAndIDQueryCondition, accountID, routerID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete network router from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete network router from store") @@ -2527,7 +2523,7 @@ func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength Lock var netResources *resourceTypes.NetworkResource result := tx. - First(&netResources, accountAndIDQueryCondition, accountID, resourceID) + Take(&netResources, accountAndIDQueryCondition, accountID, resourceID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewNetworkResourceNotFoundError(resourceID) @@ -2547,7 +2543,7 @@ func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength Lo var netResources *resourceTypes.NetworkResource result := tx. - First(&netResources, "account_id = ? AND name = ?", accountID, resourceName) + Take(&netResources, "account_id = ? AND name = ?", accountID, resourceName) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewNetworkResourceNotFoundError(resourceName) @@ -2559,8 +2555,8 @@ func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength Lo return netResources, nil } -func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(resource) +func (s *SqlStore) SaveNetworkResource(ctx context.Context, resource *resourceTypes.NetworkResource) error { + result := s.db.Save(resource) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save network resource to store: %v", result.Error) return status.Errorf(status.Internal, "failed to save network resource to store") @@ -2569,9 +2565,8 @@ func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength Locking return nil } -func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&resourceTypes.NetworkResource{}, accountAndIDQueryCondition, accountID, resourceID) +func (s *SqlStore) DeleteNetworkResource(ctx context.Context, accountID, resourceID string) error { + result := s.db.Delete(&resourceTypes.NetworkResource{}, accountAndIDQueryCondition, accountID, resourceID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete network resource from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete network resource from store") @@ -2592,7 +2587,7 @@ func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength Locking } var pat types.PersonalAccessToken - result := tx.First(&pat, "hashed_token = ?", hashedToken) + result := tx.Take(&pat, "hashed_token = ?", hashedToken) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewPATNotFoundError(hashedToken) @@ -2613,7 +2608,7 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, var pat types.PersonalAccessToken result := tx. - First(&pat, "id = ? AND user_id = ?", patID, userID) + Take(&pat, "id = ? AND user_id = ?", patID, userID) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewPATNotFoundError(patID) @@ -2643,13 +2638,13 @@ func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength } // MarkPATUsed marks a personal access token as used. -func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error { +func (s *SqlStore) MarkPATUsed(ctx context.Context, patID string) error { patCopy := types.PersonalAccessToken{ LastUsed: util.ToPtr(time.Now().UTC()), } fieldsToUpdate := []string{"last_used"} - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Select(fieldsToUpdate). + result := s.db.Select(fieldsToUpdate). Where(idQueryCondition, patID).Updates(&patCopy) if result.Error != nil { log.WithContext(ctx).Errorf("failed to mark pat as used: %s", result.Error) @@ -2664,8 +2659,8 @@ func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength } // SavePAT saves a personal access token to the database. -func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *types.PersonalAccessToken) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat) +func (s *SqlStore) SavePAT(ctx context.Context, pat *types.PersonalAccessToken) error { + result := s.db.Save(pat) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to save pat to the store: %s", err) return status.Errorf(status.Internal, "failed to save pat to store") @@ -2675,9 +2670,8 @@ func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pa } // DeletePAT deletes a personal access token from the database. -func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&types.PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID) +func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error { + result := s.db.Delete(&types.PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete pat from the store: %s", err) return status.Errorf(status.Internal, "failed to delete pat from store") @@ -2700,7 +2694,7 @@ func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength var peer nbpeer.Peer result := tx. - First(&peer, "account_id = ? AND ip = ?", accountID, jsonValue) + Take(&peer, "account_id = ? AND ip = ?", accountID, jsonValue) if result.Error != nil { // no logging here return nil, status.Errorf(status.Internal, "failed to get peer from store") diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index fdce9d735..935b0a595 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -27,11 +27,11 @@ import ( networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" nbroute "github.com/netbirdio/netbird/route" route2 "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/status" ) func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) { @@ -401,11 +401,11 @@ func TestSqlite_DeleteAccount(t *testing.T) { } for _, network := range account.Networks { - routers, err := store.GetNetworkRoutersByNetID(context.Background(), LockingStrengthShare, account.Id, network.ID) + routers, err := store.GetNetworkRoutersByNetID(context.Background(), LockingStrengthNone, account.Id, network.ID) require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network routers") require.Len(t, routers, 0, "expecting no network routers to be found after DeleteAccount") - resources, err := store.GetNetworkResourcesByNetID(context.Background(), LockingStrengthShare, account.Id, network.ID) + resources, err := store.GetNetworkResourcesByNetID(context.Background(), LockingStrengthNone, account.Id, network.ID) require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network resources") require.Len(t, resources, 0, "expecting no network resources to be found after DeleteAccount") } @@ -459,7 +459,7 @@ func TestSqlStore_SavePeer(t *testing.T) { CreatedAt: time.Now().UTC(), } ctx := context.Background() - err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, peer) + err = store.SavePeer(ctx, account.Id, peer) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -475,7 +475,7 @@ func TestSqlStore_SavePeer(t *testing.T) { updatedPeer.Status.Connected = false updatedPeer.Meta.Hostname = "updatedpeer" - err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, updatedPeer) + err = store.SavePeer(ctx, account.Id, updatedPeer) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -499,7 +499,7 @@ func TestSqlStore_SavePeerStatus(t *testing.T) { // save status of non-existing peer newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) + err = store.SavePeerStatus(context.Background(), account.Id, "non-existing-peer", newStatus) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -518,7 +518,7 @@ func TestSqlStore_SavePeerStatus(t *testing.T) { err = store.SaveAccount(context.Background(), account) require.NoError(t, err) - err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) + err = store.SavePeerStatus(context.Background(), account.Id, "testpeer", newStatus) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -532,7 +532,7 @@ func TestSqlStore_SavePeerStatus(t *testing.T) { newStatus.Connected = true - err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) + err = store.SavePeerStatus(context.Background(), account.Id, "testpeer", newStatus) require.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -566,7 +566,7 @@ func TestSqlStore_SavePeerLocation(t *testing.T) { Meta: nbpeer.PeerSystemMeta{}, } // error is expected as peer is not in store yet - err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer) + err = store.SavePeerLocation(context.Background(), account.Id, peer) assert.Error(t, err) account.Peers[peer.ID] = peer @@ -578,7 +578,7 @@ func TestSqlStore_SavePeerLocation(t *testing.T) { peer.Location.CityName = "Berlin" peer.Location.GeoNameID = 2950159 - err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, account.Peers[peer.ID]) + err = store.SavePeerLocation(context.Background(), account.Id, account.Peers[peer.ID]) assert.NoError(t, err) account, err = store.GetAccount(context.Background(), account.Id) @@ -588,7 +588,7 @@ func TestSqlStore_SavePeerLocation(t *testing.T) { assert.Equal(t, peer.Location, actual) peer.ID = "non-existing-peer" - err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer) + err = store.SavePeerLocation(context.Background(), account.Id, peer) assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) @@ -961,7 +961,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) - takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) + takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthNone, existingAccountID) require.NoError(t, err) assert.Equal(t, []net.IP{}, takenIPs) @@ -971,10 +971,10 @@ func TestSqlite_GetTakenIPs(t *testing.T) { DNSLabel: "peer1", IP: net.IP{1, 1, 1, 1}, } - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) + err = store.AddPeerToAccount(context.Background(), peer1) require.NoError(t, err) - takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) + takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthNone, existingAccountID) require.NoError(t, err) ip1 := net.IP{1, 1, 1, 1}.To16() assert.Equal(t, []net.IP{ip1}, takenIPs) @@ -985,10 +985,10 @@ func TestSqlite_GetTakenIPs(t *testing.T) { DNSLabel: "peer1-1", IP: net.IP{2, 2, 2, 2}, } - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) + err = store.AddPeerToAccount(context.Background(), peer2) require.NoError(t, err) - takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) + takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthNone, existingAccountID) require.NoError(t, err) ip2 := net.IP{2, 2, 2, 2}.To16() assert.Equal(t, []net.IP{ip1, ip2}, takenIPs) @@ -1002,7 +1002,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { _, err := store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) - labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname) + labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthNone, existingAccountID, peerHostname) require.NoError(t, err) assert.Equal(t, []string{}, labels) @@ -1012,10 +1012,10 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { DNSLabel: "peer1", IP: net.IP{1, 1, 1, 1}, } - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) + err = store.AddPeerToAccount(context.Background(), peer1) require.NoError(t, err) - labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname) + labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthNone, existingAccountID, peerHostname) require.NoError(t, err) assert.Equal(t, []string{"peer1"}, labels) @@ -1025,10 +1025,10 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { DNSLabel: "peer1-1", IP: net.IP{2, 2, 2, 2}, } - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) + err = store.AddPeerToAccount(context.Background(), peer2) require.NoError(t, err) - labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID, peerHostname) + labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthNone, existingAccountID, peerHostname) require.NoError(t, err) expected := []string{"peer1", "peer1-1"} @@ -1050,7 +1050,7 @@ func Test_AddPeerWithSameDnsLabel(t *testing.T) { AccountID: existingAccountID, DNSLabel: "peer1.domain.test", } - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) + err = store.AddPeerToAccount(context.Background(), peer1) require.NoError(t, err) peer2 := &nbpeer.Peer{ @@ -1058,7 +1058,7 @@ func Test_AddPeerWithSameDnsLabel(t *testing.T) { AccountID: existingAccountID, DNSLabel: "peer1.domain.test", } - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) + err = store.AddPeerToAccount(context.Background(), peer2) require.Error(t, err) }) } @@ -1075,7 +1075,7 @@ func Test_AddPeerWithSameIP(t *testing.T) { AccountID: existingAccountID, IP: net.IP{1, 1, 1, 1}, } - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) + err = store.AddPeerToAccount(context.Background(), peer1) require.NoError(t, err) peer2 := &nbpeer.Peer{ @@ -1083,7 +1083,7 @@ func Test_AddPeerWithSameIP(t *testing.T) { AccountID: existingAccountID, IP: net.IP{1, 1, 1, 1}, } - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) + err = store.AddPeerToAccount(context.Background(), peer2) require.Error(t, err) }) } @@ -1101,7 +1101,7 @@ func TestSqlite_GetAccountNetwork(t *testing.T) { _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) - network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID) + network, err := store.GetAccountNetwork(context.Background(), LockingStrengthNone, existingAccountID) require.NoError(t, err) ip := net.IP{100, 64, 0, 0}.To16() assert.Equal(t, ip, network.Net.IP) @@ -1128,7 +1128,7 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) - setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) + setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthNone, encodedHashedKey) require.NoError(t, err) assert.Equal(t, encodedHashedKey, setupKey.Key) assert.Equal(t, types.HiddenKey(plainKey, 4), setupKey.KeySecret) @@ -1153,21 +1153,21 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { _, err = store.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) - setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) + setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthNone, encodedHashedKey) require.NoError(t, err) assert.Equal(t, 0, setupKey.UsedTimes) err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) require.NoError(t, err) - setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) + setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthNone, encodedHashedKey) require.NoError(t, err) assert.Equal(t, 1, setupKey.UsedTimes) err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) require.NoError(t, err) - setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) + setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthNone, encodedHashedKey) require.NoError(t, err) assert.Equal(t, 2, setupKey.UsedTimes) } @@ -1188,7 +1188,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { Peers: nil, } err = store.ExecuteInTransaction(context.Background(), func(transaction Store) error { - err := transaction.CreateGroup(context.Background(), LockingStrengthUpdate, group) + err := transaction.CreateGroup(context.Background(), group) if err != nil { t.Fatal("failed to save group") return err @@ -1213,7 +1213,7 @@ func TestSqlStore_GetAccountUsers(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" account, err := store.GetAccount(context.Background(), accountID) require.NoError(t, err) - users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) + users, err := store.GetAccountUsers(context.Background(), LockingStrengthNone, accountID) require.NoError(t, err) require.Len(t, users, len(account.Users)) } @@ -1272,7 +1272,7 @@ func TestSqlite_GetGroupByName(t *testing.T) { } accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All") + group, err := store.GetGroupByName(context.Background(), LockingStrengthNone, accountID, "All") require.NoError(t, err) require.True(t, group.IsGroupAll()) } @@ -1286,10 +1286,10 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" setupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" - err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, setupKeyID) + err = store.DeleteSetupKey(context.Background(), accountID, setupKeyID) require.NoError(t, err) - _, err = store.GetSetupKeyByID(context.Background(), LockingStrengthShare, setupKeyID, accountID) + _, err = store.GetSetupKeyByID(context.Background(), LockingStrengthNone, setupKeyID, accountID) require.Error(t, err) } @@ -1302,7 +1302,7 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" nonExistingKeyID := "non-existing-key-id" - err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID) + err = store.DeleteSetupKey(context.Background(), accountID, nonExistingKeyID) require.Error(t, err) } @@ -1342,7 +1342,7 @@ func TestSqlStore_GetGroupsByIDs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs) + groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthNone, accountID, tt.groupIDs) require.NoError(t, err) require.Len(t, groups, tt.expectedCount) }) @@ -1365,10 +1365,10 @@ func TestSqlStore_CreateGroup(t *testing.T) { Resources: []types.Resource{}, GroupPeers: []types.GroupPeer{}, } - err = store.CreateGroup(context.Background(), LockingStrengthUpdate, group) + err = store.CreateGroup(context.Background(), group) require.NoError(t, err) - savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id") + savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthNone, accountID, "group-id") require.NoError(t, err) require.Equal(t, savedGroup, group) } @@ -1398,14 +1398,14 @@ func TestSqlStore_CreateUpdateGroups(t *testing.T) { GroupPeers: []types.GroupPeer{}, }, } - err = store.CreateGroups(context.Background(), LockingStrengthUpdate, accountID, groups) + err = store.CreateGroups(context.Background(), accountID, groups) require.NoError(t, err) groups[1].Peers = []string{} - err = store.UpdateGroups(context.Background(), LockingStrengthUpdate, accountID, groups) + err = store.UpdateGroups(context.Background(), accountID, groups) require.NoError(t, err) - group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groups[1].ID) + group, err := store.GetGroupByID(context.Background(), LockingStrengthNone, accountID, groups[1].ID) require.NoError(t, err) require.Equal(t, groups[1], group) } @@ -1441,7 +1441,7 @@ func TestSqlStore_DeleteGroup(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID) + err := store.DeleteGroup(context.Background(), accountID, tt.groupID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) @@ -1450,7 +1450,7 @@ func TestSqlStore_DeleteGroup(t *testing.T) { } else { require.NoError(t, err) - group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID) + group, err := store.GetGroupByID(context.Background(), LockingStrengthNone, accountID, tt.groupID) require.Error(t, err) require.Nil(t, group) } @@ -1489,14 +1489,14 @@ func TestSqlStore_DeleteGroups(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs) + err := store.DeleteGroups(context.Background(), accountID, tt.groupIDs) if tt.expectError { require.Error(t, err) } else { require.NoError(t, err) for _, groupID := range tt.groupIDs { - group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err := store.GetGroupByID(context.Background(), LockingStrengthNone, accountID, groupID) require.Error(t, err) require.Nil(t, group) } @@ -1535,7 +1535,7 @@ func TestSqlStore_GetPeerByID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID) + peer, err := store.GetPeerByID(context.Background(), LockingStrengthNone, accountID, tt.peerID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) @@ -1586,7 +1586,7 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs) + peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthNone, accountID, tt.peerIDs) require.NoError(t, err) require.Len(t, peers, tt.expectedCount) }) @@ -1623,7 +1623,7 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - postureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + postureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthNone, accountID, tt.postureChecksID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) @@ -1675,7 +1675,7 @@ func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - groups, err := store.GetPostureChecksByIDs(context.Background(), LockingStrengthShare, accountID, tt.postureCheckIDs) + groups, err := store.GetPostureChecksByIDs(context.Background(), LockingStrengthNone, accountID, tt.postureCheckIDs) require.NoError(t, err) require.Len(t, groups, tt.expectedCount) }) @@ -1715,10 +1715,10 @@ func TestSqlStore_SavePostureChecks(t *testing.T) { }, }, } - err = store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureChecks) + err = store.SavePostureChecks(context.Background(), postureChecks) require.NoError(t, err) - savePostureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, "posture-checks-id") + savePostureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthNone, accountID, "posture-checks-id") require.NoError(t, err) require.Equal(t, savePostureChecks, postureChecks) } @@ -1754,7 +1754,7 @@ func TestSqlStore_DeletePostureChecks(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err = store.DeletePostureChecks(context.Background(), LockingStrengthUpdate, accountID, tt.postureChecksID) + err = store.DeletePostureChecks(context.Background(), accountID, tt.postureChecksID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) @@ -1762,7 +1762,7 @@ func TestSqlStore_DeletePostureChecks(t *testing.T) { require.Equal(t, sErr.Type(), status.NotFound) } else { require.NoError(t, err) - group, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + group, err := store.GetPostureChecksByID(context.Background(), LockingStrengthNone, accountID, tt.postureChecksID) require.Error(t, err) require.Nil(t, group) } @@ -1800,7 +1800,7 @@ func TestSqlStore_GetPolicyByID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, tt.policyID) + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthNone, accountID, tt.policyID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) @@ -1837,10 +1837,10 @@ func TestSqlStore_CreatePolicy(t *testing.T) { }, }, } - err = store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy) + err = store.CreatePolicy(context.Background(), policy) require.NoError(t, err) - savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID) + savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthNone, accountID, policy.ID) require.NoError(t, err) require.Equal(t, savePolicy, policy) @@ -1854,17 +1854,17 @@ func TestSqlStore_SavePolicy(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" policyID := "cs1tnh0hhcjnqoiuebf0" - policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID) + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthNone, accountID, policyID) require.NoError(t, err) policy.Enabled = false policy.Description = "policy" policy.Rules[0].Sources = []string{"group"} policy.Rules[0].Ports = []string{"80", "443"} - err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + err = store.SavePolicy(context.Background(), policy) require.NoError(t, err) - savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID) + savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthNone, accountID, policy.ID) require.NoError(t, err) require.Equal(t, savePolicy, policy) } @@ -1877,10 +1877,10 @@ func TestSqlStore_DeletePolicy(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" policyID := "cs1tnh0hhcjnqoiuebf0" - err = store.DeletePolicy(context.Background(), LockingStrengthShare, accountID, policyID) + err = store.DeletePolicy(context.Background(), accountID, policyID) require.NoError(t, err) - policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID) + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthNone, accountID, policyID) require.Error(t, err) require.Nil(t, policy) } @@ -1914,7 +1914,7 @@ func TestSqlStore_GetDNSSettings(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, tt.accountID) + dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthNone, tt.accountID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) @@ -1936,14 +1936,14 @@ func TestSqlStore_SaveDNSSettings(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthNone, accountID) require.NoError(t, err) dnsSettings.DisabledManagementGroups = []string{"groupA", "groupB"} - err = store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, dnsSettings) + err = store.SaveDNSSettings(context.Background(), accountID, dnsSettings) require.NoError(t, err) - saveDNSSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + saveDNSSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthNone, accountID) require.NoError(t, err) require.Equal(t, saveDNSSettings, dnsSettings) } @@ -1977,7 +1977,7 @@ func TestSqlStore_GetAccountNameServerGroups(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - peers, err := store.GetAccountNameServerGroups(context.Background(), LockingStrengthShare, tt.accountID) + peers, err := store.GetAccountNameServerGroups(context.Background(), LockingStrengthNone, tt.accountID) require.NoError(t, err) require.Len(t, peers, tt.expectedCount) }) @@ -2015,7 +2015,7 @@ func TestSqlStore_GetNameServerByID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, tt.nsGroupID) + nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthNone, accountID, tt.nsGroupID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) @@ -2055,10 +2055,10 @@ func TestSqlStore_SaveNameServerGroup(t *testing.T) { SearchDomainsEnabled: false, } - err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nsGroup) + err = store.SaveNameServerGroup(context.Background(), nsGroup) require.NoError(t, err) - saveNSGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroup.ID) + saveNSGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthNone, accountID, nsGroup.ID) require.NoError(t, err) require.Equal(t, saveNSGroup, nsGroup) } @@ -2071,10 +2071,10 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" nsGroupID := "csqdelq7qv97ncu7d9t0" - err = store.DeleteNameServerGroup(context.Background(), LockingStrengthShare, accountID, nsGroupID) + err = store.DeleteNameServerGroup(context.Background(), accountID, nsGroupID) require.NoError(t, err) - nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroupID) + nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthNone, accountID, nsGroupID) require.Error(t, err) require.Nil(t, nsGroup) } @@ -2154,7 +2154,7 @@ func TestSqlStore_GetAccountNetworks(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - networks, err := store.GetAccountNetworks(context.Background(), LockingStrengthShare, tt.accountID) + networks, err := store.GetAccountNetworks(context.Background(), LockingStrengthNone, tt.accountID) require.NoError(t, err) require.Len(t, networks, tt.expectedCount) }) @@ -2191,7 +2191,7 @@ func TestSqlStore_GetNetworkByID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - network, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + network, err := store.GetNetworkByID(context.Background(), LockingStrengthNone, accountID, tt.networkID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) @@ -2219,10 +2219,10 @@ func TestSqlStore_SaveNetwork(t *testing.T) { Name: "net", } - err = store.SaveNetwork(context.Background(), LockingStrengthUpdate, network) + err = store.SaveNetwork(context.Background(), network) require.NoError(t, err) - savedNet, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, network.ID) + savedNet, err := store.GetNetworkByID(context.Background(), LockingStrengthNone, accountID, network.ID) require.NoError(t, err) require.Equal(t, network, savedNet) } @@ -2235,10 +2235,10 @@ func TestSqlStore_DeleteNetwork(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" networkID := "ct286bi7qv930dsrrug0" - err = store.DeleteNetwork(context.Background(), LockingStrengthUpdate, accountID, networkID) + err = store.DeleteNetwork(context.Background(), accountID, networkID) require.NoError(t, err) - network, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, networkID) + network, err := store.GetNetworkByID(context.Background(), LockingStrengthNone, accountID, networkID) require.Error(t, err) sErr, ok := status.FromError(err) require.True(t, ok) @@ -2272,7 +2272,7 @@ func TestSqlStore_GetNetworkRoutersByNetID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - routers, err := store.GetNetworkRoutersByNetID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + routers, err := store.GetNetworkRoutersByNetID(context.Background(), LockingStrengthNone, accountID, tt.networkID) require.NoError(t, err) require.Len(t, routers, tt.expectedCount) }) @@ -2309,7 +2309,7 @@ func TestSqlStore_GetNetworkRouterByID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - networkRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthShare, accountID, tt.networkRouterID) + networkRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthNone, accountID, tt.networkRouterID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) @@ -2336,10 +2336,10 @@ func TestSqlStore_SaveNetworkRouter(t *testing.T) { netRouter, err := routerTypes.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0, true) require.NoError(t, err) - err = store.SaveNetworkRouter(context.Background(), LockingStrengthUpdate, netRouter) + err = store.SaveNetworkRouter(context.Background(), netRouter) require.NoError(t, err) - savedNetRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthShare, accountID, netRouter.ID) + savedNetRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthNone, accountID, netRouter.ID) require.NoError(t, err) require.Equal(t, netRouter, savedNetRouter) } @@ -2352,10 +2352,10 @@ func TestSqlStore_DeleteNetworkRouter(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" netRouterID := "ctc20ji7qv9ck2sebc80" - err = store.DeleteNetworkRouter(context.Background(), LockingStrengthUpdate, accountID, netRouterID) + err = store.DeleteNetworkRouter(context.Background(), accountID, netRouterID) require.NoError(t, err) - netRouter, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, netRouterID) + netRouter, err := store.GetNetworkByID(context.Background(), LockingStrengthNone, accountID, netRouterID) require.Error(t, err) sErr, ok := status.FromError(err) require.True(t, ok) @@ -2389,7 +2389,7 @@ func TestSqlStore_GetNetworkResourcesByNetID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - netResources, err := store.GetNetworkResourcesByNetID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + netResources, err := store.GetNetworkResourcesByNetID(context.Background(), LockingStrengthNone, accountID, tt.networkID) require.NoError(t, err) require.Len(t, netResources, tt.expectedCount) }) @@ -2426,7 +2426,7 @@ func TestSqlStore_GetNetworkResourceByID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - netResource, err := store.GetNetworkResourceByID(context.Background(), LockingStrengthShare, accountID, tt.netResourceID) + netResource, err := store.GetNetworkResourceByID(context.Background(), LockingStrengthNone, accountID, tt.netResourceID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) @@ -2453,10 +2453,10 @@ func TestSqlStore_SaveNetworkResource(t *testing.T) { netResource, err := resourceTypes.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com", []string{}, true) require.NoError(t, err) - err = store.SaveNetworkResource(context.Background(), LockingStrengthUpdate, netResource) + err = store.SaveNetworkResource(context.Background(), netResource) require.NoError(t, err) - savedNetResource, err := store.GetNetworkResourceByID(context.Background(), LockingStrengthShare, accountID, netResource.ID) + savedNetResource, err := store.GetNetworkResourceByID(context.Background(), LockingStrengthNone, accountID, netResource.ID) require.NoError(t, err) require.Equal(t, netResource.ID, savedNetResource.ID) require.Equal(t, netResource.Name, savedNetResource.Name) @@ -2475,10 +2475,10 @@ func TestSqlStore_DeleteNetworkResource(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" netResourceID := "ctc4nci7qv9061u6ilfg" - err = store.DeleteNetworkResource(context.Background(), LockingStrengthUpdate, accountID, netResourceID) + err = store.DeleteNetworkResource(context.Background(), accountID, netResourceID) require.NoError(t, err) - netResource, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, netResourceID) + netResource, err := store.GetNetworkByID(context.Background(), LockingStrengthNone, accountID, netResourceID) require.Error(t, err) sErr, ok := status.FromError(err) require.True(t, ok) @@ -2502,18 +2502,18 @@ func TestSqlStore_AddAndRemoveResourceFromGroup(t *testing.T) { err = store.AddResourceToGroup(context.Background(), accountID, groupID, res) require.NoError(t, err) - group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err := store.GetGroupByID(context.Background(), LockingStrengthNone, accountID, groupID) require.NoError(t, err) require.Contains(t, group.Resources, *res) - groups, err := store.GetResourceGroups(context.Background(), LockingStrengthShare, accountID, resourceId) + groups, err := store.GetResourceGroups(context.Background(), LockingStrengthNone, accountID, resourceId) require.NoError(t, err) require.Len(t, groups, 1) err = store.RemoveResourceFromGroup(context.Background(), accountID, groupID, res.ID) require.NoError(t, err) - group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err = store.GetGroupByID(context.Background(), LockingStrengthNone, accountID, groupID) require.NoError(t, err) require.NotContains(t, group.Resources, *res) } @@ -2527,14 +2527,14 @@ func TestSqlStore_AddPeerToGroup(t *testing.T) { peerID := "cfefqs706sqkneg59g4g" groupID := "cfefqs706sqkneg59g4h" - group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err := store.GetGroupByID(context.Background(), LockingStrengthNone, accountID, groupID) require.NoError(t, err, "failed to get group") require.Len(t, group.Peers, 0, "group should have 0 peers") err = store.AddPeerToGroup(context.Background(), accountID, peerID, groupID) require.NoError(t, err, "failed to add peer to group") - group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err = store.GetGroupByID(context.Background(), LockingStrengthNone, accountID, groupID) require.NoError(t, err, "failed to get group") require.Len(t, group.Peers, 1, "group should have 1 peers") require.Contains(t, group.Peers, peerID) @@ -2554,18 +2554,18 @@ func TestSqlStore_AddPeerToAllGroup(t *testing.T) { DNSLabel: "peer1.domain.test", } - group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err := store.GetGroupByID(context.Background(), LockingStrengthNone, accountID, groupID) require.NoError(t, err, "failed to get group") require.Len(t, group.Peers, 2, "group should have 2 peers") require.NotContains(t, group.Peers, peer.ID) - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) + err = store.AddPeerToAccount(context.Background(), peer) require.NoError(t, err, "failed to add peer to account") err = store.AddPeerToAllGroup(context.Background(), accountID, peer.ID) require.NoError(t, err, "failed to add peer to all group") - group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err = store.GetGroupByID(context.Background(), LockingStrengthNone, accountID, groupID) require.NoError(t, err, "failed to get group") require.Len(t, group.Peers, 3, "group should have peers") require.Contains(t, group.Peers, peer.ID) @@ -2609,10 +2609,10 @@ func TestSqlStore_AddPeerToAccount(t *testing.T) { CreatedAt: time.Now().UTC(), Ephemeral: true, } - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) + err = store.AddPeerToAccount(context.Background(), peer) require.NoError(t, err, "failed to add peer to account") - storedPeer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peer.ID) + storedPeer, err := store.GetPeerByID(context.Background(), LockingStrengthNone, accountID, peer.ID) require.NoError(t, err, "failed to get peer") assert.Equal(t, peer.ID, storedPeer.ID) @@ -2643,7 +2643,7 @@ func TestSqlStore_GetPeerGroups(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" peerID := "cfefqs706sqkneg59g4g" - groups, err := store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID) + groups, err := store.GetPeerGroups(context.Background(), LockingStrengthNone, accountID, peerID) require.NoError(t, err) assert.Len(t, groups, 1) assert.Equal(t, groups[0].Name, "All") @@ -2651,7 +2651,7 @@ func TestSqlStore_GetPeerGroups(t *testing.T) { err = store.AddPeerToGroup(context.Background(), accountID, peerID, "cfefqs706sqkneg59g4h") require.NoError(t, err) - groups, err = store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID) + groups, err = store.GetPeerGroups(context.Background(), LockingStrengthNone, accountID, peerID) require.NoError(t, err) assert.Len(t, groups, 2) } @@ -2705,7 +2705,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, tt.accountID, tt.nameFilter, tt.ipFilter) + peers, err := store.GetAccountPeers(context.Background(), LockingStrengthNone, tt.accountID, tt.nameFilter, tt.ipFilter) require.NoError(t, err) require.Len(t, peers, tt.expectedCount) }) @@ -2742,7 +2742,7 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthShare, tt.accountID) + peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, tt.accountID) require.NoError(t, err) require.Len(t, peers, tt.expectedCount) }) @@ -2778,7 +2778,7 @@ func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - peers, err := store.GetAccountPeersWithInactivity(context.Background(), LockingStrengthShare, tt.accountID) + peers, err := store.GetAccountPeersWithInactivity(context.Background(), LockingStrengthNone, tt.accountID) require.NoError(t, err) require.Len(t, peers, tt.expectedCount) }) @@ -2790,7 +2790,7 @@ func TestSqlStore_GetAllEphemeralPeers(t *testing.T) { t.Cleanup(cleanup) require.NoError(t, err) - peers, err := store.GetAllEphemeralPeers(context.Background(), LockingStrengthShare) + peers, err := store.GetAllEphemeralPeers(context.Background(), LockingStrengthNone) require.NoError(t, err) require.Len(t, peers, 1) require.True(t, peers[0].Ephemeral) @@ -2841,7 +2841,7 @@ func TestSqlStore_GetUserPeers(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - peers, err := store.GetUserPeers(context.Background(), LockingStrengthShare, tt.accountID, tt.userID) + peers, err := store.GetUserPeers(context.Background(), LockingStrengthNone, tt.accountID, tt.userID) require.NoError(t, err) require.Len(t, peers, tt.expectedCount) }) @@ -2856,10 +2856,10 @@ func TestSqlStore_DeletePeer(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" peerID := "csrnkiq7qv9d8aitqd50" - err = store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID) + err = store.DeletePeer(context.Background(), accountID, peerID) require.NoError(t, err) - peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) + peer, err := store.GetPeerByID(context.Background(), LockingStrengthNone, accountID, peerID) require.Error(t, err) require.Nil(t, peer) } @@ -2888,7 +2888,7 @@ func TestSqlStore_DatabaseBlocking(t *testing.T) { <-start err := store.ExecuteInTransaction(context.Background(), func(tx Store) error { - _, err := tx.GetAccountIDByPeerID(context.Background(), LockingStrengthShare, "cfvprsrlo1hqoo49ohog") + _, err := tx.GetAccountIDByPeerID(context.Background(), LockingStrengthNone, "cfvprsrlo1hqoo49ohog") return err }) if err != nil { @@ -2906,7 +2906,7 @@ func TestSqlStore_DatabaseBlocking(t *testing.T) { t.Logf("Entered routine 2-%d", i) <-start - _, err := store.GetAccountIDByPeerID(context.Background(), LockingStrengthShare, "cfvprsrlo1hqoo49ohog") + _, err := store.GetAccountIDByPeerID(context.Background(), LockingStrengthNone, "cfvprsrlo1hqoo49ohog") if err != nil { t.Errorf("Failed, got error: %v", err) return @@ -2965,7 +2965,7 @@ func TestSqlStore_GetAccountCreatedBy(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - createdBy, err := store.GetAccountCreatedBy(context.Background(), LockingStrengthShare, tt.accountID) + createdBy, err := store.GetAccountCreatedBy(context.Background(), LockingStrengthNone, tt.accountID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) @@ -3011,7 +3011,7 @@ func TestSqlStore_GetUserByUserID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, tt.userID) + user, err := store.GetUserByUserID(context.Background(), LockingStrengthNone, tt.userID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) @@ -3034,7 +3034,7 @@ func TestSqlStore_GetUserByPATID(t *testing.T) { id := "9dj38s35-63fb-11ec-90d6-0242ac120003" - user, err := store.GetUserByPATID(context.Background(), LockingStrengthShare, id) + user, err := store.GetUserByPATID(context.Background(), LockingStrengthNone, id) require.NoError(t, err) require.Equal(t, "f4f6d672-63fb-11ec-90d6-0242ac120003", user.Id) } @@ -3057,10 +3057,10 @@ func TestSqlStore_SaveUser(t *testing.T) { CreatedAt: time.Now().UTC().Add(-time.Hour), Issued: types.UserIssuedIntegration, } - err = store.SaveUser(context.Background(), LockingStrengthUpdate, user) + err = store.SaveUser(context.Background(), user) require.NoError(t, err) - saveUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, user.Id) + saveUser, err := store.GetUserByUserID(context.Background(), LockingStrengthNone, user.Id) require.NoError(t, err) require.Equal(t, user.Id, saveUser.Id) require.Equal(t, user.AccountID, saveUser.AccountID) @@ -3080,7 +3080,7 @@ func TestSqlStore_SaveUsers(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - accountUsers, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) + accountUsers, err := store.GetAccountUsers(context.Background(), LockingStrengthNone, accountID) require.NoError(t, err) require.Len(t, accountUsers, 2) @@ -3098,18 +3098,18 @@ func TestSqlStore_SaveUsers(t *testing.T) { AutoGroups: []string{"groupA"}, }, } - err = store.SaveUsers(context.Background(), LockingStrengthUpdate, users) + err = store.SaveUsers(context.Background(), users) require.NoError(t, err) - accountUsers, err = store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) + accountUsers, err = store.GetAccountUsers(context.Background(), LockingStrengthNone, accountID) require.NoError(t, err) require.Len(t, accountUsers, 4) users[1].AutoGroups = []string{"groupA", "groupC"} - err = store.SaveUsers(context.Background(), LockingStrengthUpdate, users) + err = store.SaveUsers(context.Background(), users) require.NoError(t, err) - user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, users[1].Id) + user, err := store.GetUserByUserID(context.Background(), LockingStrengthNone, users[1].Id) require.NoError(t, err) require.Equal(t, users[1].AutoGroups, user.AutoGroups) } @@ -3122,14 +3122,14 @@ func TestSqlStore_DeleteUser(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" userID := "f4f6d672-63fb-11ec-90d6-0242ac120003" - err = store.DeleteUser(context.Background(), LockingStrengthUpdate, accountID, userID) + err = store.DeleteUser(context.Background(), accountID, userID) require.NoError(t, err) - user, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, userID) + user, err := store.GetUserByUserID(context.Background(), LockingStrengthNone, userID) require.Error(t, err) require.Nil(t, user) - userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthShare, userID) + userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthNone, userID) require.NoError(t, err) require.Len(t, userPATs, 0) } @@ -3165,7 +3165,7 @@ func TestSqlStore_GetPATByID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, tt.patID) + pat, err := store.GetPATByID(context.Background(), LockingStrengthNone, userID, tt.patID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) @@ -3186,7 +3186,7 @@ func TestSqlStore_GetUserPATs(t *testing.T) { t.Cleanup(cleanup) require.NoError(t, err) - userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthShare, "f4f6d672-63fb-11ec-90d6-0242ac120003") + userPATs, err := store.GetUserPATs(context.Background(), LockingStrengthNone, "f4f6d672-63fb-11ec-90d6-0242ac120003") require.NoError(t, err) require.Len(t, userPATs, 1) } @@ -3196,7 +3196,7 @@ func TestSqlStore_GetPATByHashedToken(t *testing.T) { t.Cleanup(cleanup) require.NoError(t, err) - pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthShare, "SoMeHaShEdToKeN") + pat, err := store.GetPATByHashedToken(context.Background(), LockingStrengthNone, "SoMeHaShEdToKeN") require.NoError(t, err) require.Equal(t, "9dj38s35-63fb-11ec-90d6-0242ac120003", pat.ID) } @@ -3209,10 +3209,10 @@ func TestSqlStore_MarkPATUsed(t *testing.T) { userID := "f4f6d672-63fb-11ec-90d6-0242ac120003" patID := "9dj38s35-63fb-11ec-90d6-0242ac120003" - err = store.MarkPATUsed(context.Background(), LockingStrengthUpdate, patID) + err = store.MarkPATUsed(context.Background(), patID) require.NoError(t, err) - pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, patID) + pat, err := store.GetPATByID(context.Background(), LockingStrengthNone, userID, patID) require.NoError(t, err) now := time.Now().UTC() require.WithinRange(t, pat.LastUsed.UTC(), now.Add(-15*time.Second), now, "LastUsed should be within 1 second of now") @@ -3235,10 +3235,10 @@ func TestSqlStore_SavePAT(t *testing.T) { CreatedAt: time.Now().UTC().Add(time.Hour), LastUsed: util.ToPtr(time.Now().UTC().Add(-15 * time.Minute)), } - err = store.SavePAT(context.Background(), LockingStrengthUpdate, pat) + err = store.SavePAT(context.Background(), pat) require.NoError(t, err) - savePAT, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, pat.ID) + savePAT, err := store.GetPATByID(context.Background(), LockingStrengthNone, userID, pat.ID) require.NoError(t, err) require.Equal(t, pat.ID, savePAT.ID) require.Equal(t, pat.UserID, savePAT.UserID) @@ -3257,10 +3257,10 @@ func TestSqlStore_DeletePAT(t *testing.T) { userID := "f4f6d672-63fb-11ec-90d6-0242ac120003" patID := "9dj38s35-63fb-11ec-90d6-0242ac120003" - err = store.DeletePAT(context.Background(), LockingStrengthUpdate, userID, patID) + err = store.DeletePAT(context.Background(), userID, patID) require.NoError(t, err) - pat, err := store.GetPATByID(context.Background(), LockingStrengthShare, userID, patID) + pat, err := store.GetPATByID(context.Background(), LockingStrengthNone, userID, patID) require.Error(t, err) require.Nil(t, pat) } @@ -3272,7 +3272,7 @@ func TestSqlStore_SaveUsers_LargeBatch(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - accountUsers, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) + accountUsers, err := store.GetAccountUsers(context.Background(), LockingStrengthNone, accountID) require.NoError(t, err) require.Len(t, accountUsers, 2) @@ -3286,10 +3286,10 @@ func TestSqlStore_SaveUsers_LargeBatch(t *testing.T) { }) } - err = store.SaveUsers(context.Background(), LockingStrengthUpdate, usersToSave) + err = store.SaveUsers(context.Background(), usersToSave) require.NoError(t, err) - accountUsers, err = store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) + accountUsers, err = store.GetAccountUsers(context.Background(), LockingStrengthNone, accountID) require.NoError(t, err) require.Equal(t, 8002, len(accountUsers)) } @@ -3301,7 +3301,7 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - accountGroups, err := store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) + accountGroups, err := store.GetAccountGroups(context.Background(), LockingStrengthNone, accountID) require.NoError(t, err) require.Len(t, accountGroups, 3) @@ -3315,10 +3315,10 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) { }) } - err = store.CreateGroups(context.Background(), LockingStrengthUpdate, accountID, groupsToSave) + err = store.CreateGroups(context.Background(), accountID, groupsToSave) require.NoError(t, err) - accountGroups, err = store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) + accountGroups, err = store.GetAccountGroups(context.Background(), LockingStrengthNone, accountID) require.NoError(t, err) require.Equal(t, 8003, len(accountGroups)) } @@ -3351,7 +3351,7 @@ func TestSqlStore_GetAccountRoutes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - routes, err := store.GetAccountRoutes(context.Background(), LockingStrengthShare, tt.accountID) + routes, err := store.GetAccountRoutes(context.Background(), LockingStrengthNone, tt.accountID) require.NoError(t, err) require.Len(t, routes, tt.expectedCount) }) @@ -3388,7 +3388,7 @@ func TestSqlStore_GetRouteByID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, tt.routeID) + route, err := store.GetRouteByID(context.Background(), LockingStrengthNone, accountID, tt.routeID) if tt.expectError { require.Error(t, err) sErr, ok := status.FromError(err) @@ -3424,10 +3424,10 @@ func TestSqlStore_SaveRoute(t *testing.T) { Groups: []string{"groupA"}, AccessControlGroups: []string{}, } - err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route) + err = store.SaveRoute(context.Background(), route) require.NoError(t, err) - saveRoute, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, string(route.ID)) + saveRoute, err := store.GetRouteByID(context.Background(), LockingStrengthNone, accountID, string(route.ID)) require.NoError(t, err) require.Equal(t, route, saveRoute) @@ -3441,10 +3441,10 @@ func TestSqlStore_DeleteRoute(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" routeID := "ct03t427qv97vmtmglog" - err = store.DeleteRoute(context.Background(), LockingStrengthUpdate, accountID, routeID) + err = store.DeleteRoute(context.Background(), accountID, routeID) require.NoError(t, err) - route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, routeID) + route, err := store.GetRouteByID(context.Background(), LockingStrengthNone, accountID, routeID) require.Error(t, err) require.Nil(t, route) } @@ -3455,7 +3455,7 @@ func TestSqlStore_GetAccountMeta(t *testing.T) { require.NoError(t, err) accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - accountMeta, err := store.GetAccountMeta(context.Background(), LockingStrengthShare, accountID) + accountMeta, err := store.GetAccountMeta(context.Background(), LockingStrengthNone, accountID) require.NoError(t, err) require.NotNil(t, accountMeta) require.Equal(t, accountID, accountMeta.AccountID) @@ -3567,7 +3567,7 @@ func BenchmarkGetAccountPeers(b *testing.B) { DNSLabel: fmt.Sprintf("peer%d.example.com", i), IP: intToIPv4(uint32(i)), } - err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) + err = store.AddPeerToAccount(context.Background(), peer) if err != nil { b.Fatalf("Failed to add peer: %v", err) } @@ -3580,7 +3580,7 @@ func BenchmarkGetAccountPeers(b *testing.B) { ID: groupID, AccountID: accountID, } - err = store.CreateGroup(context.Background(), LockingStrengthUpdate, group) + err = store.CreateGroup(context.Background(), group) if err != nil { b.Fatalf("Failed to create group: %v", err) } @@ -3595,7 +3595,7 @@ func BenchmarkGetAccountPeers(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peers[i%numberOfPeers].ID) + _, err := store.GetPeerGroups(context.Background(), LockingStrengthNone, accountID, peers[i%numberOfPeers].ID) if err != nil { b.Fatal(err) } diff --git a/management/server/store/store.go b/management/server/store/store.go index 912939bc2..da4459256 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -72,8 +72,8 @@ type Store interface { SaveAccount(ctx context.Context, account *types.Account) error DeleteAccount(ctx context.Context, account *types.Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error - SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error - SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error + SaveDNSSettings(ctx context.Context, accountID string, settings *types.DNSSettings) error + SaveAccountSettings(ctx context.Context, accountID string, settings *types.Settings) error CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error @@ -81,10 +81,10 @@ type Store interface { GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) - SaveUsers(ctx context.Context, lockStrength LockingStrength, users []*types.User) error - SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error + SaveUsers(ctx context.Context, users []*types.User) error + SaveUser(ctx context.Context, user *types.User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error - DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error + DeleteUser(ctx context.Context, accountID, userID string) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error @@ -92,34 +92,34 @@ type Store interface { GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*types.PersonalAccessToken, error) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) - MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error - SavePAT(ctx context.Context, strength LockingStrength, pat *types.PersonalAccessToken) error - DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error + MarkPATUsed(ctx context.Context, patID string) error + SavePAT(ctx context.Context, pat *types.PersonalAccessToken) error + DeletePAT(ctx context.Context, userID, patID string) error GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) - CreateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error - UpdateGroups(ctx context.Context, lockStrength LockingStrength, accountID string, groups []*types.Group) error - CreateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error - UpdateGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error - DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error - DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error + CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error + UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error + CreateGroup(ctx context.Context, group *types.Group) error + UpdateGroup(ctx context.Context, group *types.Group) error + DeleteGroup(ctx context.Context, accountID, groupID string) error + DeleteGroups(ctx context.Context, accountID string, groupIDs []string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) - CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error - SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error - DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error + CreatePolicy(ctx context.Context, policy *types.Policy) error + SavePolicy(ctx context.Context, policy *types.Policy) error + DeletePolicy(ctx context.Context, accountID, policyID string) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) - SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error - DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error + SavePostureChecks(ctx context.Context, postureCheck *posture.Checks) error + DeletePostureChecks(ctx context.Context, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string, hostname string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error @@ -130,7 +130,7 @@ type Store interface { GetPeerGroupIDs(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]string, error) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error - AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error + AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) @@ -139,30 +139,30 @@ type Store interface { GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) - SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error - SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error - SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error - DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error + SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error + SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error + SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error + DeletePeer(ctx context.Context, accountID string, peerID string) error GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) - SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error - DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error + SaveSetupKey(ctx context.Context, setupKey *types.SetupKey) error + DeleteSetupKey(ctx context.Context, accountID, keyID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error) - SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error - DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error + SaveRoute(ctx context.Context, route *route.Route) error + DeleteRoute(ctx context.Context, accountID, routeID string) error GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) - SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error - DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error + SaveNameServerGroup(ctx context.Context, nameServerGroup *dns.NameServerGroup) error + DeleteNameServerGroup(ctx context.Context, accountID, nameServerGroupID string) error GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) - IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error + IncrementNetworkSerial(ctx context.Context, accountId string) error GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error) GetInstallationID() string @@ -184,21 +184,21 @@ type Store interface { GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) - SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error - DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error + SaveNetwork(ctx context.Context, network *networkTypes.Network) error + DeleteNetwork(ctx context.Context, accountID, networkID string) error GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) - SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error - DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error + SaveNetworkRouter(ctx context.Context, router *routerTypes.NetworkRouter) error + DeleteNetworkRouter(ctx context.Context, accountID, routerID string) error GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) - SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error - DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error + SaveNetworkResource(ctx context.Context, resource *resourceTypes.NetworkResource) error + DeleteNetworkResource(ctx context.Context, accountID, resourceID string) error GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error) diff --git a/management/server/user.go b/management/server/user.go index e33f89dfd..ba1835f22 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -17,11 +17,11 @@ import ( nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/status" ) // createServiceUser creates a new service user under the given account. @@ -46,7 +46,7 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI newUser.AccountID = accountID log.WithContext(ctx).Debugf("New User: %v", newUser) - if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil { + if err = am.Store.SaveUser(ctx, newUser); err != nil { return nil, err } @@ -95,14 +95,14 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return nil, status.NewPermissionDeniedError() } - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { return nil, err } inviterID := userID if initiatorUser.IsServiceUser { - createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthShare, accountID) + createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } @@ -124,7 +124,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u CreatedAt: time.Now().UTC(), } - if err = am.Store.SaveUser(ctx, store.LockingStrengthUpdate, newUser); err != nil { + if err = am.Store.SaveUser(ctx, newUser); err != nil { return nil, err } @@ -178,13 +178,13 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID } func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { - return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) + return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id) } // GetUser looks up a user by provided nbContext.UserAuths. // Expects account to have been created already. func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) { - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId) if err != nil { return nil, err } @@ -209,11 +209,11 @@ func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAu // ListUsers returns lists of all users under the account. // It doesn't populate user information such as email or name. func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { - return am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) + return am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) } func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error { - if err := am.Store.DeleteUser(ctx, store.LockingStrengthUpdate, accountID, targetUser.Id); err != nil { + if err := am.Store.DeleteUser(ctx, accountID, targetUser.Id); err != nil { return err } meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt} @@ -230,7 +230,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return err } @@ -243,7 +243,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return status.NewPermissionDeniedError() } - targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) if err != nil { return err } @@ -347,12 +347,12 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.NewPermissionDeniedError() } - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) if err != nil { return nil, err } @@ -367,7 +367,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) } - if err = am.Store.SavePAT(ctx, store.LockingStrengthUpdate, &pat.PersonalAccessToken); err != nil { + if err = am.Store.SavePAT(ctx, &pat.PersonalAccessToken); err != nil { return nil, err } @@ -390,12 +390,12 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string return status.NewPermissionDeniedError() } - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return err } - targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) if err != nil { return err } @@ -404,12 +404,12 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string return status.NewAdminPermissionError() } - pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID) + pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID) if err != nil { return err } - if err = am.Store.DeletePAT(ctx, store.LockingStrengthUpdate, targetUserID, tokenID); err != nil { + if err = am.Store.DeletePAT(ctx, targetUserID, tokenID); err != nil { return err } @@ -429,12 +429,12 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i return nil, status.NewPermissionDeniedError() } - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) if err != nil { return nil, err } @@ -443,7 +443,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i return nil, status.NewAdminPermissionError() } - return am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID) + return am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID) } // GetAllPATs returns all PATs for a user @@ -456,12 +456,12 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin return nil, status.NewPermissionDeniedError() } - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) if err != nil { return nil, err } @@ -470,7 +470,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin return nil, status.NewAdminPermissionError() } - return am.Store.GetUserPATs(ctx, store.LockingStrengthShare, targetUserID) + return am.Store.GetUserPATs(ctx, store.LockingStrengthNone, targetUserID) } // SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error. @@ -511,7 +511,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, if !allowed { return nil, status.NewPermissionDeniedError() } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } @@ -521,7 +521,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, var addUserEvents []func() var usersToSave = make([]*types.User, 0, len(updates)) - groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, fmt.Errorf("error getting account groups: %w", err) } @@ -533,7 +533,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, var initiatorUser *types.User if initiatorUserID != activity.SystemInitiator { - result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) + result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return nil, err } @@ -560,7 +560,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, updateAccountPeers = true } } - return transaction.SaveUsers(ctx, store.LockingStrengthUpdate, usersToSave) + return transaction.SaveUsers(ctx, usersToSave) }) if err != nil { return nil, err @@ -593,7 +593,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if settings.GroupsPropagationEnabled && updateAccountPeers { - if err = am.Store.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + if err = am.Store.IncrementNetworkSerial(ctx, accountID); err != nil { return nil, fmt.Errorf("failed to increment network serial: %w", err) } am.UpdateAccountPeers(ctx, accountID) @@ -700,7 +700,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact // getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist. func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) { - existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id) + existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, update.Id) if err != nil { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if !addIfNotExists { @@ -724,7 +724,7 @@ func handleOwnerRoleTransfer(ctx context.Context, transaction store.Store, initi newInitiatorUser := initiatorUser.Copy() newInitiatorUser.Role = types.UserRoleAdmin - if err := transaction.SaveUser(ctx, store.LockingStrengthUpdate, newInitiatorUser); err != nil { + if err := transaction.SaveUser(ctx, newInitiatorUser); err != nil { return false, err } return true, nil @@ -835,7 +835,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun var user *types.User if initiatorUserID != activity.SystemInitiator { - result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) + result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return nil, fmt.Errorf("failed to get user: %w", err) } @@ -845,7 +845,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun accountUsers := []*types.User{} switch { case allowed: - accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) + accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } @@ -939,7 +939,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a // expireAndUpdatePeers expires all peers of the given user and updates them in the account func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { log.WithContext(ctx).Debugf("Expiring %d peers for account %s", len(peers), accountID) - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return err } @@ -956,7 +956,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou peerIDs = append(peerIDs, peer.ID) peer.MarkLoginExpired(true) - if err := am.Store.SavePeerStatus(ctx, store.LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil { + if err := am.Store.SavePeerStatus(ctx, accountID, peer.ID, *peer.Status); err != nil { return err } am.StoreEvent( @@ -1009,7 +1009,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return status.NewPermissionDeniedError() } - initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID) if err != nil { return err } @@ -1023,7 +1023,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account continue } - targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID) if err != nil { allErrors = errors.Join(allErrors, err) continue @@ -1087,12 +1087,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserInfo.ID) + targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, targetUserInfo.ID) if err != nil { return fmt.Errorf("failed to get user to delete: %w", err) } - userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, targetUserInfo.ID) + userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID) if err != nil { return fmt.Errorf("failed to get user peers: %w", err) } @@ -1105,7 +1105,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI } } - if err = transaction.DeleteUser(ctx, store.LockingStrengthUpdate, accountID, targetUserInfo.ID); err != nil { + if err = transaction.DeleteUser(ctx, accountID, targetUserInfo.ID); err != nil { return fmt.Errorf("failed to delete user: %s %w", targetUserInfo.ID, err) } @@ -1126,7 +1126,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI // GetOwnerInfo retrieves the owner information for a given account ID. func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) { - owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthShare, accountID) + owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } @@ -1176,7 +1176,7 @@ func validateUserInvite(invite *types.UserInfo) error { func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { accountID, userID := userAuth.AccountId, userAuth.UserId - user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) if err != nil { return nil, err } @@ -1193,7 +1193,7 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut return nil, err } - settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, err } diff --git a/management/server/user_test.go b/management/server/user_test.go index 7515439b8..8ab0c1565 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -15,9 +15,9 @@ import ( "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/roles" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/status" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" @@ -88,7 +88,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { assert.Equal(t, pat.ID, tokenID) - user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID) + user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthNone, tokenID) if err != nil { t.Fatalf("Error when getting user by token ID: %s", err) } @@ -1521,7 +1521,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) { _, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true) assert.Error(t, err, "update user to another account should fail") - user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthShare, targetId) + user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetId) require.NoError(t, err) assert.Equal(t, account1.Users[targetId].Id, user.Id) assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID) diff --git a/management/server/users/manager.go b/management/server/users/manager.go index 718eb6190..e07f28706 100644 --- a/management/server/users/manager.go +++ b/management/server/users/manager.go @@ -26,7 +26,7 @@ func NewManager(store store.Store) Manager { } func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) { - return m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID) } func NewManagerMock() Manager {