From 5dde044fa5c91f6782505fee17fc80007819ec17 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sun, 10 Mar 2024 19:09:45 +0100 Subject: [PATCH] Check for record not found when searching the store (#1686) This change returns status.NotFound only on gorm.ErrRecordNotFound and status.Internal on every other DB error --- management/server/sqlite_store.go | 68 +++++++++++++++++++++----- management/server/sqlite_store_test.go | 47 ++++++++++++++++++ 2 files changed, 103 insertions(+), 12 deletions(-) diff --git a/management/server/sqlite_store.go b/management/server/sqlite_store.go index b077acf48..eff43a31b 100644 --- a/management/server/sqlite_store.go +++ b/management/server/sqlite_store.go @@ -1,6 +1,7 @@ package server import ( + "errors" "fmt" "path/filepath" "runtime" @@ -255,7 +256,11 @@ func (s *SqliteStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerID) if result.Error != nil { - return status.Errorf(status.NotFound, "peer %s not found", peerID) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.Errorf(status.NotFound, "peer %s not found", peerID) + } + log.Errorf("error when getting peer from the store: %s", result.Error) + return status.Errorf(status.Internal, "issue getting peer from store") } peer.Status = &peerStatus @@ -267,7 +272,11 @@ func (s *SqliteStore) SavePeerLocation(accountID string, peerWithLocation *nbpee var peer nbpeer.Peer result := s.db.First(&peer, "account_id = ? and id = ?", accountID, peerWithLocation.ID) if result.Error != nil { - return status.Errorf(status.NotFound, "peer %s not found", peer.ID) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.Errorf(status.NotFound, "peer %s not found", peer.ID) + } + log.Errorf("error when getting peer from the store: %s", result.Error) + return status.Errorf(status.Internal, "issue getting peer from store") } peer.Location = peerWithLocation.Location @@ -291,7 +300,11 @@ func (s *SqliteStore) GetAccountByPrivateDomain(domain string) (*Account, error) result := s.db.First(&account, "domain = ? and is_domain_primary_account = ? and domain_category = ?", strings.ToLower(domain), true, PrivateCategory) if result.Error != nil { - return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private") + } + log.Errorf("error when getting account from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting account from store") } // TODO: rework to not call GetAccount @@ -302,7 +315,11 @@ func (s *SqliteStore) GetAccountBySetupKey(setupKey string) (*Account, error) { var key SetupKey result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey)) if result.Error != nil { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + log.Errorf("error when getting setup key from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting setup key from store") } if key.AccountID == "" { @@ -316,7 +333,11 @@ func (s *SqliteStore) GetTokenIDByHashedToken(hashedToken string) (string, error var token PersonalAccessToken result := s.db.First(&token, "hashed_token = ?", hashedToken) if result.Error != nil { - return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + } + log.Errorf("error when getting token from the store: %s", result.Error) + return "", status.Errorf(status.Internal, "issue getting account from store") } return token.ID, nil @@ -326,7 +347,11 @@ func (s *SqliteStore) GetUserByTokenID(tokenID string) (*User, error) { var token PersonalAccessToken result := s.db.First(&token, "id = ?", tokenID) if result.Error != nil { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + log.Errorf("error when getting token from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting account from store") } if token.UserID == "" { @@ -370,8 +395,11 @@ func (s *SqliteStore) GetAccount(accountID string) (*Account, error) { Preload(clause.Associations). First(&account, "id = ?", accountID) if result.Error != nil { - log.Errorf("when getting account from the store: %s", result.Error) - return nil, status.Errorf(status.NotFound, "account not found") + log.Errorf("error when getting account from the store: %s", result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "account not found") + } + return nil, status.Errorf(status.Internal, "issue getting account from store") } // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us @@ -431,7 +459,11 @@ func (s *SqliteStore) GetAccountByUser(userID string) (*Account, error) { var user User result := s.db.Select("account_id").First(&user, "id = ?", userID) if result.Error != nil { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + log.Errorf("error when getting user from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting account from store") } if user.AccountID == "" { @@ -445,7 +477,11 @@ func (s *SqliteStore) GetAccountByPeerID(peerID string) (*Account, error) { var peer nbpeer.Peer result := s.db.Select("account_id").First(&peer, "id = ?", peerID) if result.Error != nil { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + log.Errorf("error when getting peer from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting account from store") } if peer.AccountID == "" { @@ -460,7 +496,11 @@ func (s *SqliteStore) GetAccountByPeerPubKey(peerKey string) (*Account, error) { result := s.db.Select("account_id").First(&peer, "key = ?", peerKey) if result.Error != nil { - return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") + } + log.Errorf("error when getting peer from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting account from store") } if peer.AccountID == "" { @@ -476,7 +516,11 @@ func (s *SqliteStore) SaveUserLastLogin(accountID, userID string, lastLogin time result := s.db.First(&user, "account_id = ? and id = ?", accountID, userID) if result.Error != nil { - return status.Errorf(status.NotFound, "user %s not found", userID) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.Errorf(status.NotFound, "user %s not found", userID) + } + log.Errorf("error when getting user from the store: %s", result.Error) + return status.Errorf(status.Internal, "issue getting user from store") } user.LastLogin = lastLogin diff --git a/management/server/sqlite_store_test.go b/management/server/sqlite_store_test.go index 29b49d7f3..e43a0cd9a 100644 --- a/management/server/sqlite_store_test.go +++ b/management/server/sqlite_store_test.go @@ -12,6 +12,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/server/status" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/util" ) @@ -174,6 +176,26 @@ func TestSqlite_DeleteAccount(t *testing.T) { } +func TestSqlite_GetAccount(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/store.json") + + id := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + account, err := store.GetAccount(id) + require.NoError(t, err) + require.Equal(t, id, account.Id, "account id should match") + + _, err = store.GetAccount("non-existing-account") + assert.Error(t, err) + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") +} + func TestSqlite_SavePeerStatus(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") @@ -188,6 +210,9 @@ func TestSqlite_SavePeerStatus(t *testing.T) { newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus) assert.Error(t, err) + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") // save new status of existing peer account.Peers["testpeer"] = &nbpeer.Peer{ @@ -254,6 +279,13 @@ func TestSqlite_SavePeerLocation(t *testing.T) { actual := account.Peers[peer.ID].Location assert.Equal(t, peer.Location, actual) + + peer.ID = "non-existing-peer" + err = store.SavePeerLocation(account.Id, peer) + assert.Error(t, err) + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") } func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { @@ -271,6 +303,9 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { _, err = store.GetAccountByPrivateDomain("missing-domain.com") require.Error(t, err, "should return error on domain lookup") + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") } func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { @@ -286,6 +321,12 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { token, err := store.GetTokenIDByHashedToken(hashed) require.NoError(t, err) require.Equal(t, id, token) + + _, err = store.GetTokenIDByHashedToken("non-existing-hash") + require.Error(t, err) + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") } func TestSqlite_GetUserByTokenID(t *testing.T) { @@ -300,6 +341,12 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { user, err := store.GetUserByTokenID(id) require.NoError(t, err) require.Equal(t, id, user.PATs[id].ID) + + _, err = store.GetUserByTokenID("non-existing-id") + require.Error(t, err) + parsedErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") } func newSqliteStore(t *testing.T) *SqliteStore {