From ed7ac81027fdd4680e2ea3e6aeadab3ab17cc95b Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Mon, 7 Nov 2022 17:52:23 +0100 Subject: [PATCH] Introduce locking on the account level (#548) --- management/server/account.go | 61 ++++++------------ management/server/account_test.go | 22 +++---- management/server/file_store.go | 38 +++++++++++- management/server/group.go | 40 +++++++----- management/server/mock_server/account_mock.go | 13 +--- management/server/nameserver.go | 30 +++++---- management/server/nameserver_test.go | 2 +- management/server/peer.go | 62 ++++++++++++------- management/server/route.go | 24 +++---- management/server/route_test.go | 4 +- management/server/rule.go | 20 +++--- management/server/setupkey.go | 16 ++--- management/server/store.go | 4 ++ management/server/user.go | 30 ++++----- 14 files changed, 200 insertions(+), 166 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 9ceff43a9..2cb0be3ed 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -38,7 +38,6 @@ func cacheEntryExpiration() time.Duration { type AccountManager interface { GetOrCreateAccountByUser(userId, domain string) (*Account, error) - GetAccountByUser(userId string) (*Account, error) CreateSetupKey( accountId string, keyName string, @@ -51,8 +50,7 @@ type AccountManager interface { ListSetupKeys(accountID, userID string) ([]*SetupKey, error) SaveUser(accountID string, key *User) (*UserInfo, error) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) - GetAccountById(accountId string) (*Account, error) - GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) + GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExists(accountId string) (*bool, error) @@ -97,8 +95,6 @@ type AccountManager interface { type DefaultAccountManager struct { Store Store - // mux to synchronise account operations (e.g. generating Peer IP address inside the Network) - mux sync.Mutex // cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID cacheMux sync.Mutex // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded @@ -359,7 +355,6 @@ func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManage singleAccountModeDomain string, dnsDomain string) (*DefaultAccountManager, error) { am := &DefaultAccountManager{ Store: store, - mux: sync.Mutex{}, peersUpdateManager: peersUpdateManager, idpManager: idpManager, ctx: context.Background(), @@ -460,32 +455,17 @@ func (am *DefaultAccountManager) warmupIDPCache() error { return nil } -// GetAccountById returns an existing account using its ID or error (NotFound) if doesn't exist -func (am *DefaultAccountManager) GetAccountById(accountId string) (*Account, error) { - am.mux.Lock() - defer am.mux.Unlock() - - account, err := am.Store.GetAccount(accountId) - if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found") - } - - return account, nil -} - -// GetAccountByUserOrAccountId look for an account by user or account Id, if no account is provided and -// user id doesn't have an account associated with it, one account is created -func (am *DefaultAccountManager) GetAccountByUserOrAccountId( - userId, accountId, domain string, -) (*Account, error) { - if accountId != "" { - return am.GetAccountById(accountId) - } else if userId != "" { - account, err := am.GetOrCreateAccountByUser(userId, domain) +// GetAccountByUserOrAccountID looks for an account by user or accountID, if no account is provided and +// userID doesn't have an account associated with it, one account is created +func (am *DefaultAccountManager) GetAccountByUserOrAccountID(userID, accountID, domain string) (*Account, error) { + if accountID != "" { + return am.Store.GetAccount(accountID) + } else if userID != "" { + account, err := am.GetOrCreateAccountByUser(userID, domain) if err != nil { - return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userId) + return nil, status.Errorf(codes.NotFound, "account not found using user id: %s", userID) } - err = am.addAccountIDToIDPAppMeta(userId, account) + err = am.addAccountIDToIDPAppMeta(userID, account) if err != nil { return nil, err } @@ -825,15 +805,13 @@ func (am *DefaultAccountManager) GetAccountFromToken(claims jwtclaims.Authorizat // Existing user + Existing account + Existing Indexed Domain -> Nothing changes // // Existing user + Existing account + Existing domain reclassified Domain as private -> Nothing changes (index domain) -func (am *DefaultAccountManager) getAccountWithAuthorizationClaims( - claims jwtclaims.AuthorizationClaims, -) (*Account, error) { +func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(claims jwtclaims.AuthorizationClaims) (*Account, error) { // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { - return am.GetAccountByUserOrAccountId(claims.UserId, claims.AccountId, claims.Domain) + return am.GetAccountByUserOrAccountID(claims.UserId, claims.AccountId, claims.Domain) } else if claims.AccountId != "" { - accountFromID, err := am.GetAccountById(claims.AccountId) + accountFromID, err := am.Store.GetAccount(claims.AccountId) if err != nil { return nil, err } @@ -845,8 +823,8 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims( } } - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireGlobalLock() + defer unlock() // We checked if the domain has a primary account already domainAccount, err := am.Store.GetAccountByPrivateDomain(claims.Domain) @@ -876,12 +854,13 @@ func isDomainValid(domain string) bool { } // AccountExists checks whether account exists (returns true) or not (returns false) -func (am *DefaultAccountManager) AccountExists(accountId string) (*bool, error) { - am.mux.Lock() - defer am.mux.Unlock() +func (am *DefaultAccountManager) AccountExists(accountID string) (*bool, error) { + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() var res bool - _, err := am.Store.GetAccount(accountId) + _, err := am.Store.GetAccount(accountID) if err != nil { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { res = false diff --git a/management/server/account_test.go b/management/server/account_test.go index ce809f345..786f750ef 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -121,7 +121,7 @@ func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { t.Fatalf("expected to create an account for a user %s", userId) } - account, err = manager.GetAccountByUser(userId) + account, err = manager.Store.GetAccountByUser(userId) if err != nil { t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) } @@ -302,7 +302,7 @@ func TestDefaultAccountManager_GetAccountFromToken(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") - initAccount, err := manager.GetAccountByUserOrAccountId(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) + initAccount, err := manager.GetAccountByUserOrAccountID(testCase.inputInitUserParams.UserId, testCase.inputInitUserParams.AccountId, testCase.inputInitUserParams.Domain) require.NoError(t, err, "create init user failed") if testCase.inputUpdateAttrs { @@ -345,7 +345,7 @@ func TestAccountManager_PrivateAccount(t *testing.T) { t.Fatalf("expected to create an account for a user %s", userId) } - account, err = manager.GetAccountByUser(userId) + account, err = manager.Store.GetAccountByUser(userId) if err != nil { t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) } @@ -401,7 +401,7 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { userId := "test_user" - account, err := manager.GetAccountByUserOrAccountId(userId, "", "") + account, err := manager.GetAccountByUserOrAccountID(userId, "", "") if err != nil { t.Fatal(err) } @@ -411,12 +411,12 @@ func TestAccountManager_GetAccountByUserOrAccountId(t *testing.T) { accountId := account.Id - _, err = manager.GetAccountByUserOrAccountId("", accountId, "") + _, err = manager.GetAccountByUserOrAccountID("", accountId, "") if err != nil { t.Errorf("expected to get existing account after creation using userid, no account was found for a account %s", accountId) } - _, err = manager.GetAccountByUserOrAccountId("", "", "") + _, err = manager.GetAccountByUserOrAccountID("", "", "") if err == nil { t.Errorf("expected an error when user and account IDs are empty") } @@ -470,7 +470,7 @@ func TestAccountManager_GetAccount(t *testing.T) { } // AddAccount has been already tested so we can assume it is correct and compare results - getAccount, err := manager.GetAccountById(expectedId) + getAccount, err := manager.Store.GetAccount(account.Id) if err != nil { t.Fatal(err) return @@ -540,7 +540,7 @@ func TestAccountManager_AddPeer(t *testing.T) { return } - account, err = manager.GetAccountById(account.Id) + account, err = manager.Store.GetAccount(account.Id) if err != nil { t.Fatal(err) return @@ -602,7 +602,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { return } - account, err = manager.GetAccountById(account.Id) + account, err = manager.Store.GetAccount(account.Id) if err != nil { t.Fatal(err) return @@ -680,7 +680,7 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { peer2 := getPeer() peer3 := getPeer() - account, err = manager.GetAccountById(account.Id) + account, err = manager.Store.GetAccount(account.Id) if err != nil { t.Fatal(err) return @@ -848,7 +848,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { return } - account, err = manager.GetAccountById(account.Id) + account, err = manager.Store.GetAccount(account.Id) if err != nil { t.Fatal(err) return diff --git a/management/server/file_store.go b/management/server/file_store.go index 87611952a..9451cd4b6 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -1,10 +1,12 @@ package server import ( + log "github.com/sirupsen/logrus" "os" "path/filepath" "strings" "sync" + "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -27,6 +29,10 @@ type FileStore struct { // mutex to synchronise Store read/write operations mux sync.Mutex `json:"-"` storeFile string `json:"-"` + + // sync.Mutex indexed by accountID + accountLocks sync.Map `json:"-"` + globalAccountLock sync.Mutex `json:"-"` } type StoredAccount struct{} @@ -44,6 +50,7 @@ func restore(file string) (*FileStore, error) { s := &FileStore{ Accounts: make(map[string]*Account), mux: sync.Mutex{}, + globalAccountLock: sync.Mutex{}, SetupKeyID2AccountID: make(map[string]string), PeerKeyID2AccountID: make(map[string]string), UserID2AccountID: make(map[string]string), @@ -111,7 +118,36 @@ func (s *FileStore) persist(file string) error { return util.WriteJson(file, s) } -// SaveAccount updates an existing account or adds a new one +// AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock +func (s *FileStore) AcquireGlobalLock() (unlock func()) { + log.Debugf("acquiring global lock") + start := time.Now() + s.globalAccountLock.Lock() + + unlock = func() { + s.globalAccountLock.Unlock() + log.Debugf("released global lock in %v", time.Since(start)) + } + + return unlock +} + +// AcquireAccountLock acquires account lock and returns a function that releases the lock +func (s *FileStore) AcquireAccountLock(accountID string) (unlock func()) { + log.Debugf("acquiring lock for account %s", accountID) + start := time.Now() + value, _ := s.accountLocks.LoadOrStore(accountID, &sync.Mutex{}) + mtx := value.(*sync.Mutex) + mtx.Lock() + + unlock = func() { + mtx.Unlock() + log.Debugf("released lock for account %s in %v", accountID, time.Since(start)) + } + + return unlock +} + func (s *FileStore) SaveAccount(account *Account) error { s.mux.Lock() defer s.mux.Unlock() diff --git a/management/server/group.go b/management/server/group.go index 00ae3604c..cbe606463 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -47,8 +47,9 @@ func (g *Group) Copy() *Group { // GetGroup object of the peers func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, error) { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -65,8 +66,9 @@ func (am *DefaultAccountManager) GetGroup(accountID, groupID string) (*Group, er // SaveGroup object of the peers func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -86,8 +88,9 @@ func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error // UpdateGroup updates a group using a list of operations func (am *DefaultAccountManager) UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error) { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -135,8 +138,9 @@ func (am *DefaultAccountManager) UpdateGroup(accountID string, // DeleteGroup object of the peers func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -155,8 +159,9 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error { // ListGroups objects of the peers func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -173,8 +178,9 @@ func (am *DefaultAccountManager) ListGroups(accountID string) ([]*Group, error) // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string) error { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -207,8 +213,9 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey string) error { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -235,8 +242,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str // GroupListPeers returns list of the peers from the group func (am *DefaultAccountManager) GroupListPeers(accountID, groupID string) ([]*Peer, error) { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index c694f32ac..4ed231f64 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -15,7 +15,6 @@ type MockAccountManager struct { GetAccountByUserFunc func(userId string) (*server.Account, error) CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error) GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) - GetAccountByIdFunc func(accountId string) (*server.Account, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExistsFunc func(accountId string) (*bool, error) @@ -114,16 +113,8 @@ func (am *MockAccountManager) CreateSetupKey( return nil, status.Errorf(codes.Unimplemented, "method CreateSetupKey is not implemented") } -// GetAccountById mock implementation of GetAccountById from server.AccountManager interface -func (am *MockAccountManager) GetAccountById(accountId string) (*server.Account, error) { - if am.GetAccountByIdFunc != nil { - return am.GetAccountByIdFunc(accountId) - } - return nil, status.Errorf(codes.Unimplemented, "method GetAccountById is not implemented") -} - // GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface -func (am *MockAccountManager) GetAccountByUserOrAccountId( +func (am *MockAccountManager) GetAccountByUserOrAccountID( userId, accountId, domain string, ) (*server.Account, error) { if am.GetAccountByUserOrAccountIdFunc != nil { @@ -131,7 +122,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountId( } return nil, status.Errorf( codes.Unimplemented, - "method GetAccountByUserOrAccountId is not implemented", + "method GetAccountByUserOrAccountID is not implemented", ) } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 693e0af4d..b3dbf333a 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -60,8 +60,9 @@ type NameServerGroupUpdateOperation struct { // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -78,8 +79,9 @@ func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) // CreateNameServerGroup creates and saves a new nameserver group func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -125,8 +127,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(accountID string, name, d // SaveNameServerGroup saves nameserver group func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() if nsGroupToSave == nil { return status.Errorf(codes.InvalidArgument, "nameserver group provided is nil") @@ -161,8 +164,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID string, nsGroupTo // UpdateNameServerGroup updates existing nameserver group with set of operations func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -263,8 +267,9 @@ func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID stri // DeleteNameServerGroup deletes nameserver group with nsGroupID func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID string) error { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -290,8 +295,9 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID stri // ListNameServerGroups returns a list of nameserver groups from account func (am *DefaultAccountManager) ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error) { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index d672b4b2a..ea837a0d3 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -635,7 +635,7 @@ func TestSaveNameServerGroup(t *testing.T) { return } - account, err = am.GetAccountById(account.Id) + account, err = am.Store.GetAccount(account.Id) if err != nil { t.Fatal(err) } diff --git a/management/server/peer.go b/management/server/peer.go index ceda49ac4..96a684345 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -76,8 +76,6 @@ func (p *Peer) Copy() *Peer { // GetPeer looks up peer by its public WireGuard key func (am *DefaultAccountManager) GetPeer(peerPubKey string) (*Peer, error) { - am.mux.Lock() - defer am.mux.Unlock() account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) if err != nil { @@ -90,8 +88,7 @@ func (am *DefaultAccountManager) GetPeer(peerPubKey string) (*Peer, error) { // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // the current user is not an admin. func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) { - am.mux.Lock() - defer am.mux.Unlock() + account, err := am.Store.GetAccount(accountID) if err != nil { return nil, err @@ -116,14 +113,21 @@ func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, er // MarkPeerConnected marks peer as connected (true) or disconnected (false) func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected bool) error { - am.mux.Lock() - defer am.mux.Unlock() account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) if err != nil { return err } + unlock := am.Store.AcquireAccountLock(account.Id) + defer unlock() + + // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) + account, err = am.Store.GetAccount(account.Id) + if err != nil { + return err + } + peer, err := account.FindPeerByPubKey(peerPubKey) if err != nil { return err @@ -143,8 +147,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected // UpdatePeer updates peer. Only Peer.Name and Peer.SSHEnabled can be updated. func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Peer, error) { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -188,8 +193,9 @@ func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Pe // DeletePeer removes peer from the account by its IP func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string) (*Peer, error) { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -237,8 +243,9 @@ func (am *DefaultAccountManager) DeletePeer(accountID string, peerPubKey string) // GetPeerByIP returns peer by its IP func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (*Peer, error) { - am.mux.Lock() - defer am.mux.Unlock() + + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -256,8 +263,6 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (* // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap, error) { - am.mux.Lock() - defer am.mux.Unlock() account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) if err != nil { @@ -292,8 +297,6 @@ func (am *DefaultAccountManager) GetNetworkMap(peerPubKey string) (*NetworkMap, // GetPeerNetwork returns the Network for a given peer func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, error) { - am.mux.Lock() - defer am.mux.Unlock() account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) if err != nil { @@ -311,8 +314,6 @@ func (am *DefaultAccountManager) GetPeerNetwork(peerPubKey string) (*Network, er // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // The peer property is just a placeholder for the Peer properties to pass further func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *Peer) (*Peer, error) { - am.mux.Lock() - defer am.mux.Unlock() upperKey := strings.ToUpper(setupKey) @@ -367,6 +368,15 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P return nil, status.Errorf(codes.InvalidArgument, "no setup key or user id provided") } + unlock := am.Store.AcquireAccountLock(account.Id) + defer unlock() + + // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) + account, err = am.Store.GetAccount(account.Id) + if err != nil { + return nil, err + } + var takenIps []net.IP existingLabels := make(lookupMap) for _, existingPeer := range account.Peers { @@ -433,8 +443,6 @@ func (am *DefaultAccountManager) AddPeer(setupKey string, userID string, peer *P // UpdatePeerSSHKey updates peer's public SSH key func (am *DefaultAccountManager) UpdatePeerSSHKey(peerPubKey string, sshKey string) error { - am.mux.Lock() - defer am.mux.Unlock() if sshKey == "" { log.Debugf("empty SSH key provided for peer %s, skipping update", peerPubKey) @@ -446,6 +454,15 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerPubKey string, sshKey stri return err } + unlock := am.Store.AcquireAccountLock(account.Id) + defer unlock() + + // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) + account, err = am.Store.GetAccount(account.Id) + if err != nil { + return err + } + peer, err := account.FindPeerByPubKey(peerPubKey) if err != nil { return err @@ -470,14 +487,15 @@ func (am *DefaultAccountManager) UpdatePeerSSHKey(peerPubKey string, sshKey stri // UpdatePeerMeta updates peer's system metadata func (am *DefaultAccountManager) UpdatePeerMeta(peerPubKey string, meta PeerSystemMeta) error { - am.mux.Lock() - defer am.mux.Unlock() account, err := am.Store.GetAccountByPeerPubKey(peerPubKey) if err != nil { return err } + unlock := am.Store.AcquireAccountLock(account.Id) + defer unlock() + peer, err := account.FindPeerByPubKey(peerPubKey) if err != nil { return err diff --git a/management/server/route.go b/management/server/route.go index ccc54f750..9f615f9e0 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -61,8 +61,8 @@ type RouteUpdateOperation struct { // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -116,8 +116,8 @@ func (am *DefaultAccountManager) checkPrefixPeerExists(accountID, peer string, p // CreateRoute creates and saves a new route func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -180,8 +180,8 @@ func (am *DefaultAccountManager) CreateRoute(accountID string, network, peer, de // SaveRoute saves route func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route.Route) error { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() if routeToSave == nil { return status.Errorf(codes.InvalidArgument, "route provided is nil") @@ -223,8 +223,8 @@ func (am *DefaultAccountManager) SaveRoute(accountID string, routeToSave *route. // UpdateRoute updates existing route with set of operations func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -320,8 +320,8 @@ func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operatio // DeleteRoute deletes route with routeID func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -340,8 +340,8 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error { // ListRoutes returns a list of routes from account func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { diff --git a/management/server/route_test.go b/management/server/route_test.go index 63c268fd8..1d47c7e01 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -380,7 +380,7 @@ func TestSaveRoute(t *testing.T) { return } - account, err = am.GetAccountById(account.Id) + account, err = am.Store.GetAccount(account.Id) if err != nil { t.Fatal(err) } @@ -845,5 +845,5 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return nil, err } - return am.GetAccountById(accountID) + return am.Store.GetAccount(account.Id) } diff --git a/management/server/rule.go b/management/server/rule.go index 8133d4f4a..dd0cf5fa9 100644 --- a/management/server/rule.go +++ b/management/server/rule.go @@ -90,8 +90,8 @@ func (r *Rule) Copy() *Rule { // GetRule of ACL from the store func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rule, error) { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -117,8 +117,8 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rul // SaveRule of ACL in the store func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -138,8 +138,8 @@ func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error { // UpdateRule updates a rule using a list of operations func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string, operations []RuleUpdateOperation) (*Rule, error) { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -212,8 +212,8 @@ func (am *DefaultAccountManager) UpdateRule(accountID string, ruleID string, // DeleteRule of ACL from the store func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { @@ -232,8 +232,8 @@ func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error { // ListRules of ACL from the store func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, error) { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 1a30d9ac9..c91af8c79 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -173,8 +173,8 @@ func Hash(s string) uint32 { // and adds it to the specified account. A list of autoGroups IDs can be empty. func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string) (*SetupKey, error) { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() keyDuration := DefaultSetupKeyDuration if expiresIn != 0 { @@ -208,8 +208,8 @@ func (am *DefaultAccountManager) CreateSetupKey(accountID string, keyName string // (e.g. the key itself, creation date, ID, etc). // These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *SetupKey) (*SetupKey, error) { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() if keyToSave == nil { return nil, status.Errorf(codes.InvalidArgument, "provided setup key to update is nil") @@ -249,8 +249,8 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup // ListSetupKeys returns a list of all setup keys of the account func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { return nil, status.Errorf(codes.NotFound, "account not found") @@ -277,8 +277,8 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*Set // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() account, err := am.Store.GetAccount(accountID) if err != nil { diff --git a/management/server/store.go b/management/server/store.go index 2ee655b33..fcfd58bca 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -10,4 +10,8 @@ type Store interface { SaveAccount(account *Account) error GetInstallationID() string SaveInstallationID(id string) error + // AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock + AcquireAccountLock(accountID string) func() + // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock + AcquireGlobalLock() func() } diff --git a/management/server/user.go b/management/server/user.go index 309c6a120..a0e4a6870 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -119,8 +119,8 @@ func NewAdminUser(id string) *User { // CreateUser creates a new user under the given account. Effectively this is a user invite. func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) (*UserInfo, error) { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() if am.idpManager == nil { return nil, Errorf(PreconditionFailed, "IdP manager must be enabled to send user invites") @@ -184,8 +184,8 @@ func (am *DefaultAccountManager) CreateUser(accountID string, invite *UserInfo) // SaveUser saves updates a given user. If the user doesn't exit it will throw status.NotFound error. // Only User.AutoGroups field is allowed to be updated for now. func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*UserInfo, error) { - am.mux.Lock() - defer am.mux.Unlock() + unlock := am.Store.AcquireAccountLock(accountID) + defer unlock() if update == nil { return nil, status.Errorf(codes.InvalidArgument, "provided user update is nil") @@ -234,16 +234,16 @@ func (am *DefaultAccountManager) SaveUser(accountID string, update *User) (*User } // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist -func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) (*Account, error) { - am.mux.Lock() - defer am.mux.Unlock() +func (am *DefaultAccountManager) GetOrCreateAccountByUser(userID, domain string) (*Account, error) { + unlock := am.Store.AcquireGlobalLock() + defer unlock() lowerDomain := strings.ToLower(domain) - account, err := am.Store.GetAccountByUser(userId) + account, err := am.Store.GetAccountByUser(userID) if err != nil { if s, ok := status.FromError(err); ok && s.Code() == codes.NotFound { - account, err = am.newAccount(userId, lowerDomain) + account, err = am.newAccount(userID, lowerDomain) if err != nil { return nil, err } @@ -257,7 +257,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) } } - userObj := account.Users[userId] + userObj := account.Users[userID] if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin { account.Domain = lowerDomain @@ -270,14 +270,6 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(userId, domain string) return account, nil } -// GetAccountByUser returns an existing account for a given user id, NotFound if account couldn't be found -func (am *DefaultAccountManager) GetAccountByUser(userId string) (*Account, error) { - am.mux.Lock() - defer am.mux.Unlock() - - return am.Store.GetAccountByUser(userId) -} - // IsUserAdmin flag for current user authenticated by JWT token func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) { account, err := am.GetAccountFromToken(claims) @@ -296,7 +288,7 @@ func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaim // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // based on provided user role. func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) { - account, err := am.GetAccountById(accountID) + account, err := am.Store.GetAccount(accountID) if err != nil { return nil, err }