diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 3f08d12ba..99ccb1539 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -93,8 +93,9 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp permissionsManagerMock := permissions.NewMockManager(ctrl) peersmanager := peers.NewManager(store, permissionsManagerMock) + settingsManagerMock := settings.NewMockManager(ctrl) - iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, eventStore) + iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index afe4622b3..90c8cbc60 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1557,7 +1557,8 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri permissionsManager := permissions.NewManager(store) peersManager := peers.NewManager(store, permissionsManager) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, eventStore) + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) diff --git a/client/server/server_test.go b/client/server/server_test.go index 493c8601a..87889cbce 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -299,8 +299,9 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve permissionsManagerMock := permissions.NewMockManager(ctrl) peersManager := peers.NewManager(store, permissionsManagerMock) + settingsManagerMock := settings.NewMockManager(ctrl) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, eventStore) + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) diff --git a/go.mod b/go.mod index 68730bf53..70e52875f 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20250826184705-1866b8dd841f + github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 4783c47eb..3fdef5d08 100644 --- a/go.sum +++ b/go.sum @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107 h1:ZJwhKexMlK15B/Ld+1T8VYE2Mt1lk1kf2DlXr46EHcw= github.com/netbirdio/ice/v4 v4.0.0-20250827161942-426799a23107/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250826184705-1866b8dd841f h1:r1gnjw0TfkaDLSCmAE3g5N5ulcd5WpFHaGrqQomCXP4= -github.com/netbirdio/management-integrations/integrations v0.0.0-20250826184705-1866b8dd841f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0 h1:9BUqQHPVOGr0edk8EifUBUfTr2Ob0ypAPxtasUApBxQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250906095204-f87a07690ba0/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index f9023b204..984a56a39 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -20,7 +20,11 @@ func (s *BaseServer) PeersUpdateManager() *server.PeersUpdateManager { func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValidator { return Create(s, func() integrated_validator.IntegratedValidator { - integratedPeerValidator, err := integrations.NewIntegratedValidator(context.Background(), s.PeersManager(), s.EventStore()) + integratedPeerValidator, err := integrations.NewIntegratedValidator( + context.Background(), + s.PeersManager(), + s.SettingsManager(), + s.EventStore()) if err != nil { log.Errorf("failed to create integrated peer validator: %v", err) } diff --git a/management/server/peers/manager.go b/management/server/peers/manager.go index 50e36a880..cb135f4ac 100644 --- a/management/server/peers/manager.go +++ b/management/server/peers/manager.go @@ -18,6 +18,7 @@ type Manager interface { GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) GetPeerAccountID(ctx context.Context, peerID string) (string, error) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) + GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) } type managerImpl struct { @@ -61,3 +62,7 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID) } + +func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { + return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs) +} diff --git a/management/server/peers/manager_mock.go b/management/server/peers/manager_mock.go index b247a1752..994f8346b 100644 --- a/management/server/peers/manager_mock.go +++ b/management/server/peers/manager_mock.go @@ -79,3 +79,18 @@ func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *go mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID) } + +// GetPeersByGroupIDs mocks base method. +func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPeersByGroupIDs", ctx, accountID, groupsIDs) + ret0, _ := ret[0].([]*peer.Peer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPeersByGroupIDs indicates an expected call of GetPeersByGroupIDs. +func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs) +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 45561f950..027938320 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2847,3 +2847,22 @@ func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, i } return nil } + +func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) { + if len(groupIDs) == 0 { + return []*nbpeer.Peer{}, nil + } + + var peers []*nbpeer.Peer + peerIDsSubquery := s.db.Model(&types.GroupPeer{}). + Select("DISTINCT peer_id"). + Where("account_id = ? AND group_id IN ?", accountID, groupIDs) + + result := s.db.Where("id IN (?)", peerIDsSubquery).Find(&peers) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get peers by group IDs: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers by group IDs") + } + + return peers, nil +} diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 935b0a595..d40c4664c 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -3607,3 +3607,113 @@ func intToIPv4(n uint32) net.IP { binary.BigEndian.PutUint32(ip, n) return ip } + +func TestSqlStore_GetPeersByGroupIDs(t *testing.T) { + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + group1ID := "test-group-1" + group2ID := "test-group-2" + emptyGroupID := "empty-group" + + peer1 := "cfefqs706sqkneg59g4g" + peer2 := "cfeg6sf06sqkneg59g50" + + tests := []struct { + name string + groupIDs []string + expectedPeers []string + expectedCount int + }{ + { + name: "retrieve peers from single group with multiple peers", + groupIDs: []string{group1ID}, + expectedPeers: []string{peer1, peer2}, + expectedCount: 2, + }, + { + name: "retrieve peers from single group with one peer", + groupIDs: []string{group2ID}, + expectedPeers: []string{peer1}, + expectedCount: 1, + }, + { + name: "retrieve peers from multiple groups (with overlap)", + groupIDs: []string{group1ID, group2ID}, + expectedPeers: []string{peer1, peer2}, // should deduplicate + expectedCount: 2, + }, + { + name: "retrieve peers from existing 'All' group", + groupIDs: []string{"cfefqs706sqkneg59g3g"}, // All group from test data + expectedPeers: []string{peer1, peer2}, + expectedCount: 2, + }, + { + name: "retrieve peers from empty group", + groupIDs: []string{emptyGroupID}, + expectedPeers: []string{}, + expectedCount: 0, + }, + { + name: "retrieve peers from non-existing group", + groupIDs: []string{"non-existing-group"}, + expectedPeers: []string{}, + expectedCount: 0, + }, + { + name: "empty group IDs list", + groupIDs: []string{}, + expectedPeers: []string{}, + expectedCount: 0, + }, + { + name: "mix of existing and non-existing groups", + groupIDs: []string{group1ID, "non-existing-group"}, + expectedPeers: []string{peer1, peer2}, + expectedCount: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + ctx := context.Background() + + groups := []*types.Group{ + { + ID: group1ID, + AccountID: accountID, + }, + { + ID: group2ID, + AccountID: accountID, + }, + } + require.NoError(t, store.CreateGroups(ctx, accountID, groups)) + + require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group1ID)) + require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer2, group1ID)) + require.NoError(t, store.AddPeerToGroup(ctx, accountID, peer1, group2ID)) + + peers, err := store.GetPeersByGroupIDs(ctx, accountID, tt.groupIDs) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + + if tt.expectedCount > 0 { + actualPeerIDs := make([]string, len(peers)) + for i, peer := range peers { + actualPeerIDs[i] = peer.ID + } + assert.ElementsMatch(t, tt.expectedPeers, actualPeerIDs) + + // Verify all returned peers belong to the correct account + for _, peer := range peers { + assert.Equal(t, accountID, peer.AccountID) + } + } + }) + } +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 545549410..3c9d896b0 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -136,6 +136,7 @@ type Store interface { GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) + GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index b04cdd96a..becc10ded 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -86,7 +86,9 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { AnyTimes() peersManger := peers.NewManager(store, permissionsManagerMock) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, eventStore) + settingsManagerMock := settings.NewMockManager(ctrl) + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err)