diff --git a/client/cmd/testutil.go b/client/cmd/testutil.go index 861a69bf8..7bce7f511 100644 --- a/client/cmd/testutil.go +++ b/client/cmd/testutil.go @@ -62,7 +62,7 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste t.Fatal(err) } s := grpc.NewServer() - store, err := mgmt.NewStore(config.Datadir) + store, err := mgmt.NewFileStore(config.Datadir) if err != nil { t.Fatal(err) } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index e2b542ced..56d9eb66f 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -935,7 +935,7 @@ func startManagement(port int, dataDir string) (*grpc.Server, error) { return nil, err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, err := server.NewStore(config.Datadir) + store, err := server.NewFileStore(config.Datadir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } diff --git a/management/client/client_test.go b/management/client/client_test.go index c48f151b5..129dbf1ca 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -49,7 +49,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - store, err := mgmt.NewStore(config.Datadir) + store, err := mgmt.NewFileStore(config.Datadir) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index bcd8252a8..82168ec72 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -116,7 +116,7 @@ var ( } } - store, err := server.NewStore(config.Datadir) + store, err := server.NewFileStore(config.Datadir) if err != nil { return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err) } @@ -250,6 +250,7 @@ var ( _ = certManager.Listener().Close() } gRPCAPIHandler.Stop() + _ = store.Close() log.Infof("stopped Management Service") return nil diff --git a/management/server/account_test.go b/management/server/account_test.go index 786f750ef..b413d7b22 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1211,7 +1211,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) { func createStore(t *testing.T) (Store, error) { dataDir := t.TempDir() - store, err := NewStore(dataDir) + store, err := NewFileStore(dataDir) if err != nil { return nil, err } diff --git a/management/server/file_store.go b/management/server/file_store.go index 9cf3750e3..8206713ed 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -37,8 +37,8 @@ type FileStore struct { type StoredAccount struct{} -// NewStore restores a store from the file located in the datadir -func NewStore(dataDir string) (*FileStore, error) { +// NewFileStore restores a store from the file located in the datadir +func NewFileStore(dataDir string) (*FileStore, error) { return restore(filepath.Join(dataDir, storeFileName)) } @@ -198,7 +198,12 @@ func (s *FileStore) GetAccountByPrivateDomain(domain string) (*Account, error) { ) } - return s.getAccount(accountID) + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + return account.Copy(), nil } // GetAccountBySetupKey returns account by setup key id @@ -211,7 +216,12 @@ func (s *FileStore) GetAccountBySetupKey(setupKey string) (*Account, error) { return nil, status.Errorf(codes.NotFound, "account not found: provided setup key doesn't exists") } - return s.getAccount(accountID) + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + return account.Copy(), nil } // GetAllAccounts returns all accounts @@ -225,13 +235,14 @@ func (s *FileStore) GetAllAccounts() (all []*Account) { return all } +// getAccount returns a reference to the Account. Should not return a copy. func (s *FileStore) getAccount(accountID string) (*Account, error) { account, accountFound := s.Accounts[accountID] if !accountFound { return nil, status.Errorf(codes.NotFound, "account not found") } - return account.Copy(), nil + return account, nil } // GetAccount returns an account for ID @@ -239,7 +250,12 @@ func (s *FileStore) GetAccount(accountID string) (*Account, error) { s.mux.Lock() defer s.mux.Unlock() - return s.getAccount(accountID) + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + return account.Copy(), nil } // GetAccountByUser returns a user account @@ -252,7 +268,12 @@ func (s *FileStore) GetAccountByUser(userID string) (*Account, error) { return nil, status.Errorf(codes.NotFound, "account not found") } - return s.getAccount(accountID) + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + return account.Copy(), nil } // GetAccountByPeerPubKey returns an account for a given peer WireGuard public key @@ -265,7 +286,12 @@ func (s *FileStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { return nil, status.Errorf(codes.NotFound, "Provided peer key doesn't exists %s", peerKey) } - return s.getAccount(accountID) + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + return account.Copy(), nil } // GetInstallationID returns the installation ID from the store @@ -274,11 +300,42 @@ func (s *FileStore) GetInstallationID() string { } // SaveInstallationID saves the installation ID -func (s *FileStore) SaveInstallationID(id string) error { +func (s *FileStore) SaveInstallationID(ID string) error { s.mux.Lock() defer s.mux.Unlock() - s.InstallationID = id + s.InstallationID = ID + + return s.persist(s.storeFile) +} + +// SavePeerStatus stores the PeerStatus in memory. It doesn't attempt to persist data to speed up things. +// PeerStatus will be saved eventually when some other changes occur. +func (s *FileStore) SavePeerStatus(accountID, peerKey string, peerStatus PeerStatus) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return err + } + + peer := account.Peers[peerKey] + if peer == nil { + return status.Errorf(codes.NotFound, "peer %s not found", peerKey) + } + + peer.Status = &peerStatus + + return nil +} + +// Close the FileStore persisting data to disk +func (s *FileStore) Close() error { + s.mux.Lock() + defer s.mux.Unlock() + + log.Infof("closing FileStore") return s.persist(s.storeFile) } diff --git a/management/server/file_store_test.go b/management/server/file_store_test.go index 11439af10..69574bdec 100644 --- a/management/server/file_store_test.go +++ b/management/server/file_store_test.go @@ -93,7 +93,7 @@ func TestStore(t *testing.T) { return } - restored, err := NewStore(store.storeFile) + restored, err := NewFileStore(store.storeFile) if err != nil { return } @@ -129,7 +129,7 @@ func TestRestore(t *testing.T) { t.Fatal(err) } - store, err := NewStore(storeDir) + store, err := NewFileStore(storeDir) if err != nil { return } @@ -161,7 +161,7 @@ func TestGetAccountByPrivateDomain(t *testing.T) { t.Fatal(err) } - store, err := NewStore(storeDir) + store, err := NewFileStore(storeDir) if err != nil { return } @@ -190,7 +190,7 @@ func TestFileStore_GetAccount(t *testing.T) { t.Fatal(err) } - store, err := NewStore(storeDir) + store, err := NewFileStore(storeDir) if err != nil { t.Fatal(err) } @@ -218,8 +218,59 @@ func TestFileStore_GetAccount(t *testing.T) { assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups)) } +func TestFileStore_SavePeerStatus(t *testing.T) { + storeDir := t.TempDir() + + err := util.CopyFileContents("testdata/store.json", filepath.Join(storeDir, "store.json")) + if err != nil { + t.Fatal(err) + } + + store, err := NewFileStore(storeDir) + if err != nil { + return + } + + account, err := store.getAccount("bf1c8084-ba50-4ce7-9439-34653001fc3b") + if err != nil { + t.Fatal(err) + } + + // save status of non-existing peer + newStatus := PeerStatus{Connected: true, LastSeen: time.Now()} + err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) + assert.Error(t, err) + + // save new status of existing peer + account.Peers["testpeer"] = &Peer{ + Key: "peerkey", + SetupKey: "peerkeysetupkey", + IP: net.IP{127, 0, 0, 1}, + Meta: PeerSystemMeta{}, + Name: "peer name", + Status: &PeerStatus{Connected: false, LastSeen: time.Now()}, + } + + err = store.SaveAccount(account) + if err != nil { + t.Fatal(err) + } + + err = store.SavePeerStatus(account.Id, "testpeer", newStatus) + if err != nil { + t.Fatal(err) + } + account, err = store.getAccount(account.Id) + if err != nil { + t.Fatal(err) + } + + actual := account.Peers["testpeer"].Status + assert.Equal(t, newStatus, *actual) +} + func newStore(t *testing.T) *FileStore { - store, err := NewStore(t.TempDir()) + store, err := NewFileStore(t.TempDir()) if err != nil { t.Errorf("failed creating a new store") } diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 9ccabf6c4..5dc20b0cb 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -398,7 +398,7 @@ func startManagement(t *testing.T, port int, config *Config) (*grpc.Server, erro return nil, err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, err := NewStore(config.Datadir) + store, err := NewFileStore(config.Datadir) if err != nil { return nil, err } diff --git a/management/server/management_test.go b/management/server/management_test.go index 2e0d003c4..f4c1b2ac0 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -488,7 +488,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) { Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, err := server.NewStore(config.Datadir) + store, err := server.NewFileStore(config.Datadir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index ea837a0d3..30cc8246f 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -1061,7 +1061,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { func createNSStore(t *testing.T) (Store, error) { dataDir := t.TempDir() - store, err := NewStore(dataDir) + store, err := NewFileStore(dataDir) if err != nil { return nil, err } diff --git a/management/server/peer.go b/management/server/peer.go index 96a684345..08611ae8f 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -74,6 +74,14 @@ func (p *Peer) Copy() *Peer { } } +// Copy PeerStatus +func (p *PeerStatus) Copy() *PeerStatus { + return &PeerStatus{ + LastSeen: p.LastSeen, + Connected: p.Connected, + } +} + // GetPeer looks up peer by its public WireGuard key func (am *DefaultAccountManager) GetPeer(peerPubKey string) (*Peer, error) { @@ -133,12 +141,13 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerPubKey string, connected return err } - peer.Status.LastSeen = time.Now() - peer.Status.Connected = connected - + newStatus := peer.Status.Copy() + newStatus.LastSeen = time.Now() + newStatus.Connected = connected + peer.Status = newStatus account.UpdatePeer(peer) - err = am.Store.SaveAccount(account) + err = am.Store.SavePeerStatus(account.Id, peerPubKey, *newStatus) if err != nil { return err } diff --git a/management/server/route_test.go b/management/server/route_test.go index 1d47c7e01..8acb95b2e 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -788,7 +788,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { func createRouterStore(t *testing.T) (Store, error) { dataDir := t.TempDir() - store, err := NewStore(dataDir) + store, err := NewFileStore(dataDir) if err != nil { return nil, err } diff --git a/management/server/store.go b/management/server/store.go index fcfd58bca..3954f0e8c 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -9,9 +9,12 @@ type Store interface { GetAccountByPrivateDomain(domain string) (*Account, error) SaveAccount(account *Account) error GetInstallationID() string - SaveInstallationID(id string) error + SaveInstallationID(ID string) error // AcquireAccountLock should attempt to acquire account lock and return a function that releases the lock AcquireAccountLock(accountID string) func() // AcquireGlobalLock should attempt to acquire a global lock and return a function that releases the lock AcquireGlobalLock() func() + SavePeerStatus(accountID, peerKey string, status PeerStatus) error + // Close should close the store persisting all unsaved data. + Close() error }