mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-20 23:59:55 +00:00
Merge branch 'main' into refactor/permissions-manager
This commit is contained in:
@@ -131,9 +131,11 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
||||
return err
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
|
||||
})
|
||||
if !(peer.ProxyMeta.Embedded || peer.Meta.KernelVersion == "wasm") {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -27,19 +27,15 @@ type store interface {
|
||||
|
||||
type proxyManager interface {
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
}
|
||||
|
||||
type clusterCapabilities interface {
|
||||
ClusterSupportsCustomPorts(clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(clusterAddr string) *bool
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store store
|
||||
validator domain.Validator
|
||||
proxyManager proxyManager
|
||||
clusterCapabilities clusterCapabilities
|
||||
accountManager account.Manager
|
||||
store store
|
||||
validator domain.Validator
|
||||
proxyManager proxyManager
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func NewManager(store store, proxyMgr proxyManager, accountManager account.Manager) Manager {
|
||||
@@ -51,11 +47,6 @@ func NewManager(store store, proxyMgr proxyManager, accountManager account.Manag
|
||||
}
|
||||
}
|
||||
|
||||
// SetClusterCapabilities sets the cluster capabilities provider for domain queries.
|
||||
func (m *Manager) SetClusterCapabilities(caps clusterCapabilities) {
|
||||
m.clusterCapabilities = caps
|
||||
}
|
||||
|
||||
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
|
||||
domains, err := m.store.ListCustomDomains(ctx, accountID)
|
||||
if err != nil {
|
||||
@@ -83,10 +74,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
Type: domain.TypeFree,
|
||||
Validated: true,
|
||||
}
|
||||
if m.clusterCapabilities != nil {
|
||||
d.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(cluster)
|
||||
d.RequireSubdomain = m.clusterCapabilities.ClusterRequireSubdomain(cluster)
|
||||
}
|
||||
d.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, cluster)
|
||||
d.RequireSubdomain = m.proxyManager.ClusterRequireSubdomain(ctx, cluster)
|
||||
ret = append(ret, d)
|
||||
}
|
||||
|
||||
@@ -100,8 +89,8 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
Type: domain.TypeCustom,
|
||||
Validated: d.Validated,
|
||||
}
|
||||
if m.clusterCapabilities != nil && d.TargetCluster != "" {
|
||||
cd.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(d.TargetCluster)
|
||||
if d.TargetCluster != "" {
|
||||
cd.SupportsCustomPorts = m.proxyManager.ClusterSupportsCustomPorts(ctx, d.TargetCluster)
|
||||
}
|
||||
// Custom domains never require a subdomain by default since
|
||||
// the account owns them and should be able to use the bare domain.
|
||||
|
||||
@@ -11,11 +11,13 @@ import (
|
||||
|
||||
// Manager defines the interface for proxy operations
|
||||
type Manager interface {
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
|
||||
Disconnect(ctx context.Context, proxyID string) error
|
||||
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
}
|
||||
|
||||
@@ -34,6 +36,4 @@ type Controller interface {
|
||||
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||
GetProxiesForCluster(clusterAddr string) []string
|
||||
ClusterSupportsCustomPorts(clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(clusterAddr string) *bool
|
||||
}
|
||||
|
||||
@@ -72,17 +72,6 @@ func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, cluster
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts returns whether any proxy in the cluster supports custom ports.
|
||||
func (c *GRPCController) ClusterSupportsCustomPorts(clusterAddr string) *bool {
|
||||
return c.proxyGRPCServer.ClusterSupportsCustomPorts(clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterRequireSubdomain returns whether the cluster requires a subdomain label.
|
||||
// Returns nil when no proxy has reported the capability (defaults to false).
|
||||
func (c *GRPCController) ClusterRequireSubdomain(clusterAddr string) *bool {
|
||||
return c.proxyGRPCServer.ClusterRequireSubdomain(clusterAddr)
|
||||
}
|
||||
|
||||
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
|
||||
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
proxySet, ok := c.clusterProxies.Load(clusterAddr)
|
||||
|
||||
@@ -16,6 +16,8 @@ type store interface {
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
}
|
||||
|
||||
@@ -38,9 +40,14 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Connect registers a new proxy connection in the database
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
// Connect registers a new proxy connection in the database.
|
||||
// capabilities may be nil for old proxies that do not report them.
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
|
||||
now := time.Now()
|
||||
var caps proxy.Capabilities
|
||||
if capabilities != nil {
|
||||
caps = *capabilities
|
||||
}
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
ClusterAddress: clusterAddress,
|
||||
@@ -48,6 +55,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
||||
LastSeen: now,
|
||||
ConnectedAt: &now,
|
||||
Status: "connected",
|
||||
Capabilities: caps,
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
@@ -118,6 +126,18 @@ func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
return clusters, nil
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts returns whether any active proxy in the cluster
|
||||
// supports custom ports. Returns nil when no proxy has reported capabilities.
|
||||
func (m Manager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||
return m.store.GetClusterSupportsCustomPorts(ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterRequireSubdomain returns whether any active proxy in the cluster
|
||||
// requires a subdomain. Returns nil when no proxy has reported capabilities.
|
||||
func (m Manager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
|
||||
return m.store.GetClusterRequireSubdomain(ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
|
||||
@@ -50,18 +50,46 @@ func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interfac
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration)
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
// ClusterSupportsCustomPorts mocks base method.
|
||||
func (m *MockManager) ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
|
||||
ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts.
|
||||
func (mr *MockManagerMockRecorder) ClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockManager)(nil).ClusterSupportsCustomPorts), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterRequireSubdomain mocks base method.
|
||||
func (m *MockManager) ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClusterRequireSubdomain", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain.
|
||||
func (mr *MockManagerMockRecorder) ClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockManager)(nil).ClusterRequireSubdomain), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Connect indicates an expected call of Connect.
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
||||
}
|
||||
|
||||
// Disconnect mocks base method.
|
||||
@@ -145,34 +173,6 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts mocks base method.
|
||||
func (m *MockController) ClusterSupportsCustomPorts(clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts.
|
||||
func (mr *MockControllerMockRecorder) ClusterSupportsCustomPorts(clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockController)(nil).ClusterSupportsCustomPorts), clusterAddr)
|
||||
}
|
||||
|
||||
// ClusterRequireSubdomain mocks base method.
|
||||
func (m *MockController) ClusterRequireSubdomain(clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ClusterRequireSubdomain", clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain.
|
||||
func (mr *MockControllerMockRecorder) ClusterRequireSubdomain(clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockController)(nil).ClusterRequireSubdomain), clusterAddr)
|
||||
}
|
||||
|
||||
// GetOIDCValidationConfig mocks base method.
|
||||
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -2,6 +2,17 @@ package proxy
|
||||
|
||||
import "time"
|
||||
|
||||
// Capabilities describes what a proxy can handle, as reported via gRPC.
|
||||
// Nil fields mean the proxy never reported this capability.
|
||||
type Capabilities struct {
|
||||
// SupportsCustomPorts indicates whether this proxy can bind arbitrary
|
||||
// ports for TCP/UDP services. TLS uses SNI routing and is not gated.
|
||||
SupportsCustomPorts *bool
|
||||
// RequireSubdomain indicates whether a subdomain label is required in
|
||||
// front of the cluster domain.
|
||||
RequireSubdomain *bool
|
||||
}
|
||||
|
||||
// Proxy represents a reverse proxy instance
|
||||
type Proxy struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||
@@ -10,7 +21,8 @@ type Proxy struct {
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
Capabilities Capabilities `gorm:"embedded"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
@@ -74,24 +74,27 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
|
||||
require.NoError(t, err)
|
||||
|
||||
mockCtrl := proxy.NewMockController(ctrl)
|
||||
mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes()
|
||||
mockCtrl.EXPECT().ClusterRequireSubdomain(gomock.Any()).Return((*bool)(nil)).AnyTimes()
|
||||
mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()
|
||||
|
||||
mockCaps := proxy.NewMockManager(ctrl)
|
||||
mockCaps.EXPECT().ClusterSupportsCustomPorts(gomock.Any(), testCluster).Return(customPortsSupported).AnyTimes()
|
||||
mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), testCluster).Return((*bool)(nil)).AnyTimes()
|
||||
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
},
|
||||
}
|
||||
|
||||
mgr := &Manager{
|
||||
store: testStore,
|
||||
accountManager: accountMgr,
|
||||
store: testStore,
|
||||
accountManager: accountMgr,
|
||||
proxyController: mockCtrl,
|
||||
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
|
||||
capabilities: mockCaps,
|
||||
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
|
||||
}
|
||||
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
||||
|
||||
|
||||
@@ -72,20 +72,28 @@ type ClusterDeriver interface {
|
||||
GetClusterDomains() []string
|
||||
}
|
||||
|
||||
// CapabilityProvider queries proxy cluster capabilities from the database.
|
||||
type CapabilityProvider interface {
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
ClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
proxyController proxy.Controller
|
||||
capabilities CapabilityProvider
|
||||
clusterDeriver ClusterDeriver
|
||||
exposeReaper *exposeReaper
|
||||
}
|
||||
|
||||
// NewManager creates a new service manager.
|
||||
func NewManager(store store.Store, accountManager account.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager {
|
||||
func NewManager(store store.Store, accountManager account.Manager, proxyController proxy.Controller, capabilities CapabilityProvider, clusterDeriver ClusterDeriver) *Manager {
|
||||
mgr := &Manager{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
proxyController: proxyController,
|
||||
capabilities: capabilities,
|
||||
clusterDeriver: clusterDeriver,
|
||||
}
|
||||
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
||||
@@ -200,7 +208,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
|
||||
}
|
||||
service.ProxyCluster = proxyCluster
|
||||
|
||||
if err := m.validateSubdomainRequirement(service.Domain, proxyCluster); err != nil {
|
||||
if err := m.validateSubdomainRequirement(ctx, service.Domain, proxyCluster); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -231,11 +239,11 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
|
||||
// validateSubdomainRequirement checks whether the domain can be used bare
|
||||
// (without a subdomain label) on the given cluster. If the cluster reports
|
||||
// require_subdomain=true and the domain equals the cluster domain, it rejects.
|
||||
func (m *Manager) validateSubdomainRequirement(domain, cluster string) error {
|
||||
func (m *Manager) validateSubdomainRequirement(ctx context.Context, domain, cluster string) error {
|
||||
if domain != cluster {
|
||||
return nil
|
||||
}
|
||||
requireSub := m.proxyController.ClusterRequireSubdomain(cluster)
|
||||
requireSub := m.capabilities.ClusterRequireSubdomain(ctx, cluster)
|
||||
if requireSub != nil && *requireSub {
|
||||
return status.Errorf(status.InvalidArgument, "domain %s requires a subdomain label", domain)
|
||||
}
|
||||
@@ -243,6 +251,8 @@ func (m *Manager) validateSubdomainRequirement(domain, cluster string) error {
|
||||
}
|
||||
|
||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
|
||||
customPorts := m.clusterCustomPorts(ctx, svc)
|
||||
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if svc.Domain != "" {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
|
||||
@@ -250,7 +260,7 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
|
||||
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -270,12 +280,23 @@ func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *
|
||||
})
|
||||
}
|
||||
|
||||
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
|
||||
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service) error {
|
||||
// clusterCustomPorts queries whether the cluster supports custom ports.
|
||||
// Must be called before entering a transaction: the underlying query uses
|
||||
// the main DB handle, which deadlocks when called inside a transaction
|
||||
// that already holds the connection.
|
||||
func (m *Manager) clusterCustomPorts(ctx context.Context, svc *service.Service) *bool {
|
||||
if !service.IsL4Protocol(svc.Mode) {
|
||||
return nil
|
||||
}
|
||||
return m.capabilities.ClusterSupportsCustomPorts(ctx, svc.ProxyCluster)
|
||||
}
|
||||
|
||||
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
|
||||
// customPorts must be pre-computed via clusterCustomPorts before entering a transaction.
|
||||
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service, customPorts *bool) error {
|
||||
if !service.IsL4Protocol(svc.Mode) {
|
||||
return nil
|
||||
}
|
||||
customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster)
|
||||
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
|
||||
if svc.Source != service.SourceEphemeral {
|
||||
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
|
||||
@@ -359,12 +380,14 @@ func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string
|
||||
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
|
||||
// for the same peer, preventing the per-peer limit from being bypassed.
|
||||
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
|
||||
customPorts := m.clusterCustomPorts(ctx, svc)
|
||||
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
|
||||
if err := m.ensureL4Port(ctx, transaction, svc, customPorts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -459,16 +482,49 @@ type serviceUpdateInfo struct {
|
||||
}
|
||||
|
||||
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
|
||||
effectiveCluster, err := m.resolveEffectiveCluster(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svcForCaps := *service
|
||||
svcForCaps.ProxyCluster = effectiveCluster
|
||||
customPorts := m.clusterCustomPorts(ctx, &svcForCaps)
|
||||
|
||||
var updateInfo serviceUpdateInfo
|
||||
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo)
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo, customPorts)
|
||||
})
|
||||
|
||||
return &updateInfo, err
|
||||
}
|
||||
|
||||
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo) error {
|
||||
// resolveEffectiveCluster determines the cluster that will be used after the update.
|
||||
// It reads the existing service without locking and derives the new cluster if the domain changed.
|
||||
func (m *Manager) resolveEffectiveCluster(ctx context.Context, accountID string, svc *service.Service) (string, error) {
|
||||
existing, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, svc.ID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if existing.Domain == svc.Domain {
|
||||
return existing.ProxyCluster, nil
|
||||
}
|
||||
|
||||
if m.clusterDeriver != nil {
|
||||
derived, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
|
||||
} else {
|
||||
return derived, nil
|
||||
}
|
||||
}
|
||||
|
||||
return existing.ProxyCluster, nil
|
||||
}
|
||||
|
||||
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo, customPorts *bool) error {
|
||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -493,7 +549,7 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
service.ProxyCluster = existingService.ProxyCluster
|
||||
}
|
||||
|
||||
if err := m.validateSubdomainRequirement(service.Domain, service.ProxyCluster); err != nil {
|
||||
if err := m.validateSubdomainRequirement(ctx, service.Domain, service.ProxyCluster); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -505,7 +561,7 @@ func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.St
|
||||
m.preserveListenPort(service, existingService)
|
||||
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, service); err != nil {
|
||||
if err := m.ensureL4Port(ctx, transaction, service, customPorts); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
|
||||
|
||||
@@ -12,9 +12,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
|
||||
permissions2 "github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
@@ -696,8 +693,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1316,11 +1313,11 @@ func TestValidateSubdomainRequirement(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
mockCtrl := proxy.NewMockController(ctrl)
|
||||
mockCtrl.EXPECT().ClusterRequireSubdomain(tc.cluster).Return(tc.requireSubdomain).AnyTimes()
|
||||
mockCaps := proxy.NewMockManager(ctrl)
|
||||
mockCaps.EXPECT().ClusterRequireSubdomain(gomock.Any(), tc.cluster).Return(tc.requireSubdomain).AnyTimes()
|
||||
|
||||
mgr := &Manager{proxyController: mockCtrl}
|
||||
err := mgr.validateSubdomainRequirement(tc.domain, tc.cluster)
|
||||
mgr := &Manager{capabilities: mockCaps}
|
||||
err := mgr.validateSubdomainRequirement(context.Background(), tc.domain, tc.cluster)
|
||||
if tc.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "requires a subdomain label")
|
||||
|
||||
@@ -787,6 +787,11 @@ func (s *Service) validateHTTPTargets() error {
|
||||
}
|
||||
|
||||
func (s *Service) validateL4Target(target *Target) error {
|
||||
// L4 services have a single target; per-target disable is meaningless
|
||||
// (use the service-level Enabled flag instead). Force it on so that
|
||||
// buildPathMappings always includes the target in the proto.
|
||||
target.Enabled = true
|
||||
|
||||
if target.Port == 0 {
|
||||
return errors.New("target port is required for L4 services")
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/job"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
|
||||
@@ -71,6 +72,7 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
||||
signingKeyRefreshEnabled := s.Config.HttpConfig.IdpSignKeyRefreshEnabled
|
||||
issuer := s.Config.HttpConfig.AuthIssuer
|
||||
userIDClaim := s.Config.HttpConfig.AuthUserIDClaim
|
||||
var keyFetcher nbjwt.KeyFetcher
|
||||
|
||||
// Use embedded IdP configuration if available
|
||||
if oauthProvider := s.OAuthConfigProvider(); oauthProvider != nil {
|
||||
@@ -78,8 +80,11 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
||||
if len(audiences) > 0 {
|
||||
audience = audiences[0] // Use the first client ID as the primary audience
|
||||
}
|
||||
// Use localhost keys location for internal validation (management has embedded Dex)
|
||||
keysLocation = oauthProvider.GetLocalKeysLocation()
|
||||
keyFetcher = oauthProvider.GetKeyFetcher()
|
||||
// Fall back to default keys location if direct key fetching is not available
|
||||
if keyFetcher == nil {
|
||||
keysLocation = oauthProvider.GetLocalKeysLocation()
|
||||
}
|
||||
signingKeyRefreshEnabled = true
|
||||
issuer = oauthProvider.GetIssuer()
|
||||
userIDClaim = oauthProvider.GetUserIDClaim()
|
||||
@@ -92,7 +97,8 @@ func (s *BaseServer) AuthManager() auth.Manager {
|
||||
keysLocation,
|
||||
userIDClaim,
|
||||
audiences,
|
||||
signingKeyRefreshEnabled)
|
||||
signingKeyRefreshEnabled,
|
||||
keyFetcher)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -117,9 +117,11 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
||||
return Create(s, func() idp.Manager {
|
||||
var idpManager idp.Manager
|
||||
var err error
|
||||
|
||||
// Use embedded IdP service if embedded Dex is configured and enabled.
|
||||
// Legacy IdpManager won't be used anymore even if configured.
|
||||
if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled {
|
||||
embeddedEnabled := s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled
|
||||
if embeddedEnabled {
|
||||
idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create embedded IDP service: %v", err)
|
||||
@@ -165,19 +167,19 @@ func (s *BaseServer) GroupsManager() groups.Manager {
|
||||
|
||||
func (s *BaseServer) ResourcesManager() resources.Manager {
|
||||
return Create(s, func() resources.Manager {
|
||||
return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
|
||||
return resources.NewManager(s.Store(), s.GroupsManager(), s.AccountManager(), s.ServiceManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) RoutesManager() routers.Manager {
|
||||
return Create(s, func() routers.Manager {
|
||||
return routers.NewManager(s.Store(), s.PermissionsManager(), s.AccountManager())
|
||||
return routers.NewManager(s.Store(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) NetworksManager() networks.Manager {
|
||||
return Create(s, func() networks.Manager {
|
||||
return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
|
||||
return networks.NewManager(s.Store(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -195,7 +197,7 @@ func (s *BaseServer) RecordsManager() records.Manager {
|
||||
|
||||
func (s *BaseServer) ServiceManager() service.Manager {
|
||||
return Create(s, func() service.Manager {
|
||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
|
||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.ServiceProxyController(), s.ProxyManager(), s.ReverseProxyDomainManager())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -212,9 +214,6 @@ func (s *BaseServer) ProxyManager() proxy.Manager {
|
||||
func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
||||
return Create(s, func() *manager.Manager {
|
||||
m := manager.NewManager(s.Store(), s.ProxyManager(), s.AccountManager())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
m.SetClusterCapabilities(s.ServiceProxyController())
|
||||
})
|
||||
return &m
|
||||
})
|
||||
}
|
||||
|
||||
@@ -182,9 +182,21 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
// Register proxy in database
|
||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err)
|
||||
// Register proxy in database with capabilities
|
||||
var caps *proxy.Capabilities
|
||||
if c := req.GetCapabilities(); c != nil {
|
||||
caps = &proxy.Capabilities{
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
}
|
||||
}
|
||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
@@ -297,6 +309,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
}
|
||||
|
||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
||||
if !proxyAcceptsMapping(conn, m) {
|
||||
continue
|
||||
}
|
||||
mappings = append(mappings, m)
|
||||
}
|
||||
return mappings, nil
|
||||
@@ -445,22 +460,46 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
||||
|
||||
log.Debugf("Sending service update to cluster %s", clusterAddr)
|
||||
for _, proxyID := range proxyIDs {
|
||||
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||
conn := connVal.(*proxyConnection)
|
||||
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
}
|
||||
connVal, ok := s.connectedProxies.Load(proxyID)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
conn := connVal.(*proxyConnection)
|
||||
if !proxyAcceptsMapping(conn, update) {
|
||||
log.WithContext(ctx).Debugf("Skipping proxy %s: does not support custom ports for mapping %s", proxyID, update.Id)
|
||||
continue
|
||||
}
|
||||
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// proxyAcceptsMapping returns whether the proxy should receive this mapping.
|
||||
// Old proxies that never reported capabilities are skipped for non-TLS L4
|
||||
// mappings with a custom listen port, since they don't understand the
|
||||
// protocol. Proxies that report capabilities (even SupportsCustomPorts=false)
|
||||
// are new enough to handle the mapping. TLS uses SNI routing and works on
|
||||
// any proxy. Delete operations are always sent so proxies can clean up.
|
||||
func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) bool {
|
||||
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
|
||||
return true
|
||||
}
|
||||
if mapping.ListenPort == 0 || mapping.Mode == "tls" {
|
||||
return true
|
||||
}
|
||||
// Old proxies that never reported capabilities don't understand
|
||||
// custom port mappings.
|
||||
return conn.capabilities != nil && conn.capabilities.SupportsCustomPorts != nil
|
||||
}
|
||||
|
||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||
// create/update operations. For delete operations the original mapping is
|
||||
// used unchanged because proxies do not need to authenticate for removal.
|
||||
@@ -508,64 +547,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
}
|
||||
}
|
||||
|
||||
// ClusterSupportsCustomPorts returns whether any connected proxy in the given
|
||||
// cluster reports custom port support. Returns nil if no proxy has reported
|
||||
// capabilities (old proxies that predate the field).
|
||||
func (s *ProxyServiceServer) ClusterSupportsCustomPorts(clusterAddr string) *bool {
|
||||
if s.proxyController == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var hasCapabilities bool
|
||||
for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) {
|
||||
connVal, ok := s.connectedProxies.Load(pid)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
conn := connVal.(*proxyConnection)
|
||||
if conn.capabilities == nil || conn.capabilities.SupportsCustomPorts == nil {
|
||||
continue
|
||||
}
|
||||
if *conn.capabilities.SupportsCustomPorts {
|
||||
return ptr(true)
|
||||
}
|
||||
hasCapabilities = true
|
||||
}
|
||||
if hasCapabilities {
|
||||
return ptr(false)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClusterRequireSubdomain returns whether any connected proxy in the given
|
||||
// cluster reports that a subdomain is required. Returns nil if no proxy has
|
||||
// reported the capability (defaults to not required).
|
||||
func (s *ProxyServiceServer) ClusterRequireSubdomain(clusterAddr string) *bool {
|
||||
if s.proxyController == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var hasCapabilities bool
|
||||
for _, pid := range s.proxyController.GetProxiesForCluster(clusterAddr) {
|
||||
connVal, ok := s.connectedProxies.Load(pid)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
conn := connVal.(*proxyConnection)
|
||||
if conn.capabilities == nil || conn.capabilities.RequireSubdomain == nil {
|
||||
continue
|
||||
}
|
||||
if *conn.capabilities.RequireSubdomain {
|
||||
return ptr(true)
|
||||
}
|
||||
hasCapabilities = true
|
||||
}
|
||||
if hasCapabilities {
|
||||
return ptr(false)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) {
|
||||
service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId())
|
||||
if err != nil {
|
||||
|
||||
@@ -53,14 +53,6 @@ func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clus
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testProxyController) ClusterSupportsCustomPorts(_ string) *bool {
|
||||
return ptr(true)
|
||||
}
|
||||
|
||||
func (c *testProxyController) ClusterRequireSubdomain(_ string) *bool {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
@@ -355,14 +347,14 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
|
||||
|
||||
const cluster = "proxy.example.com"
|
||||
|
||||
// Proxy A supports custom ports.
|
||||
chA := registerFakeProxyWithCaps(s, "proxy-a", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
|
||||
// Proxy B does NOT support custom ports (shared cloud proxy).
|
||||
chB := registerFakeProxyWithCaps(s, "proxy-b", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
|
||||
// Modern proxy reports capabilities.
|
||||
chModern := registerFakeProxyWithCaps(s, "proxy-modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
|
||||
// Legacy proxy never reported capabilities (nil).
|
||||
chLegacy := registerFakeProxy(s, "proxy-legacy", cluster)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// TLS passthrough works on all proxies regardless of custom port support.
|
||||
// TLS passthrough with custom port: all proxies receive it (SNI routing).
|
||||
tlsMapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-tls",
|
||||
@@ -375,12 +367,26 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
|
||||
|
||||
s.SendServiceUpdateToCluster(ctx, tlsMapping, cluster)
|
||||
|
||||
msgA := drainMapping(chA)
|
||||
msgB := drainMapping(chB)
|
||||
assert.NotNil(t, msgA, "proxy-a should receive TLS mapping")
|
||||
assert.NotNil(t, msgB, "proxy-b should receive TLS mapping (passthrough works on all proxies)")
|
||||
assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TLS mapping")
|
||||
assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive TLS mapping (SNI works on all)")
|
||||
|
||||
// Send an HTTP mapping: both should receive it.
|
||||
// TCP mapping with custom port: only modern proxy receives it.
|
||||
tcpMapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-tcp",
|
||||
AccountId: "account-1",
|
||||
Domain: "db.example.com",
|
||||
Mode: "tcp",
|
||||
ListenPort: 5432,
|
||||
Path: []*proto.PathMapping{{Target: "10.0.0.5:5432"}},
|
||||
}
|
||||
|
||||
s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster)
|
||||
|
||||
assert.NotNil(t, drainMapping(chModern), "modern proxy should receive TCP custom-port mapping")
|
||||
assert.Nil(t, drainMapping(chLegacy), "legacy proxy should NOT receive TCP custom-port mapping")
|
||||
|
||||
// HTTP mapping (no listen port): both receive it.
|
||||
httpMapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-http",
|
||||
@@ -391,10 +397,16 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) {
|
||||
|
||||
s.SendServiceUpdateToCluster(ctx, httpMapping, cluster)
|
||||
|
||||
msgA = drainMapping(chA)
|
||||
msgB = drainMapping(chB)
|
||||
assert.NotNil(t, msgA, "proxy-a should receive HTTP mapping")
|
||||
assert.NotNil(t, msgB, "proxy-b should receive HTTP mapping")
|
||||
assert.NotNil(t, drainMapping(chModern), "modern proxy should receive HTTP mapping")
|
||||
assert.NotNil(t, drainMapping(chLegacy), "legacy proxy should receive HTTP mapping")
|
||||
|
||||
// Proxy that reports SupportsCustomPorts=false still receives custom-port
|
||||
// mappings because it understands the protocol (it's new enough).
|
||||
chNewNoCustom := registerFakeProxyWithCaps(s, "proxy-new-no-custom", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
|
||||
|
||||
s.SendServiceUpdateToCluster(ctx, tcpMapping, cluster)
|
||||
|
||||
assert.NotNil(t, drainMapping(chNewNoCustom), "new proxy with SupportsCustomPorts=false should still receive mapping")
|
||||
}
|
||||
|
||||
func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
|
||||
@@ -408,7 +420,8 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
|
||||
|
||||
const cluster = "proxy.example.com"
|
||||
|
||||
chShared := registerFakeProxyWithCaps(s, "proxy-shared", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
|
||||
// Legacy proxy (no capabilities) still receives TLS since it uses SNI.
|
||||
chLegacy := registerFakeProxy(s, "proxy-legacy", cluster)
|
||||
|
||||
tlsMapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
@@ -421,8 +434,8 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) {
|
||||
|
||||
s.SendServiceUpdateToCluster(context.Background(), tlsMapping, cluster)
|
||||
|
||||
msg := drainMapping(chShared)
|
||||
assert.NotNil(t, msg, "shared proxy should receive TLS mapping even without custom port support")
|
||||
msg := drainMapping(chLegacy)
|
||||
assert.NotNil(t, msg, "legacy proxy should receive TLS mapping (SNI works without custom port support)")
|
||||
}
|
||||
|
||||
// TestServiceModifyNotifications exercises every possible modification
|
||||
@@ -589,7 +602,7 @@ func TestServiceModifyNotifications(t *testing.T) {
|
||||
s.SetProxyController(newTestProxyController())
|
||||
const cluster = "proxy.example.com"
|
||||
chModern := registerFakeProxyWithCaps(s, "modern", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(true)})
|
||||
chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
|
||||
chLegacy := registerFakeProxy(s, "legacy", cluster)
|
||||
|
||||
// TLS passthrough works on all proxies regardless of custom port support
|
||||
s.SendServiceUpdateToCluster(ctx, tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED), cluster)
|
||||
@@ -608,7 +621,7 @@ func TestServiceModifyNotifications(t *testing.T) {
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
const cluster = "proxy.example.com"
|
||||
chLegacy := registerFakeProxyWithCaps(s, "legacy", cluster, &proto.ProxyCapabilities{SupportsCustomPorts: ptr(false)})
|
||||
chLegacy := registerFakeProxy(s, "legacy", cluster)
|
||||
|
||||
mapping := tlsOnlyMapping(proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED)
|
||||
mapping.ListenPort = 0 // default port
|
||||
|
||||
Reference in New Issue
Block a user