diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go index 88156ef38..e491e2bbc 100644 --- a/management/internals/modules/reverseproxy/proxy/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -16,7 +16,6 @@ type Manager interface { Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveClusterAddresses(ctx context.Context) ([]string, error) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) - GetActiveClusters(ctx context.Context) ([]Cluster, error) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool ClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go index f168f7af4..58e612f6e 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -17,7 +17,7 @@ type store interface { UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) - GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool @@ -116,16 +116,6 @@ func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, erro return addresses, nil } -// GetActiveClusters returns all active proxy clusters with their connected proxy count. -func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) { - clusters, err := m.store.GetActiveProxyClusters(ctx) - if err != nil { - log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err) - return nil, err - } - return clusters, nil -} - // ClusterSupportsCustomPorts returns whether any active proxy in the cluster // supports custom ports. Returns nil when no proxy has reported capabilities. func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool { diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager_test.go b/management/internals/modules/reverseproxy/proxy/manager/manager_test.go index 8f1ef7569..8bbb275ff 100644 --- a/management/internals/modules/reverseproxy/proxy/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/proxy/manager/manager_test.go @@ -57,7 +57,7 @@ func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context } return nil, nil } -func (m *mockStore) GetActiveProxyClusters(_ context.Context) ([]proxy.Cluster, error) { +func (m *mockStore) GetActiveProxyClusters(_ context.Context, _ string) ([]proxy.Cluster, error) { return nil, nil } func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error { diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go index b78450796..5d43fae7a 100644 --- a/management/internals/modules/reverseproxy/proxy/manager_mock.go +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -148,20 +148,6 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, acco return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID) } -func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetActiveClusters", ctx) - ret0, _ := ret[0].([]Cluster) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetActiveClusters indicates an expected call of GetActiveClusters. -func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx) -} - // Heartbeat mocks base method. func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index 285afd95f..1cec7a8c2 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -120,7 +120,7 @@ func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID strin return nil, status.NewPermissionDeniedError() } - return m.store.GetActiveProxyClusters(ctx) + return m.store.GetActiveProxyClusters(ctx, accountID) } // DeleteAccountCluster removes all proxy registrations for the given cluster address diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 2d68e7d6a..649164554 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -5526,13 +5526,15 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd return nil } -// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies +// GetActiveProxyClusterAddresses returns the unique cluster addresses of active +// shared proxies (those without an account scope). BYOP cluster addresses are +// excluded; use GetActiveProxyClusterAddressesForAccount to retrieve them. func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) { var addresses []string result := s.db. Model(&proxy.Proxy{}). - Where("status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)). + Where("account_id IS NULL AND status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)). Distinct("cluster_address"). Pluck("cluster_address", &addresses) @@ -5549,7 +5551,7 @@ func (s *SqlStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context, result := s.db. Model(&proxy.Proxy{}). - Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-2*time.Minute)). + Where("account_id = ? AND status = ? AND last_seen > ?", accountID, proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)). Distinct("cluster_address"). Pluck("cluster_address", &addresses) @@ -5606,12 +5608,13 @@ func (s *SqlStore) DeleteAccountCluster(ctx context.Context, clusterAddress, acc return nil } -func (s *SqlStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { +func (s *SqlStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) { var clusters []proxy.Cluster result := s.db.Model(&proxy.Proxy{}). Select("MIN(id) as id, cluster_address as address, COUNT(*) as connected_proxies, COUNT(account_id) > 0 as self_hosted"). - Where("status = ? AND last_seen > ?", proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold)). + Where("status = ? AND last_seen > ? AND (account_id IS NULL OR account_id = ?)", + proxy.StatusConnected, time.Now().Add(-proxyActiveThreshold), accountID). Group("cluster_address"). Scan(&clusters) diff --git a/management/server/store/store.go b/management/server/store/store.go index a31c97bee..8bb84c2bb 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -291,7 +291,7 @@ type Store interface { UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) - GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) + GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 78b45f3f2..d199a1210 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -1346,18 +1346,18 @@ func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddressesForAccount(ctx, a } // GetActiveProxyClusters mocks base method. -func (m *MockStore) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) { +func (m *MockStore) GetActiveProxyClusters(ctx context.Context, accountID string) ([]proxy.Cluster, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx) + ret := m.ctrl.Call(m, "GetActiveProxyClusters", ctx, accountID) ret0, _ := ret[0].([]proxy.Cluster) ret1, _ := ret[1].(error) return ret0, ret1 } // GetActiveProxyClusters indicates an expected call of GetActiveProxyClusters. -func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx interface{}) *gomock.Call { +func (mr *MockStoreMockRecorder) GetActiveProxyClusters(ctx, accountID interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusters", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusters), ctx, accountID) } // GetAllAccounts mocks base method.