diff --git a/management/server/account.go b/management/server/account.go index da1e43370..403be7286 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -11,6 +11,7 @@ import ( "net/netip" "reflect" "regexp" + "runtime/debug" "strings" "sync" "time" @@ -76,7 +77,7 @@ type AccountManager interface { GetUser(claims jwtclaims.AuthorizationClaims) (*User, error) ListUsers(accountID string) ([]*User, error) GetPeers(accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error + MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *Account) error DeletePeer(accountID, peerID, userID string) error UpdatePeer(accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) GetNetworkMap(peerID string) (*NetworkMap, error) @@ -117,8 +118,8 @@ type AccountManager interface { SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) - LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API - SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API + LoginPeer(login PeerLogin) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API + SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager @@ -130,6 +131,8 @@ type AccountManager interface { UpdateIntegratedValidatorGroups(accountID string, userID string, groups []string) error GroupValidation(accountId string, groups []string) (bool, error) GetValidatedPeers(account *Account) (map[string]struct{}, error) + SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) + CancelPeerRoutines(peer *nbpeer.Peer) error } type DefaultAccountManager struct { @@ -386,6 +389,8 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group { // GetPeerNetworkMap returns a group by ID if exists, nil otherwise func (a *Account) GetPeerNetworkMap(peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap { + log.Debugf("GetNetworkMap with trace: %s", string(debug.Stack())) + peer := a.Peers[peerID] if peer == nil { return &NetworkMap{ @@ -958,7 +963,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -1009,7 +1014,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string, func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -1108,7 +1113,7 @@ func (am *DefaultAccountManager) warmupIDPCache() error { // DeleteAccount deletes an account and all its users from local store and from the remote IDP if the requester is an admin and account owner func (am *DefaultAccountManager) DeleteAccount(accountID, userID string) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -1567,7 +1572,7 @@ func (am *DefaultAccountManager) MarkPATUsed(tokenID string) error { return err } - unlock := am.Store.AcquireAccountLock(account.Id) + unlock := am.Store.AcquireAccountWriteLock(account.Id) defer unlock() account, err = am.Store.GetAccountByUser(user.Id) @@ -1650,7 +1655,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat if err != nil { return nil, nil, err } - unlock := am.Store.AcquireAccountLock(newAcc.Id) + unlock := am.Store.AcquireAccountWriteLock(newAcc.Id) alreadyUnlocked := false defer func() { if !alreadyUnlocked { @@ -1801,7 +1806,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla account, err := am.Store.GetAccountByUser(claims.UserId) if err == nil { - unlockAccount := am.Store.AcquireAccountLock(account.Id) + unlockAccount := am.Store.AcquireAccountWriteLock(account.Id) defer unlockAccount() account, err = am.Store.GetAccountByUser(claims.UserId) if err != nil { @@ -1821,7 +1826,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla return account, nil } else if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if domainAccount != nil { - unlockAccount := am.Store.AcquireAccountLock(domainAccount.Id) + unlockAccount := am.Store.AcquireAccountWriteLock(domainAccount.Id) defer unlockAccount() domainAccount, err = am.Store.GetAccountByPrivateDomain(claims.Domain) if err != nil { @@ -1835,6 +1840,56 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtcla } } +func (am *DefaultAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *NetworkMap, error) { + accountID, err := am.Store.GetAccountIDByPeerPubKey(peerPubKey) + if err != nil { + return nil, nil, err + } + + unlock := am.Store.AcquireAccountReadLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, nil, err + } + + peer, netMap, err := am.SyncPeer(PeerSync{WireGuardPubKey: peerPubKey}, account) + if err != nil { + return nil, nil, mapError(err) + } + + err = am.MarkPeerConnected(peerPubKey, true, realIP, account) + if err != nil { + log.Warnf("failed marking peer as connected %s %v", peerPubKey, err) + } + + return peer, netMap, nil +} + +func (am *DefaultAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error { + accountID, err := am.Store.GetAccountIDByPeerPubKey(peer.Key) + if err != nil { + return err + } + + unlock := am.Store.AcquireAccountWriteLock(accountID) + defer unlock() + + account, err := am.Store.GetAccount(accountID) + if err != nil { + return err + } + + err = am.MarkPeerConnected(peer.Key, false, nil, account) + if err != nil { + log.Warnf("failed marking peer as connected %s %v", peer.Key, err) + } + + return nil + +} + // GetAllConnectedPeers returns connected peers based on peersUpdateManager.GetAllConnectedPeers() func (am *DefaultAccountManager) GetAllConnectedPeers() (map[string]struct{}, error) { return am.peersUpdateManager.GetAllConnectedPeers(), nil diff --git a/management/server/account_test.go b/management/server/account_test.go index 456963361..0b2992760 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1655,7 +1655,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + _, err = manager.GetAccountByUserOrAccountID(userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1666,7 +1666,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil) + + account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + require.NoError(t, err, "unable to get the account") + err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") account, err = manager.UpdateAccountSettings(account.Id, userID, &Settings{ PeerLoginExpiration: time.Hour, @@ -1732,8 +1735,10 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. }, } + account, err = manager.GetAccountByUserOrAccountID(userID, "", "") + require.NoError(t, err, "unable to get the account") // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil) + err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1745,7 +1750,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + _, err = manager.GetAccountByUserOrAccountID(userID, "", "") require.NoError(t, err, "unable to create an account") key, err := wgtypes.GenerateKey() @@ -1756,7 +1761,10 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil) + + account, err := manager.GetAccountByUserOrAccountID(userID, "", "") + require.NoError(t, err, "unable to get the account") + err = manager.MarkPeerConnected(key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} diff --git a/management/server/dns.go b/management/server/dns.go index f6e3531ec..5e2febf55 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -35,7 +35,7 @@ func (d DNSSettings) Copy() DNSSettings { // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) (*DNSSettings, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -57,7 +57,7 @@ func (am *DefaultAccountManager) GetDNSSettings(accountID string, userID string) // SaveDNSSettings validates a user role and updates the account's DNS settings func (am *DefaultAccountManager) SaveDNSSettings(accountID string, userID string, dnsSettingsToSave *DNSSettings) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) diff --git a/management/server/event.go b/management/server/event.go index dd253717a..303f88a79 100644 --- a/management/server/event.go +++ b/management/server/event.go @@ -12,7 +12,7 @@ import ( // GetEvents returns a list of activity events of an account func (am *DefaultAccountManager) GetEvents(accountID, userID string) ([]*activity.Event, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) diff --git a/management/server/file_store.go b/management/server/file_store.go index 2de852bee..a6e29ec44 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -279,8 +279,8 @@ func (s *FileStore) AcquireGlobalLock() (unlock func()) { return unlock } -// AcquireAccountLock acquires account lock and returns a function that releases the lock -func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) { +// AcquireAccountWriteLock acquires account lock for writing to a resource and returns a function that releases the lock +func (s *FileStore) AcquireAccountWriteLock(accountID string) (unlock func()) { log.Debugf("acquiring lock for account %s", accountID) start := time.Now() value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) @@ -295,6 +295,12 @@ func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) { return unlock } +// AcquireAccountReadLock AcquireAccountWriteLock acquires account lock for reading a resource and returns a function that releases the lock +// This method is still returns a write lock as file store can't handle read locks +func (s *FileStore) AcquireAccountReadLock(accountID string) (unlock func()) { + return s.AcquireAccountWriteLock(accountID) +} + func (s *FileStore) SaveAccount(account *Account) error { s.mux.Lock() defer s.mux.Unlock() @@ -572,6 +578,18 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { return account.Copy(), nil } +func (s *FileStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) { + s.mux.Lock() + defer s.mux.Unlock() + + accountID, ok := s.PeerKeyID2AccountID[peerKey] + if !ok { + return "", status.Errorf(status.NotFound, "provided peer key doesn't exists %s", peerKey) + } + + return accountID, nil +} + // GetInstallationID returns the installation ID from the store func (s *FileStore) GetInstallationID() string { return s.InstallationID diff --git a/management/server/group.go b/management/server/group.go index 0d93ab5e5..5232c1e7a 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -22,7 +22,7 @@ func (e *GroupLinkError) Error() string { // GetGroup object of the peers func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID, userID string) (*n // GetAllGroups returns all groups in an account func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -76,7 +76,7 @@ func (am *DefaultAccountManager) GetAllGroups(accountID string, userID string) ( // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -109,7 +109,7 @@ func (am *DefaultAccountManager) GetGroupByName(groupName, accountID string) (*n // SaveGroup object of the peers func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *nbgroup.Group) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -214,7 +214,7 @@ func difference(a, b []string) []string { // DeleteGroup object of the peers func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error { - unlock := am.Store.AcquireAccountLock(accountId) + unlock := am.Store.AcquireAccountWriteLock(accountId) defer unlock() account, err := am.Store.GetAccount(accountId) @@ -323,7 +323,7 @@ func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) // ListGroups objects of the peers func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -341,7 +341,7 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*nbgroup.Group, // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -377,7 +377,7 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerID string) // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerID string) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 4df24711e..e65046117 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -134,9 +134,9 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi return err } - peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()}) + peer, netMap, err := s.accountManager.SyncAndMarkPeer(peerKey.String(), realIP) if err != nil { - return mapError(err) + return err } err = s.sendInitialSync(peerKey, peer, netMap, srv) @@ -149,11 +149,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi s.ephemeralManager.OnPeerConnected(peer) - err = s.accountManager.MarkPeerConnected(peerKey.String(), true, realIP) - if err != nil { - log.Warnf("failed marking peer as connected %s %v", peerKey, err) - } - if s.config.TURNConfig.TimeBasedCredentials { s.turnCredentialsManager.SetupRefresh(peer.ID) } @@ -207,7 +202,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi func (s *GRPCServer) cancelPeerRoutines(peer *nbpeer.Peer) { s.peersUpdateManager.CloseChannel(peer.ID) s.turnCredentialsManager.CancelRefresh(peer.ID) - _ = s.accountManager.MarkPeerConnected(peer.Key, false, nil) + _ = s.accountManager.CancelPeerRoutines(peer) s.ephemeralManager.OnPeerDisconnected(peer) } diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index cd770a801..198f8d527 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -31,7 +31,7 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(accountID strin return errors.New("invalid groups") } - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() a, err := am.Store.GetAccountByUser(userID) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index e3f0edd01..259bd645d 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -28,6 +28,7 @@ type MockAccountManager struct { ListUsersFunc func(accountID string) ([]*server.User, error) GetPeersFunc func(accountID, userID string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(peerKey string, connected bool, realIP net.IP) error + SyncAndMarkPeerFunc func(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) DeletePeerFunc func(accountID, peerKey, userID string) error GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) GetPeerNetworkFunc func(peerKey string) (*server.Network, error) @@ -82,7 +83,7 @@ type MockAccountManager struct { GetPeerFunc func(accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) LoginPeerFunc func(login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, error) - SyncPeerFunc func(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) + SyncPeerFunc func(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error) InviteUserFunc func(accountID string, initiatorUserID string, targetUserEmail string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool @@ -96,6 +97,18 @@ type MockAccountManager struct { GroupValidationFunc func(accountId string, groups []string) (bool, error) } +func (am *MockAccountManager) SyncAndMarkPeer(peerPubKey string, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, error) { + if am.SyncAndMarkPeerFunc != nil { + return am.SyncAndMarkPeerFunc(peerPubKey, realIP) + } + return nil, nil, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") +} + +func (am *MockAccountManager) CancelPeerRoutines(peer *nbpeer.Peer) error { + // TODO implement me + panic("implement me") +} + func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { approvedPeers := make(map[string]struct{}) for id := range account.Peers { @@ -180,7 +193,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountID( } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP) error { +func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool, realIP net.IP, account *server.Account) error { if am.MarkPeerConnectedFunc != nil { return am.MarkPeerConnectedFunc(peerKey, connected, realIP) } @@ -626,9 +639,9 @@ func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*nbpeer.Peer, * } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*nbpeer.Peer, *server.NetworkMap, error) { +func (am *MockAccountManager) SyncPeer(sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, error) { if am.SyncPeerFunc != nil { - return am.SyncPeerFunc(sync) + return am.SyncPeerFunc(sync, account) } return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index fa7793602..44d231c3e 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -19,7 +19,7 @@ const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -47,7 +47,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, userID, nsGroupID // CreateNameServerGroup creates and saves a new nameserver group func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -94,7 +94,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d // SaveNameServerGroup saves nameserver group func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() if nsGroupToSave == nil { @@ -129,7 +129,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n // DeleteNameServerGroup deletes nameserver group with nsGroupID func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -159,7 +159,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, use // ListNameServerGroups returns a list of nameserver groups from account func (am *DefaultAccountManager) ListNameServerGroups(accountID string, userID string) ([]*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) diff --git a/management/server/peer.go b/management/server/peer.go index 57aa91316..95d0179ce 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -88,21 +88,7 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*nbpeer.P } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP) error { - account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) - if err != nil { - return err - } - - unlock := am.Store.AcquireAccountLock(account.Id) - defer unlock() - - // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(account.Id) - if err != nil { - return err - } - +func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool, realIP net.IP, account *Account) error { peer, err := account.FindPeerByPubKey(peerPubKey) if err != nil { return err @@ -156,7 +142,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, and Peer.LoginExpirationEnabled can be updated. func (am *DefaultAccountManager) UpdatePeer(accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -278,7 +264,7 @@ func (am *DefaultAccountManager) deletePeers(account *Account, peerIDs []string, // DeletePeer removes peer from the account by its IP func (am *DefaultAccountManager) DeletePeer(accountID, peerID, userID string) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -362,7 +348,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P return nil, nil, status.Errorf(status.NotFound, "failed adding new peer: account not found") } - unlock := am.Store.AcquireAccountLock(account.Id) + unlock := am.Store.AcquireAccountWriteLock(account.Id) defer unlock() // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) @@ -381,7 +367,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P } // This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice. - // Such case is possible when AddPeer function takes long time to finish after AcquireAccountLock (e.g., database is slow) + // Such case is possible when AddPeer function takes long time to finish after AcquireAccountWriteLock (e.g., database is slow) // and the peer disconnects with a timeout and tries to register again. // We just check if this machine has been registered before and reject the second registration. // The connecting peer should be able to recover with a retry. @@ -518,25 +504,7 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *nbpeer.P } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*nbpeer.Peer, *NetworkMap, error) { - account, err := am.Store.GetAccountByPeerPubKey(sync.WireGuardPubKey) - if err != nil { - if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { - return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered") - } - return nil, nil, err - } - - // we found the peer, and we follow a normal login flow - unlock := am.Store.AcquireAccountLock(account.Id) - defer unlock() - - // fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies - account, err = am.Store.GetAccount(account.Id) - if err != nil { - return nil, nil, err - } - +func (am *DefaultAccountManager) SyncPeer(sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, error) { peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) if err != nil { return nil, nil, status.Errorf(status.Unauthenticated, "peer is not registered") @@ -603,7 +571,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*nbpeer.Peer, *Netw } // we found the peer, and we follow a normal login flow - unlock := am.Store.AcquireAccountLock(account.Id) + unlock := am.Store.AcquireAccountWriteLock(account.Id) defer unlock() // fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies @@ -760,7 +728,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) return err } - unlock := am.Store.AcquireAccountLock(account.Id) + unlock := am.Store.AcquireAccountWriteLock(account.Id) defer unlock() // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) @@ -795,7 +763,7 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerID string, sshKey string) // GetPeer for a given accountID, peerID and userID error if not found. func (am *DefaultAccountManager) GetPeer(accountID, peerID, userID string) (*nbpeer.Peer, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) diff --git a/management/server/policy.go b/management/server/policy.go index e162d2b3b..704825cae 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -314,7 +314,7 @@ func (a *Account) connResourcesGenerator() (func(*PolicyRule, []*nbpeer.Peer, in // GetPolicy from the store func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) (*Policy, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -342,7 +342,7 @@ func (am *DefaultAccountManager) GetPolicy(accountID, policyID, userID string) ( // SavePolicy in the store func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Policy) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -370,7 +370,7 @@ func (am *DefaultAccountManager) SavePolicy(accountID, userID string, policy *Po // DeletePolicy from the store func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -397,7 +397,7 @@ func (am *DefaultAccountManager) DeletePolicy(accountID, policyID, userID string // ListPolicies from the store func (am *DefaultAccountManager) ListPolicies(accountID, userID string) ([]*Policy, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 7e654b5fb..fb904c10f 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -7,7 +7,7 @@ import ( ) func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, userID string) (*posture.Checks, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -34,7 +34,7 @@ func (am *DefaultAccountManager) GetPostureChecks(accountID, postureChecksID, us } func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, postureChecks *posture.Checks) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -81,7 +81,7 @@ func (am *DefaultAccountManager) SavePostureChecks(accountID, userID string, pos } func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, userID string) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -113,7 +113,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(accountID, postureChecksID, } func (am *DefaultAccountManager) ListPostureChecks(accountID, userID string) ([]*posture.Checks, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) diff --git a/management/server/route.go b/management/server/route.go index 0b7658441..2de813d48 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -14,7 +14,7 @@ import ( // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(accountID string, routeID route.ID, userID string) (*route.Route, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -115,7 +115,7 @@ func (am *DefaultAccountManager) checkRoutePrefixExistsForPeers(account *Account // CreateRoute creates and saves a new route func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -194,7 +194,7 @@ func (am *DefaultAccountManager) CreateRoute(accountID, network, peerID string, // SaveRoute saves route func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave *route.Route) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() if routeToSave == nil { @@ -255,7 +255,7 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave // DeleteRoute deletes route with routeID func (am *DefaultAccountManager) DeleteRoute(accountID string, routeID route.ID, userID string) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -283,7 +283,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID string, routeID route.ID, // ListRoutes returns a list of routes from account func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index ff6fb3204..40b8ac457 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -209,7 +209,7 @@ func Hash(s string) uint32 { // and adds it to the specified account. A list of autoGroups IDs can be empty. func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() keyDuration := DefaultSetupKeyDuration @@ -255,7 +255,7 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string // (e.g. the key itself, creation date, ID, etc). // These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() if keyToSave == nil { @@ -327,7 +327,7 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup // ListSetupKeys returns a list of all setup keys of the account func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -359,7 +359,7 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index 853816fd3..a206575a7 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -127,17 +127,33 @@ func (s *SqliteStore) AcquireGlobalLock() (unlock func()) { return unlock } -func (s *SqliteStore) AcquireAccountLock(accountID string) (unlock func()) { - log.Tracef("acquiring lock for account %s", accountID) +func (s *SqliteStore) AcquireAccountWriteLock(accountID string) (unlock func()) { + log.Tracef("acquiring write lock for account %s", accountID) start := time.Now() - value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) - mtx := value.(*sync.Mutex) + value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{}) + mtx := value.(*sync.RWMutex) mtx.Lock() unlock = func() { mtx.Unlock() - log.Tracef("released lock for account %s in %v", accountID, time.Since(start)) + log.Tracef("released write lock for account %s in %v", accountID, time.Since(start)) + } + + return unlock +} + +func (s *SqliteStore) AcquireAccountReadLock(accountID string) (unlock func()) { + log.Tracef("acquiring read lock for account %s", accountID) + + start := time.Now() + value, _ := s.accountLocks.LoadOrStore(accountID, &sync.RWMutex{}) + mtx := value.(*sync.RWMutex) + mtx.RLock() + + unlock = func() { + mtx.RUnlock() + log.Tracef("released read lock for account %s in %v", accountID, time.Since(start)) } return unlock @@ -263,36 +279,43 @@ func (s *SqliteStore) GetInstallationID() string { } func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error { - var peer nbpeer.Peer + var peerCopy nbpeer.Peer + peerCopy.Status = &peerStatus + result := s.db.Model(&nbpeer.Peer{}). + Where("account_id = ? AND id = ?", accountID, peerID). + Updates(peerCopy) - result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID) if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.Errorf(status.NotFound, "peer %s not found", peerID) - } - log.Errorf("error when getting peer from the store: %s", result.Error) - return status.Errorf(status.Internal, "issue getting peer from store") + return result.Error } - peer.Status = &peerStatus + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "peer %s not found", peerID) + } - return s.db.Save(peer).Error + return nil } func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error { - var peer nbpeer.Peer - result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerWithLocation.ID) + // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. + var peerCopy nbpeer.Peer + // Since the location field has been migrated to JSON serialization, + // updating the struct ensures the correct data format is inserted into the database. + peerCopy.Location = peerWithLocation.Location + + result := s.db.Model(&nbpeer.Peer{}). + Where("account_id = ? and id = ?", accountID, peerWithLocation.ID). + Updates(peerCopy) + if result.Error != nil { - if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.Errorf(status.NotFound, "peer %s not found", peer.ID) - } - log.Errorf("error when getting peer from the store: %s", result.Error) - return status.Errorf(status.Internal, "issue getting peer from store") + return result.Error } - peer.Location = peerWithLocation.Location + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "peer %s not found", peerWithLocation.ID) + } - return s.db.Save(peer).Error + return nil } // DeleteHashedPAT2TokenIDIndex is noop in Sqlite @@ -400,6 +423,7 @@ func (s *SqliteStore) GetAllAccounts() (all []*Account) { } func (s *SqliteStore) GetAccount(accountID string) (*Account, error) { + var account Account result := s.db.Model(&account). Preload("UsersG.PATsG"). // have to be specifies as this is nester reference @@ -521,6 +545,21 @@ func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { return s.GetAccount(peer.AccountID) } +func (s *SqliteStore) GetAccountIDByPeerPubKey(peerKey string) (string, error) { + var peer nbpeer.Peer + var accountID string + result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + } + log.Errorf("error when getting peer from the store: %s", result.Error) + return "", status.Errorf(status.Internal, "issue getting account from store") + } + + return accountID, nil +} + // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { var user User diff --git a/management/server/store.go b/management/server/store.go index 77b8d0dad..8674f1cf2 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -19,6 +19,7 @@ type Store interface { DeleteAccount(account *Account) error GetAccountByUser(userID string) (*Account, error) GetAccountByPeerPubKey(peerKey string) (*Account, error) + GetAccountIDByPeerPubKey(peerKey string) (string, error) GetAccountByPeerID(peerID string) (*Account, error) GetAccountBySetupKey(setupKey string) (*Account, error) // todo use key hash later GetAccountByPrivateDomain(domain string) (*Account, error) @@ -29,8 +30,10 @@ type Store interface { DeleteTokenID2UserIDIndex(tokenID string) error GetInstallationID() string SaveInstallationID(ID string) error - // AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock - AcquireAccountLock(accountID string) func() + // AcquireAccountWriteLock should attempt to acquire account lock for write purposes and return a function that releases the lock + AcquireAccountWriteLock(accountID string) func() + // AcquireAccountReadLock should attempt to acquire account lock for read purposes and return a function that releases the lock + AcquireAccountReadLock(accountID string) func() // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock AcquireGlobalLock() func() SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error diff --git a/management/server/user.go b/management/server/user.go index 4ae13d101..6d1879285 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -210,7 +210,7 @@ func NewOwnerUser(id string) *User { // createServiceUser creates a new service user under the given account. func (am *DefaultAccountManager) createServiceUser(accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -266,7 +266,7 @@ func (am *DefaultAccountManager) CreateUser(accountID, userID string, user *User // inviteNewUser Invites a USer to a given account and creates reference in datastore func (am *DefaultAccountManager) inviteNewUser(accountID, userID string, invite *UserInfo) (*UserInfo, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() if am.idpManager == nil { @@ -367,7 +367,7 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) ( return nil, fmt.Errorf("failed to get account with token claims %v", err) } - unlock := am.Store.AcquireAccountLock(account.Id) + unlock := am.Store.AcquireAccountWriteLock(account.Id) defer unlock() account, err = am.Store.GetAccount(account.Id) @@ -400,7 +400,7 @@ func (am *DefaultAccountManager) GetUser(claims jwtclaims.AuthorizationClaims) ( // ListUsers returns lists of all users under the account. // It doesn't populate user information such as email or name. func (am *DefaultAccountManager) ListUsers(accountID string) ([]*User, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -427,7 +427,7 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t if initiatorUserID == targetUserID { return status.Errorf(status.InvalidArgument, "self deletion is not allowed") } - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -537,7 +537,7 @@ func (am *DefaultAccountManager) deleteUserPeers(initiatorUserID string, targetU // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID string, targetUserID string) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() if am.idpManager == nil { @@ -577,7 +577,7 @@ func (am *DefaultAccountManager) InviteUser(accountID string, initiatorUserID st // CreatePAT creates a new PAT for the given user func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() if tokenName == "" { @@ -627,7 +627,7 @@ func (am *DefaultAccountManager) CreatePAT(accountID string, initiatorUserID str // DeletePAT deletes a specific PAT from a user func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) error { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -677,7 +677,7 @@ func (am *DefaultAccountManager) DeletePAT(accountID string, initiatorUserID str // GetPAT returns a specific PAT from a user func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -709,7 +709,7 @@ func (am *DefaultAccountManager) GetPAT(accountID string, initiatorUserID string // GetAllPATs returns all PATs for a user func (am *DefaultAccountManager) GetAllPATs(accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() account, err := am.Store.GetAccount(accountID) @@ -747,7 +747,7 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. func (am *DefaultAccountManager) SaveOrAddUser(accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { - unlock := am.Store.AcquireAccountLock(accountID) + unlock := am.Store.AcquireAccountWriteLock(accountID) defer unlock() if update == nil {