From 2806d7316100fb59f6498ebd6aae6975faf62477 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 12 Nov 2024 13:38:34 +0300 Subject: [PATCH 1/2] Add tests Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 277 +++++++++++++++++++++++++++- 1 file changed, 274 insertions(+), 3 deletions(-) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 20409798b..114da1ee6 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -14,11 +14,10 @@ import ( "time" "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" route2 "github.com/netbirdio/netbird/route" @@ -1293,3 +1292,275 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID) require.Error(t, err) } + +func TestSqlStore_GetGroupsByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupIDs []string + expectedCount int + }{ + { + name: "retrieve existing groups by existing IDs", + groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"}, + expectedCount: 2, + }, + { + name: "empty group IDs list", + groupIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing group IDs", + groupIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing group IDs", + groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs) + require.NoError(t, err) + require.Len(t, groups, tt.expectedCount) + }) + } +} + +func TestSqlStore_SaveGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + group := &nbgroup.Group{ + ID: "group-id", + AccountID: accountID, + Issued: "api", + Peers: []string{"peer1", "peer2"}, + } + err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + require.NoError(t, err) + + savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id") + require.NoError(t, err) + require.Equal(t, savedGroup, group) +} + +func TestSqlStore_SaveGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + groups := []*nbgroup.Group{ + { + ID: "group-1", + AccountID: accountID, + Issued: "api", + Peers: []string{"peer1", "peer2"}, + }, + { + ID: "group-2", + AccountID: accountID, + Issued: "integration", + Peers: []string{"peer3", "peer4"}, + }, + } + err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups) + require.NoError(t, err) +} + +func TestSqlStore_DeleteGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupID string + expectError bool + }{ + { + name: "delete existing group", + groupID: "cfefqs706sqkneg59g4g", + expectError: false, + }, + { + name: "delete non-existing group", + groupID: "non-existing-group-id", + expectError: true, + }, + { + name: "delete with empty group ID", + groupID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + } else { + require.NoError(t, err) + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID) + require.Error(t, err) + require.Nil(t, group) + } + }) + } +} + +func TestSqlStore_DeleteGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupIDs []string + expectError bool + }{ + { + name: "delete multiple existing groups", + groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"}, + expectError: false, + }, + { + name: "delete non-existing groups", + groupIDs: []string{"non-existing-id-1", "non-existing-id-2"}, + expectError: false, + }, + { + name: "delete with empty group IDs list", + groupIDs: []string{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + + for _, groupID := range tt.groupIDs { + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.Error(t, err) + require.Nil(t, group) + } + } + }) + } +} + +func TestSqlStore_GetPeerByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + peerID string + expectError bool + }{ + { + name: "retrieve existing peer", + peerID: "cfefqs706sqkneg59g4g", + expectError: false, + }, + { + name: "retrieve non-existing peer", + peerID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty peer ID", + peerID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, peer) + } else { + require.NoError(t, err) + require.NotNil(t, peer) + require.Equal(t, tt.peerID, peer.ID) + } + }) + } +} + +func TestSqlStore_GetPeersByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + peerIDs []string + expectedCount int + }{ + { + name: "retrieve existing peers by existing IDs", + peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"}, + expectedCount: 2, + }, + { + name: "empty peer IDs list", + peerIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing peer IDs", + peerIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing peer IDs", + peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} From 20a5afc3596b41028a12a13b5820dff0ede2c6a7 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 12 Nov 2024 14:19:22 +0100 Subject: [PATCH 2/2] [management] Add more logs to the peer update processes (#2881) --- management/server/account.go | 8 ++++---- management/server/grpcserver.go | 9 ++++++++- management/server/peer.go | 18 ++++++++++-------- management/server/status/error.go | 5 +++++ management/server/updatechannel.go | 5 ++++- management/server/user.go | 12 ++++++++++-- 6 files changed, 41 insertions(+), 16 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 583853f25..bf6039229 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2129,7 +2129,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st if settings.GroupsPropagationEnabled { account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { - return fmt.Errorf("error getting account: %w", err) + return status.NewGetAccountError(err) } if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) { @@ -2290,12 +2290,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID account, err := am.Store.GetAccount(ctx, accountID) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, status.NewGetAccountError(err) } peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) } err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account) @@ -2314,7 +2314,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account account, err := am.Store.GetAccount(ctx, accountID) if err != nil { - return err + return status.NewGetAccountError(err) } err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index efe088b27..9c12336f8 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -180,6 +180,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), extractPeerMeta(ctx, syncReq.GetMeta()), realIP) if err != nil { + log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) return mapError(ctx, err) } @@ -207,6 +208,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi // handleUpdates sends updates to the connected peer until the updates channel is closed. func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { + log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) for { select { // condition when there are some updates @@ -260,10 +262,15 @@ func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, p unlock := s.acquirePeerLockByUID(ctx, peer.Key) defer unlock() - _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) + err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) + if err != nil { + log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err) + } s.peersUpdateManager.CloseChannel(ctx, peer.ID) s.secretsManager.CancelRefresh(peer.ID) s.ephemeralManager.OnPeerDisconnected(ctx, peer) + + log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key) } func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) { diff --git a/management/server/peer.go b/management/server/peer.go index 8ced2a1de..9784650de 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -110,14 +110,16 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error { peer, err := account.FindPeerByPubKey(peerPubKey) if err != nil { - return err + return fmt.Errorf("failed to find peer by pub key: %w", err) } expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account) if err != nil { - return err + return fmt.Errorf("failed to update peer status and location: %w", err) } + log.WithContext(ctx).Debugf("mark peer %s connected: %t", peer.ID, connected) + if peer.AddedWithSSOLogin() { if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { am.checkAndSchedulePeerLoginExpiration(ctx, account) @@ -168,7 +170,7 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) if err != nil { - return false, err + return false, fmt.Errorf("failed to save peer status: %w", err) } return oldStatus.LoginExpired, nil @@ -587,7 +589,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { - return nil, nil, nil, fmt.Errorf("error getting account: %w", err) + return nil, nil, nil, status.NewGetAccountError(err) } allGroup, err := account.GetGroupAll() @@ -640,7 +642,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac if peer.UserID != "" { user, err := account.FindUser(peer.UserID) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, fmt.Errorf("failed to get user: %w", err) } err = checkIfPeerOwnerIsBlocked(peer, user) @@ -657,7 +659,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac if updated { err = am.Store.SavePeer(ctx, account.Id, peer) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err) } if sync.UpdateAccountPeers { @@ -667,7 +669,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, fmt.Errorf("failed to validate peer: %w", err) } var postureChecks []*posture.Checks @@ -685,7 +687,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac validPeersMap, err := am.GetValidatedPeers(account) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err) } postureChecks = am.getPeerPostureChecks(account, peer) diff --git a/management/server/status/error.go b/management/server/status/error.go index a415d5b6e..f1f3f16e6 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -135,3 +135,8 @@ func NewStoreContextCanceledError(duration time.Duration) error { func NewInvalidKeyIDError() error { return Errorf(InvalidArgument, "invalid key ID") } + +// NewGetAccountError creates a new Error with Internal type for an issue getting account +func NewGetAccountError(err error) error { + return Errorf(Internal, "error getting account: %s", err) +} diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 59b6fd094..d338b84b1 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -96,9 +96,12 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) { if channel, ok := p.peerChannels[peerID]; ok { delete(p.peerChannels, peerID) close(channel) + + log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) + return } - log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) + log.WithContext(ctx).Debugf("closing updates channel: peer %s has no channel", peerID) } // CloseChannels closes updates channel for each given peer diff --git a/management/server/user.go b/management/server/user.go index 1368b76b1..5e0d9d034 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -9,14 +9,16 @@ import ( "time" "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/activity" + nbContext "github.com/netbirdio/netbird/management/server/context" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" ) const ( @@ -1105,6 +1107,9 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error { var peerIDs []string for _, peer := range peers { + // nolint:staticcheck + ctx = context.WithValue(ctx, nbContext.PeerIDKey, peer.Key) + if peer.Status.LoginExpired { continue } @@ -1112,8 +1117,11 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou peer.MarkLoginExpired(true) account.UpdatePeer(peer) if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil { - return err + return fmt.Errorf("failed saving peer status for peer %s: %s", peer.ID, err) } + + log.WithContext(ctx).Tracef("mark peer %s login expired", peer.ID) + am.StoreEvent( ctx, peer.UserID, peer.ID, account.Id,