diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index 67a8e74fa..262c2af9b 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -15,6 +15,7 @@ type Manager interface { Disconnect(ctx context.Context, proxyID string) error Heartbeat(ctx context.Context, proxyID string) error GetActiveClusterAddresses(ctx context.Context) ([]string, error) + GetActiveClusters(ctx context.Context) ([]Cluster, error) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error } diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go index 4c0964b5c..6350b36bd 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -15,6 +15,7 @@ type store interface { SaveProxy(ctx context.Context, p *proxy.Proxy) error UpdateProxyHeartbeat(ctx context.Context, proxyID string) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) + GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error } @@ -105,6 +106,16 @@ func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error return addresses, nil } +// GetActiveClusters returns all active proxy clusters with their connected proxy count. +func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) { + clusters, err := m.store.GetActiveProxyClusters(ctx) + if err != nil { + log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err) + return nil, err + } + return clusters, nil +} + // CleanupStale removes proxies that haven't sent heartbeat in the specified duration func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil { diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index b07a21122..e2dc4c2b6 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -93,6 +93,21 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx) } +// GetActiveClusters mocks base method. +func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveClusters", ctx) + ret0, _ := ret[0].([]Cluster) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveClusters indicates an expected call of GetActiveClusters. +func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx) +} + // Heartbeat mocks base method. func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error { m.ctrl.T.Helper() @@ -130,20 +145,6 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder { return m.recorder } -// GetOIDCValidationConfig mocks base method. -func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetOIDCValidationConfig") - ret0, _ := ret[0].(OIDCValidationConfig) - return ret0 -} - -// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig. -func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig)) -} - // ClusterSupportsCustomPorts mocks base method. func (m *MockController) ClusterSupportsCustomPorts(clusterAddr string) *bool { m.ctrl.T.Helper() @@ -158,6 +159,20 @@ func (mr *MockControllerMockRecorder) ClusterSupportsCustomPorts(clusterAddr int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockController)(nil).ClusterSupportsCustomPorts), clusterAddr) } +// GetOIDCValidationConfig mocks base method. +func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOIDCValidationConfig") + ret0, _ := ret[0].(OIDCValidationConfig) + return ret0 +} + +// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig. +func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig)) +} + // GetProxiesForCluster mocks base method. func (m *MockController) GetProxiesForCluster(clusterAddr string) []string { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go index 699e1ed02..671eb109f 100644 --- a/management/internals/modules/reverseproxy/proxy/proxy.go +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -18,3 +18,9 @@ type Proxy struct { func (Proxy) TableName() string { return "proxies" } + +// Cluster represents a group of proxy nodes serving the same address. +type Cluster struct { + Address string + ConnectedProxies int +} diff --git a/management/internals/modules/reverseproxy/service/interface.go b/management/internals/modules/reverseproxy/service/interface.go index 39fd7e3ae..a49cbea35 100644 --- a/management/internals/modules/reverseproxy/service/interface.go +++ b/management/internals/modules/reverseproxy/service/interface.go @@ -4,9 +4,12 @@ package service import ( "context" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" ) type Manager interface { + GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error) CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) diff --git a/management/internals/modules/reverseproxy/service/interface_mock.go b/management/internals/modules/reverseproxy/service/interface_mock.go index bdc1f3e65..cc5ccbb8e 100644 --- a/management/internals/modules/reverseproxy/service/interface_mock.go +++ b/management/internals/modules/reverseproxy/service/interface_mock.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" ) // MockManager is a mock of Manager interface. @@ -107,6 +108,21 @@ func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID) } +// GetActiveClusters mocks base method. +func (m *MockManager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveClusters", ctx, accountID, userID) + ret0, _ := ret[0].([]proxy.Cluster) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveClusters indicates an expected call of GetActiveClusters. +func (mr *MockManagerMockRecorder) GetActiveClusters(ctx, accountID, userID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx, accountID, userID) +} + // GetAllServices mocks base method. func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/service/manager/api.go b/management/internals/modules/reverseproxy/service/manager/api.go index c53219d2e..cd81efa88 100644 --- a/management/internals/modules/reverseproxy/service/manager/api.go +++ b/management/internals/modules/reverseproxy/service/manager/api.go @@ -34,6 +34,7 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma accesslogsmanager.RegisterEndpoints(router, accessLogsManager) + router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS") router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS") router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS") @@ -177,3 +178,27 @@ func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } + +func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + apiClusters := make([]api.ProxyCluster, 0, len(clusters)) + for _, c := range clusters { + apiClusters = append(apiClusters, api.ProxyCluster{ + Address: c.Address, + ConnectedProxies: c.ConnectedProxies, + }) + } + + util.WriteJSONObject(r.Context(), w, apiClusters) +} diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index 2251f5084..a7173a131 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -102,6 +102,19 @@ func (m *Manager) StartExposeReaper(ctx context.Context) { m.exposeReaper.StartExposeReaper(ctx) } +// GetActiveClusters returns all active proxy clusters with their connected proxy count. +func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetActiveProxyClusters(ctx) +} + func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index 22fe4506b..0fa9a0dc1 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/server/types" ) @@ -90,6 +91,10 @@ func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {} +func (m *mockReverseProxyManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { + return nil, nil +} + type mockUsersManager struct { users map[string]*types.User err error diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 647e8443b..2f77de86e 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/server/store" @@ -320,6 +321,10 @@ func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Contex func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {} +func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]proxy.Cluster, error) { + return nil, nil +} + type testValidateSessionProxyManager struct{} func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error { @@ -338,6 +343,10 @@ func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Co return nil, nil } +func (m *testValidateSessionProxyManager) GetActiveClusters(_ context.Context) ([]proxy.Cluster, error) { + return nil, nil +} + func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error { return nil } diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index 3bed54e80..922bf4352 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" + nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/store" @@ -433,6 +434,10 @@ func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ stri func (m *testServiceManager) StartExposeReaper(_ context.Context) {} +func (m *testServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { + return nil, nil +} + func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string { t.Helper() diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index b3fbfe141..32f2f8540 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -5440,6 +5440,24 @@ func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string return addresses, nil } +// GetActiveProxyClusters returns all active proxy clusters with their connected proxy count. +func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { + var clusters []proxy.Cluster + + result := s.db.Model(&proxy.Proxy{}). + Select("cluster_address as address, COUNT(*) as connected_proxies"). + Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)). + Group("cluster_address"). + Scan(&clusters) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", result.Error) + return nil, status.Errorf(status.Internal, "get active proxy clusters") + } + + return clusters, nil +} + // CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error { cutoffTime := time.Now().Add(-inactivityDuration) diff --git a/management/server/store/store.go b/management/server/store/store.go index 8bb52f38a..5dbfbd177 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -286,6 +286,7 @@ type Store interface { SaveProxy(ctx context.Context, proxy *proxy.Proxy) error UpdateProxyHeartbeat(ctx context.Context, proxyID string) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) + GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index e75e35b94..05a6fe39f 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -1287,6 +1287,21 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx) } +// GetActiveProxyClusters mocks base method. +func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx) + ret0, _ := ret[0].([]proxy.Cluster) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters. +func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx) +} + // GetAllAccounts mocks base method. func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account { m.ctrl.T.Helper() diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 8af151446..2fcbfe3cf 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -216,6 +216,10 @@ func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]strin return nil, nil } +func (m *testProxyManager) GetActiveClusters(_ context.Context) ([]nbproxy.Cluster, error) { + return nil, nil +} + func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error { return nil } @@ -323,6 +327,10 @@ func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _, func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {} +func (m *storeBackedServiceManager) GetActiveClusters(_ context.Context, _, _ string) ([]nbproxy.Cluster, error) { + return nil, nil +} + func strPtr(s string) *string { return &s }