diff --git a/management/internals/modules/reverseproxy/domain/manager/manager.go b/management/internals/modules/reverseproxy/domain/manager/manager.go index 55ca24ac2..12dd051fd 100644 --- a/management/internals/modules/reverseproxy/domain/manager/manager.go +++ b/management/internals/modules/reverseproxy/domain/manager/manager.go @@ -27,21 +27,21 @@ type store interface { DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error } -type proxyURLProvider interface { - GetConnectedProxyURLs() []string +type proxyManager interface { + GetActiveClusterAddresses(ctx context.Context) ([]string, error) } type Manager struct { store store validator domain.Validator - proxyURLProvider proxyURLProvider + proxyManager proxyManager permissionsManager permissions.Manager } -func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager { +func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager { return Manager{ - store: store, - proxyURLProvider: proxyURLProvider, + store: store, + proxyManager: proxyMgr, validator: domain.Validator{ Resolver: net.DefaultResolver, }, @@ -67,8 +67,12 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d // Add connected proxy clusters as free domains. // The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io"). - allowList := m.proxyURLAllowList() - log.WithFields(log.Fields{ + allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + if err != nil { + log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err) + return nil, err + } + log.WithContext(ctx).WithFields(log.Fields{ "accountID": accountID, "proxyAllowList": allowList, }).Debug("getting domains with proxy allow list") @@ -107,7 +111,10 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName } // Verify the target cluster is in the available clusters - allowList := m.proxyURLAllowList() + allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err) + } clusterValid := false for _, cluster := range allowList { if cluster == targetCluster { @@ -221,25 +228,26 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID } } +// GetClusterDomains returns a list of proxy cluster domains. func (m Manager) GetClusterDomains() []string { - return m.proxyURLAllowList() -} - -// proxyURLAllowList retrieves a list of currently connected proxies and -// their URLs -func (m Manager) proxyURLAllowList() []string { - var reverseProxyAddresses []string - if m.proxyURLProvider != nil { - reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs() + if m.proxyManager == nil { + return nil } - return reverseProxyAddresses + addresses, err := m.proxyManager.GetActiveClusterAddresses(context.Background()) + if err != nil { + return nil + } + return addresses } // DeriveClusterFromDomain determines the proxy cluster for a given domain. // For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain. // For custom domains, the cluster is determined by checking the registered custom domain's target cluster. func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) { - allowList := m.proxyURLAllowList() + allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx) + if err != nil { + return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err) + } if len(allowList) == 0 { return "", fmt.Errorf("no proxy clusters available") } diff --git a/management/internals/modules/reverseproxy/proxy/manager.go b/management/internals/modules/reverseproxy/proxy/manager.go new file mode 100644 index 000000000..15f2f9f54 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/manager.go @@ -0,0 +1,36 @@ +package proxy + +//go:generate go run github.com/golang/mock/mockgen -package proxy -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod + +import ( + "context" + "time" + + "github.com/netbirdio/netbird/shared/management/proto" +) + +// Manager defines the interface for proxy operations +type Manager interface { + Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error + Disconnect(ctx context.Context, proxyID string) error + Heartbeat(ctx context.Context, proxyID string) error + GetActiveClusterAddresses(ctx context.Context) ([]string, error) + CleanupStale(ctx context.Context, inactivityDuration time.Duration) error +} + +// OIDCValidationConfig contains the OIDC configuration needed for token validation. +type OIDCValidationConfig struct { + Issuer string + Audiences []string + KeysLocation string + MaxTokenAgeSeconds int64 +} + +// Controller is responsible for managing proxy clusters and routing service updates. +type Controller interface { + SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) + GetOIDCValidationConfig() OIDCValidationConfig + RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error + UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error + GetProxiesForCluster(clusterAddr string) []string +} diff --git a/management/internals/modules/reverseproxy/proxy/manager/controller.go b/management/internals/modules/reverseproxy/proxy/manager/controller.go new file mode 100644 index 000000000..e5b3e9886 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/manager/controller.go @@ -0,0 +1,88 @@ +package manager + +import ( + "context" + "sync" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// GRPCController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC. +type GRPCController struct { + proxyGRPCServer *nbgrpc.ProxyServiceServer + // Map of cluster address -> set of proxy IDs + clusterProxies sync.Map + metrics *metrics +} + +// NewGRPCController creates a new GRPCController. +func NewGRPCController(proxyGRPCServer *nbgrpc.ProxyServiceServer, meter metric.Meter) (*GRPCController, error) { + m, err := newMetrics(meter) + if err != nil { + return nil, err + } + + return &GRPCController{ + proxyGRPCServer: proxyGRPCServer, + metrics: m, + }, nil +} + +// SendServiceUpdateToCluster sends a service update to a specific proxy cluster. +func (c *GRPCController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) { + c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr) + c.metrics.IncrementServiceUpdateSendCount(clusterAddr) +} + +// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server. +func (c *GRPCController) GetOIDCValidationConfig() proxy.OIDCValidationConfig { + return c.proxyGRPCServer.GetOIDCValidationConfig() +} + +// RegisterProxyToCluster registers a proxy to a specific cluster for routing. +func (c *GRPCController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error { + if clusterAddr == "" { + return nil + } + proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{}) + proxySet.(*sync.Map).Store(proxyID, struct{}{}) + log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr) + + c.metrics.IncrementProxyConnectionCount(clusterAddr) + + return nil +} + +// UnregisterProxyFromCluster removes a proxy from a cluster. +func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error { + if clusterAddr == "" { + return nil + } + if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok { + proxySet.(*sync.Map).Delete(proxyID) + log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr) + + c.metrics.DecrementProxyConnectionCount(clusterAddr) + } + return nil +} + +// GetProxiesForCluster returns all proxy IDs registered for a specific cluster. +func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string { + proxySet, ok := c.clusterProxies.Load(clusterAddr) + if !ok { + return nil + } + + var proxies []string + proxySet.(*sync.Map).Range(func(key, _ interface{}) bool { + proxies = append(proxies, key.(string)) + return true + }) + return proxies +} diff --git a/management/internals/modules/reverseproxy/proxy/manager/manager.go b/management/internals/modules/reverseproxy/proxy/manager/manager.go new file mode 100644 index 000000000..4c0964b5c --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/manager/manager.go @@ -0,0 +1,115 @@ +package manager + +import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/metric" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" +) + +// store defines the interface for proxy persistence operations +type store interface { + SaveProxy(ctx context.Context, p *proxy.Proxy) error + UpdateProxyHeartbeat(ctx context.Context, proxyID string) error + GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) + CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error +} + +// Manager handles all proxy operations +type Manager struct { + store store + metrics *metrics +} + +// NewManager creates a new proxy Manager +func NewManager(store store, meter metric.Meter) (*Manager, error) { + m, err := newMetrics(meter) + if err != nil { + return nil, err + } + + return &Manager{ + store: store, + metrics: m, + }, nil +} + +// Connect registers a new proxy connection in the database +func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { + now := time.Now() + p := &proxy.Proxy{ + ID: proxyID, + ClusterAddress: clusterAddress, + IPAddress: ipAddress, + LastSeen: now, + ConnectedAt: &now, + Status: "connected", + } + + if err := m.store.SaveProxy(ctx, p); err != nil { + log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err) + return err + } + + log.WithContext(ctx).WithFields(log.Fields{ + "proxyID": proxyID, + "clusterAddress": clusterAddress, + "ipAddress": ipAddress, + }).Info("proxy connected") + + return nil +} + +// Disconnect marks a proxy as disconnected in the database +func (m Manager) Disconnect(ctx context.Context, proxyID string) error { + now := time.Now() + p := &proxy.Proxy{ + ID: proxyID, + Status: "disconnected", + DisconnectedAt: &now, + LastSeen: now, + } + + if err := m.store.SaveProxy(ctx, p); err != nil { + log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err) + return err + } + + log.WithContext(ctx).WithFields(log.Fields{ + "proxyID": proxyID, + }).Info("proxy disconnected") + + return nil +} + +// Heartbeat updates the proxy's last seen timestamp +func (m Manager) Heartbeat(ctx context.Context, proxyID string) error { + if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil { + log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err) + return err + } + m.metrics.IncrementProxyHeartbeatCount() + return nil +} + +// GetActiveClusterAddresses returns all unique cluster addresses for active proxies +func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { + addresses, err := m.store.GetActiveProxyClusterAddresses(ctx) + if err != nil { + log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err) + return nil, err + } + return addresses, nil +} + +// CleanupStale removes proxies that haven't sent heartbeat in the specified duration +func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { + if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil { + log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err) + return err + } + return nil +} diff --git a/management/internals/modules/reverseproxy/proxy/manager/metrics.go b/management/internals/modules/reverseproxy/proxy/manager/metrics.go new file mode 100644 index 000000000..2b402cead --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/manager/metrics.go @@ -0,0 +1,74 @@ +package manager + +import ( + "context" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" +) + +type metrics struct { + proxyConnectionCount metric.Int64UpDownCounter + serviceUpdateSendCount metric.Int64Counter + proxyHeartbeatCount metric.Int64Counter +} + +func newMetrics(meter metric.Meter) (*metrics, error) { + proxyConnectionCount, err := meter.Int64UpDownCounter( + "management_proxy_connection_count", + metric.WithDescription("Number of active proxy connections"), + metric.WithUnit("{connection}"), + ) + if err != nil { + return nil, err + } + + serviceUpdateSendCount, err := meter.Int64Counter( + "management_proxy_service_update_send_count", + metric.WithDescription("Total number of service updates sent to proxies"), + metric.WithUnit("{update}"), + ) + if err != nil { + return nil, err + } + + proxyHeartbeatCount, err := meter.Int64Counter( + "management_proxy_heartbeat_count", + metric.WithDescription("Total number of proxy heartbeats received"), + metric.WithUnit("{heartbeat}"), + ) + if err != nil { + return nil, err + } + + return &metrics{ + proxyConnectionCount: proxyConnectionCount, + serviceUpdateSendCount: serviceUpdateSendCount, + proxyHeartbeatCount: proxyHeartbeatCount, + }, nil +} + +func (m *metrics) IncrementProxyConnectionCount(clusterAddr string) { + m.proxyConnectionCount.Add(context.Background(), 1, + metric.WithAttributes( + attribute.String("cluster", clusterAddr), + )) +} + +func (m *metrics) DecrementProxyConnectionCount(clusterAddr string) { + m.proxyConnectionCount.Add(context.Background(), -1, + metric.WithAttributes( + attribute.String("cluster", clusterAddr), + )) +} + +func (m *metrics) IncrementServiceUpdateSendCount(clusterAddr string) { + m.serviceUpdateSendCount.Add(context.Background(), 1, + metric.WithAttributes( + attribute.String("cluster", clusterAddr), + )) +} + +func (m *metrics) IncrementProxyHeartbeatCount() { + m.proxyHeartbeatCount.Add(context.Background(), 1) +} diff --git a/management/internals/modules/reverseproxy/proxy/manager_mock.go b/management/internals/modules/reverseproxy/proxy/manager_mock.go new file mode 100644 index 000000000..d9645ba88 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/manager_mock.go @@ -0,0 +1,199 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./manager.go + +// Package proxy is a generated GoMock package. +package proxy + +import ( + context "context" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + proto "github.com/netbirdio/netbird/shared/management/proto" +) + +// MockManager is a mock of Manager interface. +type MockManager struct { + ctrl *gomock.Controller + recorder *MockManagerMockRecorder +} + +// MockManagerMockRecorder is the mock recorder for MockManager. +type MockManagerMockRecorder struct { + mock *MockManager +} + +// NewMockManager creates a new mock instance. +func NewMockManager(ctrl *gomock.Controller) *MockManager { + mock := &MockManager{ctrl: ctrl} + mock.recorder = &MockManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockManager) EXPECT() *MockManagerMockRecorder { + return m.recorder +} + +// CleanupStale mocks base method. +func (m *MockManager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CleanupStale", ctx, inactivityDuration) + ret0, _ := ret[0].(error) + return ret0 +} + +// CleanupStale indicates an expected call of CleanupStale. +func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration) +} + +// Connect mocks base method. +func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress) + ret0, _ := ret[0].(error) + return ret0 +} + +// Connect indicates an expected call of Connect. +func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress) +} + +// Disconnect mocks base method. +func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// Disconnect indicates an expected call of Disconnect. +func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID) +} + +// GetActiveClusterAddresses mocks base method. +func (m *MockManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveClusterAddresses", ctx) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveClusterAddresses indicates an expected call of GetActiveClusterAddresses. +func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx) +} + +// Heartbeat mocks base method. +func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// Heartbeat indicates an expected call of Heartbeat. +func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID) +} + +// MockController is a mock of Controller interface. +type MockController struct { + ctrl *gomock.Controller + recorder *MockControllerMockRecorder +} + +// MockControllerMockRecorder is the mock recorder for MockController. +type MockControllerMockRecorder struct { + mock *MockController +} + +// NewMockController creates a new mock instance. +func NewMockController(ctrl *gomock.Controller) *MockController { + mock := &MockController{ctrl: ctrl} + mock.recorder = &MockControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockController) EXPECT() *MockControllerMockRecorder { + return m.recorder +} + +// GetOIDCValidationConfig mocks base method. +func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOIDCValidationConfig") + ret0, _ := ret[0].(OIDCValidationConfig) + return ret0 +} + +// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig. +func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig)) +} + +// GetProxiesForCluster mocks base method. +func (m *MockController) GetProxiesForCluster(clusterAddr string) []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr) + ret0, _ := ret[0].([]string) + return ret0 +} + +// GetProxiesForCluster indicates an expected call of GetProxiesForCluster. +func (mr *MockControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockController)(nil).GetProxiesForCluster), clusterAddr) +} + +// RegisterProxyToCluster mocks base method. +func (m *MockController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster. +func (mr *MockControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID) +} + +// SendServiceUpdateToCluster mocks base method. +func (m *MockController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SendServiceUpdateToCluster", ctx, accountID, update, clusterAddr) +} + +// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster. +func (mr *MockControllerMockRecorder) SendServiceUpdateToCluster(ctx, accountID, update, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockController)(nil).SendServiceUpdateToCluster), ctx, accountID, update, clusterAddr) +} + +// UnregisterProxyFromCluster mocks base method. +func (m *MockController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster. +func (mr *MockControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID) +} diff --git a/management/internals/modules/reverseproxy/proxy/proxy.go b/management/internals/modules/reverseproxy/proxy/proxy.go new file mode 100644 index 000000000..699e1ed02 --- /dev/null +++ b/management/internals/modules/reverseproxy/proxy/proxy.go @@ -0,0 +1,20 @@ +package proxy + +import "time" + +// Proxy represents a reverse proxy instance +type Proxy struct { + ID string `gorm:"primaryKey;type:varchar(255)"` + ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"` + IPAddress string `gorm:"type:varchar(45)"` + LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` + ConnectedAt *time.Time + DisconnectedAt *time.Time + Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` + CreatedAt time.Time + UpdatedAt time.Time +} + +func (Proxy) TableName() string { + return "proxies" +} diff --git a/management/internals/modules/reverseproxy/interface.go b/management/internals/modules/reverseproxy/service/interface.go similarity index 88% rename from management/internals/modules/reverseproxy/interface.go rename to management/internals/modules/reverseproxy/service/interface.go index e7a21a24c..b420f22a8 100644 --- a/management/internals/modules/reverseproxy/interface.go +++ b/management/internals/modules/reverseproxy/service/interface.go @@ -1,6 +1,6 @@ -package reverseproxy +package service -//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod +//go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod import ( "context" @@ -14,7 +14,7 @@ type Manager interface { DeleteService(ctx context.Context, accountID, userID, serviceID string) error DeleteAllServices(ctx context.Context, accountID, userID string) error SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error - SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error + SetStatus(ctx context.Context, accountID, serviceID string, status Status) error ReloadAllServicesForAccount(ctx context.Context, accountID string) error ReloadService(ctx context.Context, accountID, serviceID string) error GetGlobalServices(ctx context.Context) ([]*Service, error) diff --git a/management/internals/modules/reverseproxy/interface_mock.go b/management/internals/modules/reverseproxy/service/interface_mock.go similarity index 99% rename from management/internals/modules/reverseproxy/interface_mock.go rename to management/internals/modules/reverseproxy/service/interface_mock.go index 893025195..727b2c7de 100644 --- a/management/internals/modules/reverseproxy/interface_mock.go +++ b/management/internals/modules/reverseproxy/service/interface_mock.go @@ -1,8 +1,8 @@ // Code generated by MockGen. DO NOT EDIT. // Source: ./interface.go -// Package reverseproxy is a generated GoMock package. -package reverseproxy +// Package service is a generated GoMock package. +package service import ( context "context" @@ -239,7 +239,7 @@ func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, servic } // SetStatus mocks base method. -func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error { +func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status) ret0, _ := ret[0].(error) diff --git a/management/internals/modules/reverseproxy/manager/api.go b/management/internals/modules/reverseproxy/service/manager/api.go similarity index 93% rename from management/internals/modules/reverseproxy/manager/api.go rename to management/internals/modules/reverseproxy/service/manager/api.go index 9117ecd38..70b09e603 100644 --- a/management/internals/modules/reverseproxy/manager/api.go +++ b/management/internals/modules/reverseproxy/service/manager/api.go @@ -6,10 +6,10 @@ import ( "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" @@ -17,11 +17,11 @@ import ( ) type handler struct { - manager reverseproxy.Manager + manager rpservice.Manager } // RegisterEndpoints registers all service HTTP endpoints. -func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) { +func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) { h := &handler{ manager: manager, } @@ -72,7 +72,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) { return } - service := new(reverseproxy.Service) + service := new(rpservice.Service) service.FromAPIRequest(&req, userAuth.AccountId) if err = service.Validate(); err != nil { @@ -130,7 +130,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) { return } - service := new(reverseproxy.Service) + service := new(rpservice.Service) service.ID = serviceID service.FromAPIRequest(&req, userAuth.AccountId) diff --git a/management/internals/modules/reverseproxy/manager/expose_tracker.go b/management/internals/modules/reverseproxy/service/manager/expose_tracker.go similarity index 99% rename from management/internals/modules/reverseproxy/manager/expose_tracker.go rename to management/internals/modules/reverseproxy/service/manager/expose_tracker.go index ef285e923..11e1f0110 100644 --- a/management/internals/modules/reverseproxy/manager/expose_tracker.go +++ b/management/internals/modules/reverseproxy/service/manager/expose_tracker.go @@ -27,7 +27,7 @@ type trackedExpose struct { type exposeTracker struct { activeExposes sync.Map exposeCreateMu sync.Mutex - manager *managerImpl + manager *Manager } func exposeKey(peerID, domain string) string { diff --git a/management/internals/modules/reverseproxy/manager/expose_tracker_test.go b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go similarity index 97% rename from management/internals/modules/reverseproxy/manager/expose_tracker_test.go rename to management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go index 2dc726590..154239fb1 100644 --- a/management/internals/modules/reverseproxy/manager/expose_tracker_test.go +++ b/management/internals/modules/reverseproxy/service/manager/expose_tracker_test.go @@ -9,7 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" ) func TestExposeKey(t *testing.T) { @@ -120,7 +120,7 @@ func TestReapExpiredExposes(t *testing.T) { tracker := mgr.exposeTracker ctx := context.Background() - resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ Port: 8080, Protocol: "http", }) @@ -156,7 +156,7 @@ func TestReapExpiredExposes_SetsExpiringFlag(t *testing.T) { tracker := mgr.exposeTracker ctx := context.Background() - resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ Port: 8080, Protocol: "http", }) @@ -191,7 +191,7 @@ func TestConcurrentTrackAndCount(t *testing.T) { ctx := context.Background() for i := range 5 { - _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ Port: 8080 + i, Protocol: "http", }) diff --git a/management/internals/modules/reverseproxy/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go similarity index 65% rename from management/internals/modules/reverseproxy/manager/manager.go rename to management/internals/modules/reverseproxy/service/manager/manager.go index 3c02e117b..16a57abb6 100644 --- a/management/internals/modules/reverseproxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -11,17 +11,15 @@ import ( nbpeer "github.com/netbirdio/netbird/management/server/peer" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" - nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) @@ -33,24 +31,22 @@ type ClusterDeriver interface { GetClusterDomains() []string } -type managerImpl struct { +type Manager struct { store store.Store accountManager account.Manager permissionsManager permissions.Manager - settingsManager settings.Manager - proxyGRPCServer *nbgrpc.ProxyServiceServer + proxyController proxy.Controller clusterDeriver ClusterDeriver exposeTracker *exposeTracker } // NewManager creates a new service manager. -func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, settingsManager settings.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager { - mgr := &managerImpl{ +func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager { + mgr := &Manager{ store: store, accountManager: accountManager, permissionsManager: permissionsManager, - settingsManager: settingsManager, - proxyGRPCServer: proxyGRPCServer, + proxyController: proxyController, clusterDeriver: clusterDeriver, } mgr.exposeTracker = &exposeTracker{manager: mgr} @@ -58,11 +54,11 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa } // StartExposeReaper delegates to the expose tracker. -func (m *managerImpl) StartExposeReaper(ctx context.Context) { +func (m *Manager) StartExposeReaper(ctx context.Context) { m.exposeTracker.StartExposeReaper(ctx) } -func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { +func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -86,34 +82,34 @@ func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID stri return services, nil } -func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error { - for _, target := range service.Targets { +func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error { + for _, target := range s.Targets { switch target.TargetType { - case reverseproxy.TargetTypePeer: + case service.TargetTypePeer: peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) if err != nil { - log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err) + log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err) target.Host = unknownHostPlaceholder continue } target.Host = peer.IP.String() - case reverseproxy.TargetTypeHost: + case service.TargetTypeHost: resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) if err != nil { - log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err) + log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err) target.Host = unknownHostPlaceholder continue } target.Host = resource.Prefix.Addr().String() - case reverseproxy.TargetTypeDomain: + case service.TargetTypeDomain: resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) if err != nil { - log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err) + log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err) target.Host = unknownHostPlaceholder continue } target.Host = resource.Domain - case reverseproxy.TargetTypeSubnet: + case service.TargetTypeSubnet: // For subnets we do not do any lookups on the resource default: return fmt.Errorf("unknown target type: %s", target.TargetType) @@ -122,7 +118,7 @@ func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, return nil } -func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) { +func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -143,7 +139,7 @@ func (m *managerImpl) GetService(ctx context.Context, accountID, userID, service return service, nil } -func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -152,29 +148,29 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin return nil, status.NewPermissionDeniedError() } - if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil { + if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil { return nil, err } - if err := m.persistNewService(ctx, accountID, service); err != nil { + if err := m.persistNewService(ctx, accountID, s); err != nil { return nil, err } - m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta()) + m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta()) - err = m.replaceHostByLookup(ctx, accountID, service) + err = m.replaceHostByLookup(ctx, accountID, s) if err != nil { - return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) + return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err) } - m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "") + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) - return service, nil + return s, nil } -func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error { +func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error { if m.clusterDeriver != nil { proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain) if err != nil { @@ -201,7 +197,7 @@ func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID return nil } -func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error { +func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil { return err @@ -219,7 +215,7 @@ func (m *managerImpl) persistNewService(ctx context.Context, accountID string, s }) } -func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error { +func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error { existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain) if err != nil { if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { @@ -235,7 +231,7 @@ func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction stor return nil } -func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -259,7 +255,7 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) } - m.sendServiceUpdateNotifications(service, updateInfo) + m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo) m.accountManager.UpdateAccountPeers(ctx, accountID) return service, nil @@ -271,7 +267,7 @@ type serviceUpdateInfo struct { serviceEnabledChanged bool } -func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) { +func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) { var updateInfo serviceUpdateInfo err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -309,7 +305,7 @@ func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string return &updateInfo, err } -func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error { +func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error { if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil { return err } @@ -326,7 +322,7 @@ func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store. return nil } -func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) { +func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) { if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled && existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled && service.Auth.PasswordAuth.Password == "" { @@ -340,54 +336,40 @@ func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reve } } -func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) { +func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) { service.Meta = existingService.Meta service.SessionPrivateKey = existingService.SessionPrivateKey service.SessionPublicKey = existingService.SessionPublicKey } -func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) { +func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) { + oidcCfg := m.proxyController.GetOIDCValidationConfig() + switch { - case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster: - m.sendServiceUpdate(service, reverseproxy.Delete, updateInfo.oldCluster, "") - m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "") - case !service.Enabled && updateInfo.serviceEnabledChanged: - m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "") - case service.Enabled && updateInfo.serviceEnabledChanged: - m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "") + case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster: + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster) + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster) + case !s.Enabled && updateInfo.serviceEnabledChanged: + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster) + case s.Enabled && updateInfo.serviceEnabledChanged: + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster) default: - m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "") + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster) } } -func (m *managerImpl) sendServiceUpdate(service *reverseproxy.Service, operation reverseproxy.Operation, cluster, oldService string) { - oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() - mapping := service.ToProtoMapping(operation, oldService, oidcCfg) - m.sendMappingsToCluster([]*proto.ProxyMapping{mapping}, cluster) -} - -func (m *managerImpl) sendMappingsToCluster(mappings []*proto.ProxyMapping, cluster string) { - if len(mappings) == 0 { - return - } - update := &proto.GetMappingUpdateResponse{ - Mapping: mappings, - } - m.proxyGRPCServer.SendServiceUpdateToCluster(update, cluster) -} - // validateTargetReferences checks that all target IDs reference existing peers or resources in the account. -func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error { +func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error { for _, target := range targets { switch target.TargetType { - case reverseproxy.TargetTypePeer: + case service.TargetTypePeer: if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId) } return fmt.Errorf("look up peer target %q: %w", target.TargetId, err) } - case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain: + case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain: if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId) @@ -399,7 +381,7 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco return nil } -func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { +func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) @@ -408,9 +390,10 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv return status.NewPermissionDeniedError() } - var service *reverseproxy.Service + var s *service.Service err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) + var err error + s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) if err != nil { return err } @@ -429,20 +412,20 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv return err } - if service.Source == reverseproxy.SourceEphemeral { - m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain) + if s.Source == service.SourceEphemeral { + m.exposeTracker.UntrackExpose(s.SourcePeer, s.Domain) } - m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta()) + m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta()) - m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "") + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) return nil } -func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID string) error { +func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) @@ -451,16 +434,16 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s return status.NewPermissionDeniedError() } - var services []*reverseproxy.Service + var services []*service.Service err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var err error - services, err = transaction.GetServicesByAccountID(ctx, store.LockingStrengthUpdate, accountID) + services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return err } - for _, service := range services { - if err = transaction.DeleteService(ctx, accountID, service.ID); err != nil { + for _, svc := range services { + if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil { return fmt.Errorf("failed to delete service: %w", err) } } @@ -471,20 +454,14 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s return err } - clusterMappings := make(map[string][]*proto.ProxyMapping) - oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() + oidcCfg := m.proxyController.GetOIDCValidationConfig() - for _, service := range services { - if service.Source == reverseproxy.SourceEphemeral { - m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain) + for _, svc := range services { + if svc.Source == service.SourceEphemeral { + m.exposeTracker.UntrackExpose(svc.SourcePeer, svc.Domain) } - m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, service.EventMeta()) - mapping := service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg) - clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping) - } - - for cluster, mappings := range clusterMappings { - m.sendMappingsToCluster(mappings, cluster) + m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta()) + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster) } m.accountManager.UpdateAccountPeers(ctx, accountID) @@ -494,7 +471,7 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s // SetCertificateIssuedAt sets the certificate issued timestamp to the current time. // Call this when receiving a gRPC notification that the certificate was issued. -func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error { +func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) if err != nil { @@ -513,7 +490,7 @@ func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, ser } // SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.) -func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error { +func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) if err != nil { @@ -530,50 +507,42 @@ func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string }) } -func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error { - service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) +func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error { + s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) if err != nil { return fmt.Errorf("failed to get service: %w", err) } - err = m.replaceHostByLookup(ctx, accountID, service) + err = m.replaceHostByLookup(ctx, accountID, s) if err != nil { - return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) + return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err) } - m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "") + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) return nil } -func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error { +func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error { services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) if err != nil { return fmt.Errorf("failed to get services: %w", err) } - clusterMappings := make(map[string][]*proto.ProxyMapping) - oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() - - for _, service := range services { - err = m.replaceHostByLookup(ctx, accountID, service) + for _, s := range services { + err = m.replaceHostByLookup(ctx, accountID, s) if err != nil { - return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) + return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err) } - mapping := service.ToProtoMapping(reverseproxy.Update, "", oidcCfg) - clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping) - } - - for cluster, mappings := range clusterMappings { - m.sendMappingsToCluster(mappings, cluster) + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) } return nil } -func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { +func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) { services, err := m.store.GetServices(ctx, store.LockingStrengthNone) if err != nil { return nil, fmt.Errorf("failed to get services: %w", err) @@ -589,7 +558,7 @@ func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Se return services, nil } -func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) { +func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) { service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) if err != nil { return nil, fmt.Errorf("failed to get service: %w", err) @@ -603,7 +572,7 @@ func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID s return service, nil } -func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { +func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) { services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, fmt.Errorf("failed to get services: %w", err) @@ -619,7 +588,7 @@ func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) return services, nil } -func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) { +func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) { target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { @@ -637,7 +606,7 @@ func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID stri // validateExposePermission checks whether the peer is allowed to use the expose feature. // It verifies the account has peer expose enabled and that the peer belongs to an allowed group. -func (m *managerImpl) validateExposePermission(ctx context.Context, accountID, peerID string) error { +func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerID string) error { settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get account settings: %v", err) @@ -670,7 +639,7 @@ func (m *managerImpl) validateExposePermission(ctx context.Context, accountID, p // CreateServiceFromPeer creates a service initiated by a peer expose request. // It validates the request, checks expose permissions, enforces the per-peer limit, // creates the service, and tracks it for TTL-based reaping. -func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { +func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) { if err := req.Validate(); err != nil { return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err) } @@ -679,31 +648,31 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer return nil, err } - serviceName, err := reverseproxy.GenerateExposeName(req.NamePrefix) + serviceName, err := service.GenerateExposeName(req.NamePrefix) if err != nil { return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err) } - service := req.ToService(accountID, peerID, serviceName) - service.Source = reverseproxy.SourceEphemeral + svc := req.ToService(accountID, peerID, serviceName) + svc.Source = service.SourceEphemeral - if service.Domain == "" { - domain, err := m.buildRandomDomain(service.Name) + if svc.Domain == "" { + domain, err := m.buildRandomDomain(svc.Name) if err != nil { - return nil, fmt.Errorf("build random domain for service %s: %w", service.Name, err) + return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err) } - service.Domain = domain + svc.Domain = domain } - if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled { - groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, service.Auth.BearerAuth.DistributionGroups) + if svc.Auth.BearerAuth != nil && svc.Auth.BearerAuth.Enabled { + groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, svc.Auth.BearerAuth.DistributionGroups) if err != nil { - return nil, fmt.Errorf("get group ids for service %s: %w", service.Name, err) + return nil, fmt.Errorf("get group ids for service %s: %w", svc.Name, err) } - service.Auth.BearerAuth.DistributionGroups = groupIDs + svc.Auth.BearerAuth.DistributionGroups = groupIDs } - if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil { + if err := m.initializeServiceForCreate(ctx, accountID, svc); err != nil { return nil, err } @@ -713,45 +682,45 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer } now := time.Now() - service.Meta.LastRenewedAt = &now - service.SourcePeer = peerID + svc.Meta.LastRenewedAt = &now + svc.SourcePeer = peerID - if err := m.persistNewService(ctx, accountID, service); err != nil { + if err := m.persistNewService(ctx, accountID, svc); err != nil { return nil, err } - alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, service.Domain, accountID) + alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, svc.Domain, accountID) if alreadyTracked { - if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil { - log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", service.Domain, err) + if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil { + log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", svc.Domain, err) } return nil, status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain") } if !allowed { - if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil { - log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", service.Domain, err) + if err := m.deleteServiceFromPeer(ctx, accountID, peerID, svc.Domain, false); err != nil { + log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", svc.Domain, err) } return nil, status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) } - meta := addPeerInfoToEventMeta(service.EventMeta(), peer) - m.accountManager.StoreEvent(ctx, peerID, service.ID, accountID, activity.PeerServiceExposed, meta) + meta := addPeerInfoToEventMeta(svc.EventMeta(), peer) + m.accountManager.StoreEvent(ctx, peerID, svc.ID, accountID, activity.PeerServiceExposed, meta) - if err := m.replaceHostByLookup(ctx, accountID, service); err != nil { - return nil, fmt.Errorf("replace host by lookup for service %s: %w", service.ID, err) + if err := m.replaceHostByLookup(ctx, accountID, svc); err != nil { + return nil, fmt.Errorf("replace host by lookup for service %s: %w", svc.ID, err) } - m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "") + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) - return &reverseproxy.ExposeServiceResponse{ - ServiceName: service.Name, - ServiceURL: "https://" + service.Domain, - Domain: service.Domain, + return &service.ExposeServiceResponse{ + ServiceName: svc.Name, + ServiceURL: "https://" + svc.Domain, + Domain: svc.Domain, }, nil } -func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) { +func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) { if len(groupNames) == 0 { return []string{}, fmt.Errorf("no group names provided") } @@ -766,7 +735,7 @@ func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string return groupIDs, nil } -func (m *managerImpl) buildRandomDomain(name string) (string, error) { +func (m *Manager) buildRandomDomain(name string) (string, error) { if m.clusterDeriver == nil { return "", fmt.Errorf("unable to get random domain") } @@ -781,7 +750,7 @@ func (m *managerImpl) buildRandomDomain(name string) (string, error) { // RenewServiceFromPeer renews the in-memory TTL tracker for the peer's expose session. // Returns an error if the expose is not actively tracked. -func (m *managerImpl) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error { +func (m *Manager) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error { if !m.exposeTracker.RenewTrackedExpose(peerID, domain) { return status.Errorf(status.NotFound, "no active expose session for domain %s", domain) } @@ -789,7 +758,7 @@ func (m *managerImpl) RenewServiceFromPeer(_ context.Context, _, peerID, domain } // StopServiceFromPeer stops a peer's active expose session by untracking and deleting the service. -func (m *managerImpl) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { +func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil { log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err) return err @@ -804,8 +773,8 @@ func (m *managerImpl) StopServiceFromPeer(ctx context.Context, accountID, peerID // deleteServiceFromPeer deletes a peer-initiated service identified by domain. // When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed. -func (m *managerImpl) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error { - service, err := m.lookupPeerService(ctx, accountID, peerID, domain) +func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error { + svc, err := m.lookupPeerService(ctx, accountID, peerID, domain) if err != nil { return err } @@ -814,41 +783,41 @@ func (m *managerImpl) deleteServiceFromPeer(ctx context.Context, accountID, peer if expired { activityCode = activity.PeerServiceExposeExpired } - return m.deletePeerService(ctx, accountID, peerID, service.ID, activityCode) + return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode) } // lookupPeerService finds a peer-initiated service by domain and validates ownership. -func (m *managerImpl) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*reverseproxy.Service, error) { - service, err := m.store.GetServiceByDomain(ctx, accountID, domain) +func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) { + svc, err := m.store.GetServiceByDomain(ctx, accountID, domain) if err != nil { return nil, err } - if service.Source != reverseproxy.SourceEphemeral { + if svc.Source != service.SourceEphemeral { return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose") } - if service.SourcePeer != peerID { + if svc.SourcePeer != peerID { return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer") } - return service, nil + return svc, nil } -func (m *managerImpl) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error { - var service *reverseproxy.Service +func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error { + var svc *service.Service err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var err error - service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) + svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) if err != nil { return err } - if service.Source != reverseproxy.SourceEphemeral { + if svc.Source != service.SourceEphemeral { return status.Errorf(status.PermissionDenied, "cannot delete API-created service via peer expose") } - if service.SourcePeer != peerID { + if svc.SourcePeer != peerID { return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer") } @@ -868,11 +837,11 @@ func (m *managerImpl) deletePeerService(ctx context.Context, accountID, peerID, peer = nil } - meta := addPeerInfoToEventMeta(service.EventMeta(), peer) + meta := addPeerInfoToEventMeta(svc.EventMeta(), peer) m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activityCode, meta) - m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "") + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) diff --git a/management/internals/modules/reverseproxy/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go similarity index 83% rename from management/internals/modules/reverseproxy/manager/manager_test.go rename to management/internals/modules/reverseproxy/service/manager/manager_test.go index 8e6b0e876..99409e235 100644 --- a/management/internals/modules/reverseproxy/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -10,21 +10,21 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/integrations/extra_settings" "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" - "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/shared/management/status" ) @@ -33,13 +33,13 @@ func TestInitializeServiceForCreate(t *testing.T) { accountID := "test-account" t.Run("successful initialization without cluster deriver", func(t *testing.T) { - mgr := &managerImpl{ + mgr := &Manager{ clusterDeriver: nil, } - service := &reverseproxy.Service{ + service := &rpservice.Service{ Domain: "example.com", - Auth: reverseproxy.AuthConfig{}, + Auth: rpservice.AuthConfig{}, } err := mgr.initializeServiceForCreate(ctx, accountID, service) @@ -53,12 +53,12 @@ func TestInitializeServiceForCreate(t *testing.T) { }) t.Run("verifies session keys are different", func(t *testing.T) { - mgr := &managerImpl{ + mgr := &Manager{ clusterDeriver: nil, } - service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}} - service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}} + service1 := &rpservice.Service{Domain: "test1.com", Auth: rpservice.AuthConfig{}} + service2 := &rpservice.Service{Domain: "test2.com", Auth: rpservice.AuthConfig{}} err1 := mgr.initializeServiceForCreate(ctx, accountID, service1) err2 := mgr.initializeServiceForCreate(ctx, accountID, service2) @@ -100,7 +100,7 @@ func TestCheckDomainAvailable(t *testing.T) { setupMock: func(ms *store.MockStore) { ms.EXPECT(). GetServiceByDomain(ctx, accountID, "exists.com"). - Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil) + Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil) }, expectedError: true, errorType: status.AlreadyExists, @@ -112,7 +112,7 @@ func TestCheckDomainAvailable(t *testing.T) { setupMock: func(ms *store.MockStore) { ms.EXPECT(). GetServiceByDomain(ctx, accountID, "exists.com"). - Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil) + Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil) }, expectedError: false, }, @@ -123,7 +123,7 @@ func TestCheckDomainAvailable(t *testing.T) { setupMock: func(ms *store.MockStore) { ms.EXPECT(). GetServiceByDomain(ctx, accountID, "exists.com"). - Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil) + Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil) }, expectedError: true, errorType: status.AlreadyExists, @@ -149,7 +149,7 @@ func TestCheckDomainAvailable(t *testing.T) { mockStore := store.NewMockStore(ctrl) tt.setupMock(mockStore) - mgr := &managerImpl{} + mgr := &Manager{} err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID) if tt.expectedError { @@ -179,7 +179,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) { GetServiceByDomain(ctx, accountID, ""). Return(nil, status.Errorf(status.NotFound, "not found")) - mgr := &managerImpl{} + mgr := &Manager{} err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "") assert.NoError(t, err) @@ -192,9 +192,9 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) { mockStore := store.NewMockStore(ctrl) mockStore.EXPECT(). GetServiceByDomain(ctx, accountID, "test.com"). - Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil) + Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil) - mgr := &managerImpl{} + mgr := &Manager{} err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "") assert.Error(t, err) @@ -212,7 +212,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) { GetServiceByDomain(ctx, accountID, "nil.com"). Return(nil, nil) - mgr := &managerImpl{} + mgr := &Manager{} err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "") assert.NoError(t, err) @@ -228,10 +228,10 @@ func TestPersistNewService(t *testing.T) { defer ctrl.Finish() mockStore := store.NewMockStore(ctrl) - service := &reverseproxy.Service{ + service := &rpservice.Service{ ID: "service-123", Domain: "new.com", - Targets: []*reverseproxy.Target{}, + Targets: []*rpservice.Target{}, } // Mock ExecuteInTransaction to execute the function immediately @@ -250,7 +250,7 @@ func TestPersistNewService(t *testing.T) { return fn(txMock) }) - mgr := &managerImpl{store: mockStore} + mgr := &Manager{store: mockStore} err := mgr.persistNewService(ctx, accountID, service) assert.NoError(t, err) @@ -261,10 +261,10 @@ func TestPersistNewService(t *testing.T) { defer ctrl.Finish() mockStore := store.NewMockStore(ctrl) - service := &reverseproxy.Service{ + service := &rpservice.Service{ ID: "service-123", Domain: "existing.com", - Targets: []*reverseproxy.Target{}, + Targets: []*rpservice.Target{}, } mockStore.EXPECT(). @@ -273,12 +273,12 @@ func TestPersistNewService(t *testing.T) { txMock := store.NewMockStore(ctrl) txMock.EXPECT(). GetServiceByDomain(ctx, accountID, "existing.com"). - Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil) + Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil) return fn(txMock) }) - mgr := &managerImpl{store: mockStore} + mgr := &Manager{store: mockStore} err := mgr.persistNewService(ctx, accountID, service) require.Error(t, err) @@ -288,21 +288,21 @@ func TestPersistNewService(t *testing.T) { }) } func TestPreserveExistingAuthSecrets(t *testing.T) { - mgr := &managerImpl{} + mgr := &Manager{} t.Run("preserve password when empty", func(t *testing.T) { - existing := &reverseproxy.Service{ - Auth: reverseproxy.AuthConfig{ - PasswordAuth: &reverseproxy.PasswordAuthConfig{ + existing := &rpservice.Service{ + Auth: rpservice.AuthConfig{ + PasswordAuth: &rpservice.PasswordAuthConfig{ Enabled: true, Password: "hashed-password", }, }, } - updated := &reverseproxy.Service{ - Auth: reverseproxy.AuthConfig{ - PasswordAuth: &reverseproxy.PasswordAuthConfig{ + updated := &rpservice.Service{ + Auth: rpservice.AuthConfig{ + PasswordAuth: &rpservice.PasswordAuthConfig{ Enabled: true, Password: "", }, @@ -315,18 +315,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) { }) t.Run("preserve pin when empty", func(t *testing.T) { - existing := &reverseproxy.Service{ - Auth: reverseproxy.AuthConfig{ - PinAuth: &reverseproxy.PINAuthConfig{ + existing := &rpservice.Service{ + Auth: rpservice.AuthConfig{ + PinAuth: &rpservice.PINAuthConfig{ Enabled: true, Pin: "hashed-pin", }, }, } - updated := &reverseproxy.Service{ - Auth: reverseproxy.AuthConfig{ - PinAuth: &reverseproxy.PINAuthConfig{ + updated := &rpservice.Service{ + Auth: rpservice.AuthConfig{ + PinAuth: &rpservice.PINAuthConfig{ Enabled: true, Pin: "", }, @@ -339,18 +339,18 @@ func TestPreserveExistingAuthSecrets(t *testing.T) { }) t.Run("do not preserve when password is provided", func(t *testing.T) { - existing := &reverseproxy.Service{ - Auth: reverseproxy.AuthConfig{ - PasswordAuth: &reverseproxy.PasswordAuthConfig{ + existing := &rpservice.Service{ + Auth: rpservice.AuthConfig{ + PasswordAuth: &rpservice.PasswordAuthConfig{ Enabled: true, Password: "old-password", }, }, } - updated := &reverseproxy.Service{ - Auth: reverseproxy.AuthConfig{ - PasswordAuth: &reverseproxy.PasswordAuthConfig{ + updated := &rpservice.Service{ + Auth: rpservice.AuthConfig{ + PasswordAuth: &rpservice.PasswordAuthConfig{ Enabled: true, Password: "new-password", }, @@ -365,10 +365,10 @@ func TestPreserveExistingAuthSecrets(t *testing.T) { } func TestPreserveServiceMetadata(t *testing.T) { - mgr := &managerImpl{} + mgr := &Manager{} - existing := &reverseproxy.Service{ - Meta: reverseproxy.ServiceMeta{ + existing := &rpservice.Service{ + Meta: rpservice.Meta{ CertificateIssuedAt: func() *time.Time { t := time.Now(); return &t }(), Status: "active", }, @@ -376,7 +376,7 @@ func TestPreserveServiceMetadata(t *testing.T) { SessionPublicKey: "public-key", } - updated := &reverseproxy.Service{ + updated := &rpservice.Service{ Domain: "updated.com", } @@ -400,31 +400,32 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { IP: net.ParseIP("100.64.0.1"), } - newEphemeralService := func() *reverseproxy.Service { - return &reverseproxy.Service{ + newEphemeralService := func() *rpservice.Service { + return &rpservice.Service{ ID: serviceID, AccountID: accountID, Name: "test-service", Domain: "test.example.com", - Source: reverseproxy.SourceEphemeral, + Source: rpservice.SourceEphemeral, SourcePeer: ownerPeerID, } } - newPermanentService := func() *reverseproxy.Service { - return &reverseproxy.Service{ + newPermanentService := func() *rpservice.Service { + return &rpservice.Service{ ID: serviceID, AccountID: accountID, Name: "api-service", Domain: "api.example.com", - Source: reverseproxy.SourcePermanent, + Source: rpservice.SourcePermanent, } } newProxyServer := func(t *testing.T) *nbgrpc.ProxyServiceServer { t.Helper() - tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Hour) - srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil) + tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) t.Cleanup(srv.Close) return srv } @@ -458,10 +459,14 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID). Return(testPeer, nil) - mgr := &managerImpl{ - store: mockStore, - accountManager: mockAccountMgr, - proxyGRPCServer: newProxyServer(t), + mgr := &Manager{ + store: mockStore, + accountManager: mockAccountMgr, + proxyController: func() proxy.Controller { + c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter("")) + require.NoError(t, err) + return c + }(), } err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed) @@ -485,7 +490,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { return fn(txMock) }) - mgr := &managerImpl{ + mgr := &Manager{ store: mockStore, } @@ -514,7 +519,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { return fn(txMock) }) - mgr := &managerImpl{ + mgr := &Manager{ store: mockStore, } @@ -556,10 +561,14 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID). Return(testPeer, nil) - mgr := &managerImpl{ - store: mockStore, - accountManager: mockAccountMgr, - proxyGRPCServer: newProxyServer(t), + mgr := &Manager{ + store: mockStore, + accountManager: mockAccountMgr, + proxyController: func() proxy.Controller { + c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter("")) + require.NoError(t, err) + return c + }(), } err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceExposeExpired) @@ -596,10 +605,14 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { GetPeerByID(ctx, store.LockingStrengthNone, accountID, ownerPeerID). Return(testPeer, nil) - mgr := &managerImpl{ - store: mockStore, - accountManager: mockAccountMgr, - proxyGRPCServer: newProxyServer(t), + mgr := &Manager{ + store: mockStore, + accountManager: mockAccountMgr, + proxyController: func() proxy.Controller { + c, err := proxymanager.NewGRPCController(newProxyServer(t), noop.NewMeterProvider().Meter("")) + require.NoError(t, err) + return c + }(), } err := mgr.deletePeerService(ctx, accountID, ownerPeerID, serviceID, activity.PeerServiceUnexposed) @@ -612,19 +625,6 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { }) } -// noopExtraSettings is a minimal extra_settings.Manager for tests without external integrations. -type noopExtraSettings struct{} - -func (n *noopExtraSettings) GetExtraSettings(_ context.Context, _ string) (*types.ExtraSettings, error) { - return &types.ExtraSettings{}, nil -} - -func (n *noopExtraSettings) UpdateExtraSettings(_ context.Context, _, _ string, _ *types.ExtraSettings) (bool, error) { - return false, nil -} - -var _ extra_settings.Manager = (*noopExtraSettings)(nil) - // testClusterDeriver is a minimal ClusterDeriver that returns a fixed domain list. type testClusterDeriver struct { domains []string @@ -646,7 +646,7 @@ const ( ) // setupIntegrationTest creates a real SQLite store with seeded test data for integration tests. -func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) { +func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { t.Helper() ctx := context.Background() @@ -694,30 +694,28 @@ func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) { require.NoError(t, err) permsMgr := permissions.NewManager(testStore) - usersMgr := users.NewManager(testStore) - settingsMgr := settings.NewManager(testStore, usersMgr, &noopExtraSettings{}, permsMgr, settings.IdpConfig{}) - var storedEvents []activity.Activity accountMgr := &mock_server.MockAccountManager{ - StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) { - storedEvents = append(storedEvents, activityID.(activity.Activity)) - }, + StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {}, GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) { return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID) }, } - tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Hour) - proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil) + tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) t.Cleanup(proxySrv.Close) - mgr := &managerImpl{ + proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) + require.NoError(t, err) + + mgr := &Manager{ store: testStore, accountManager: accountMgr, permissionsManager: permsMgr, - settingsManager: settingsMgr, - proxyGRPCServer: proxySrv, + proxyController: proxyController, clusterDeriver: &testClusterDeriver{ domains: []string{"test.netbird.io"}, }, @@ -791,7 +789,7 @@ func Test_validateExposePermission(t *testing.T) { ctrl := gomock.NewController(t) mockStore := store.NewMockStore(ctrl) mockStore.EXPECT().GetAccountSettings(gomock.Any(), gomock.Any(), testAccountID).Return(nil, errors.New("store error")) - mgr := &managerImpl{store: mockStore} + mgr := &Manager{store: mockStore} err := mgr.validateExposePermission(ctx, testAccountID, testPeerID) require.Error(t, err) assert.Contains(t, err.Error(), "get account settings") @@ -804,7 +802,7 @@ func TestCreateServiceFromPeer(t *testing.T) { t.Run("creates service with random domain", func(t *testing.T) { mgr, testStore := setupIntegrationTest(t) - req := &reverseproxy.ExposeServiceRequest{ + req := &rpservice.ExposeServiceRequest{ Port: 8080, Protocol: "http", } @@ -819,7 +817,7 @@ func TestCreateServiceFromPeer(t *testing.T) { persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) require.NoError(t, err) assert.Equal(t, resp.Domain, persisted.Domain) - assert.Equal(t, reverseproxy.SourceEphemeral, persisted.Source, "source should be ephemeral") + assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral") assert.Equal(t, testPeerID, persisted.SourcePeer, "source peer should be set") assert.NotNil(t, persisted.Meta.LastRenewedAt, "last renewed should be set") }) @@ -827,7 +825,7 @@ func TestCreateServiceFromPeer(t *testing.T) { t.Run("creates service with custom domain", func(t *testing.T) { mgr, _ := setupIntegrationTest(t) - req := &reverseproxy.ExposeServiceRequest{ + req := &rpservice.ExposeServiceRequest{ Port: 80, Protocol: "http", Domain: "example.com", @@ -848,7 +846,7 @@ func TestCreateServiceFromPeer(t *testing.T) { err = testStore.SaveAccountSettings(ctx, testAccountID, s) require.NoError(t, err) - req := &reverseproxy.ExposeServiceRequest{ + req := &rpservice.ExposeServiceRequest{ Port: 8080, Protocol: "http", } @@ -861,7 +859,7 @@ func TestCreateServiceFromPeer(t *testing.T) { t.Run("validates request fields", func(t *testing.T) { mgr, _ := setupIntegrationTest(t) - req := &reverseproxy.ExposeServiceRequest{ + req := &rpservice.ExposeServiceRequest{ Port: 0, Protocol: "http", } @@ -875,67 +873,67 @@ func TestCreateServiceFromPeer(t *testing.T) { func TestExposeServiceRequestValidate(t *testing.T) { tests := []struct { name string - req reverseproxy.ExposeServiceRequest + req rpservice.ExposeServiceRequest wantErr string }{ { name: "valid http request", - req: reverseproxy.ExposeServiceRequest{Port: 8080, Protocol: "http"}, + req: rpservice.ExposeServiceRequest{Port: 8080, Protocol: "http"}, wantErr: "", }, { name: "valid https request with pin", - req: reverseproxy.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"}, + req: rpservice.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"}, wantErr: "", }, { name: "port zero rejected", - req: reverseproxy.ExposeServiceRequest{Port: 0, Protocol: "http"}, + req: rpservice.ExposeServiceRequest{Port: 0, Protocol: "http"}, wantErr: "port must be between 1 and 65535", }, { name: "negative port rejected", - req: reverseproxy.ExposeServiceRequest{Port: -1, Protocol: "http"}, + req: rpservice.ExposeServiceRequest{Port: -1, Protocol: "http"}, wantErr: "port must be between 1 and 65535", }, { name: "port above 65535 rejected", - req: reverseproxy.ExposeServiceRequest{Port: 65536, Protocol: "http"}, + req: rpservice.ExposeServiceRequest{Port: 65536, Protocol: "http"}, wantErr: "port must be between 1 and 65535", }, { name: "unsupported protocol", - req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "tcp"}, + req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "tcp"}, wantErr: "unsupported protocol", }, { name: "invalid pin format", - req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"}, + req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"}, wantErr: "invalid pin", }, { name: "pin too short", - req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"}, + req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"}, wantErr: "invalid pin", }, { name: "valid 6-digit pin", - req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"}, + req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"}, wantErr: "", }, { name: "empty user group name", - req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}}, + req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}}, wantErr: "user group name cannot be empty", }, { name: "invalid name prefix", - req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"}, + req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"}, wantErr: "invalid name prefix", }, { name: "valid name prefix", - req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"}, + req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"}, wantErr: "", }, } @@ -953,7 +951,7 @@ func TestExposeServiceRequestValidate(t *testing.T) { } t.Run("nil receiver", func(t *testing.T) { - var req *reverseproxy.ExposeServiceRequest + var req *rpservice.ExposeServiceRequest err := req.Validate() require.Error(t, err) assert.Contains(t, err.Error(), "request cannot be nil") @@ -967,7 +965,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) { mgr, testStore := setupIntegrationTest(t) // First create a service - req := &reverseproxy.ExposeServiceRequest{ + req := &rpservice.ExposeServiceRequest{ Port: 8080, Protocol: "http", } @@ -986,7 +984,7 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) { t.Run("expire uses correct activity", func(t *testing.T) { mgr, _ := setupIntegrationTest(t) - req := &reverseproxy.ExposeServiceRequest{ + req := &rpservice.ExposeServiceRequest{ Port: 8080, Protocol: "http", } @@ -1004,7 +1002,7 @@ func TestStopServiceFromPeer(t *testing.T) { t.Run("stops service by domain", func(t *testing.T) { mgr, testStore := setupIntegrationTest(t) - req := &reverseproxy.ExposeServiceRequest{ + req := &rpservice.ExposeServiceRequest{ Port: 8080, Protocol: "http", } @@ -1023,7 +1021,7 @@ func TestDeleteService_UntracksEphemeralExpose(t *testing.T) { ctx := context.Background() mgr, _ := setupIntegrationTest(t) - resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ Port: 8080, Protocol: "http", }) @@ -1041,7 +1039,7 @@ func TestDeleteService_UntracksEphemeralExpose(t *testing.T) { assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be untracked after API delete") // A new expose should succeed (not blocked by stale tracking) - _, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ + _, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ Port: 9090, Protocol: "http", }) @@ -1053,7 +1051,7 @@ func TestDeleteAllServices_UntracksEphemeralExposes(t *testing.T) { mgr, _ := setupIntegrationTest(t) for i := range 3 { - _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ Port: 8080 + i, Protocol: "http", }) @@ -1074,7 +1072,7 @@ func TestRenewServiceFromPeer(t *testing.T) { t.Run("renews tracked expose", func(t *testing.T) { mgr, _ := setupIntegrationTest(t) - resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{ Port: 8080, Protocol: "http", }) @@ -1129,25 +1127,32 @@ func TestDeleteService_DeletesTargets(t *testing.T) { mockPerms := permissions.NewMockManager(ctrl) mockAcct := account.NewMockManager(ctrl) - mockGRPC := &nbgrpc.ProxyServiceServer{} - mgr := &managerImpl{ + tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100) + require.NoError(t, err) + proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) + t.Cleanup(proxySrv.Close) + + proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) + require.NoError(t, err) + + mgr := &Manager{ store: sqlStore, permissionsManager: mockPerms, accountManager: mockAcct, - proxyGRPCServer: mockGRPC, + proxyController: proxyController, } - service := &reverseproxy.Service{ + service := &rpservice.Service{ ID: "service-1", AccountID: accountID, Domain: "test.example.com", ProxyCluster: "cluster1", Enabled: true, - Targets: []*reverseproxy.Target{ - {AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-1"}, - {AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-2"}, - {AccountID: accountID, ServiceID: "service-1", TargetType: reverseproxy.TargetTypePeer, TargetId: "peer-3"}, + Targets: []*rpservice.Target{ + {AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-1"}, + {AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-2"}, + {AccountID: accountID, ServiceID: "service-1", TargetType: rpservice.TargetTypePeer, TargetId: "peer-3"}, }, } diff --git a/management/internals/modules/reverseproxy/reverseproxy.go b/management/internals/modules/reverseproxy/service/service.go similarity index 94% rename from management/internals/modules/reverseproxy/reverseproxy.go rename to management/internals/modules/reverseproxy/service/service.go index 10226710b..46ae185d6 100644 --- a/management/internals/modules/reverseproxy/reverseproxy.go +++ b/management/internals/modules/reverseproxy/service/service.go @@ -1,4 +1,4 @@ -package reverseproxy +package service import ( "crypto/rand" @@ -14,6 +14,7 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/shared/hash/argon2id" "github.com/netbirdio/netbird/util/crypt" @@ -29,15 +30,15 @@ const ( Delete Operation = "delete" ) -type ProxyStatus string +type Status string const ( - StatusPending ProxyStatus = "pending" - StatusActive ProxyStatus = "active" - StatusTunnelNotCreated ProxyStatus = "tunnel_not_created" - StatusCertificatePending ProxyStatus = "certificate_pending" - StatusCertificateFailed ProxyStatus = "certificate_failed" - StatusError ProxyStatus = "error" + StatusPending Status = "pending" + StatusActive Status = "active" + StatusTunnelNotCreated Status = "tunnel_not_created" + StatusCertificatePending Status = "certificate_pending" + StatusCertificateFailed Status = "certificate_failed" + StatusError Status = "error" TargetTypePeer = "peer" TargetTypeHost = "host" @@ -111,14 +112,7 @@ func (a *AuthConfig) ClearSecrets() { } } -type OIDCValidationConfig struct { - Issuer string - Audiences []string - KeysLocation string - MaxTokenAgeSeconds int64 -} - -type ServiceMeta struct { +type Meta struct { CreatedAt time.Time CertificateIssuedAt *time.Time Status string @@ -135,11 +129,11 @@ type Service struct { Enabled bool PassHostHeader bool RewriteRedirects bool - Auth AuthConfig `gorm:"serializer:json"` - Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"` - SessionPrivateKey string `gorm:"column:session_private_key"` - SessionPublicKey string `gorm:"column:session_public_key"` - Source string `gorm:"default:'permanent'"` + Auth AuthConfig `gorm:"serializer:json"` + Meta Meta `gorm:"embedded;embeddedPrefix:meta_"` + SessionPrivateKey string `gorm:"column:session_private_key"` + SessionPublicKey string `gorm:"column:session_public_key"` + Source string `gorm:"default:'permanent'"` SourcePeer string } @@ -165,7 +159,7 @@ func NewService(accountID, name, domain, proxyCluster string, targets []*Target, // only be called during initial creation, not for updates. func (s *Service) InitNewRecord() { s.ID = xid.New().String() - s.Meta = ServiceMeta{ + s.Meta = Meta{ CreatedAt: time.Now(), Status: string(StatusPending), } @@ -239,7 +233,7 @@ func (s *Service) ToAPIResponse() *api.Service { return resp } -func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping { +func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping { pathMappings := make([]*proto.PathMapping, 0, len(s.Targets)) for _, target := range s.Targets { if !target.Enabled { diff --git a/management/internals/modules/reverseproxy/reverseproxy_test.go b/management/internals/modules/reverseproxy/service/service_test.go similarity index 98% rename from management/internals/modules/reverseproxy/reverseproxy_test.go rename to management/internals/modules/reverseproxy/service/service_test.go index cb75ee61f..8b09ab827 100644 --- a/management/internals/modules/reverseproxy/reverseproxy_test.go +++ b/management/internals/modules/reverseproxy/service/service_test.go @@ -1,4 +1,4 @@ -package reverseproxy +package service import ( "errors" @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/shared/hash/argon2id" "github.com/netbirdio/netbird/shared/management/proto" ) @@ -109,7 +110,7 @@ func TestIsDefaultPort(t *testing.T) { } func TestToProtoMapping_PortInTargetURL(t *testing.T) { - oidcConfig := OIDCValidationConfig{} + oidcConfig := proxy.OIDCValidationConfig{} tests := []struct { name string @@ -202,7 +203,7 @@ func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) { {TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true}, }, } - pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{}) + pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{}) require.Len(t, pm.Path, 1) assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target) } @@ -219,7 +220,7 @@ func TestToProtoMapping_OperationTypes(t *testing.T) { } for _, tt := range tests { t.Run(string(tt.op), func(t *testing.T) { - pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{}) + pm := rp.ToProtoMapping(tt.op, "", proxy.OIDCValidationConfig{}) assert.Equal(t, tt.want, pm.Type) }) } diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 45c1b763f..2049f0051 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -94,7 +94,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ReverseProxyManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) if err != nil { log.Fatalf("failed to create API handler: %v", err) } @@ -134,7 +134,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server { if s.Config.HttpConfig.LetsEncryptDomain != "" { certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain) if err != nil { - log.Fatalf("failed to create certificate manager: %v", err) + log.Fatalf("failed to create certificate service: %v", err) } transportCredentials := credentials.NewTLS(certManager.TLSConfig()) gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials)) @@ -152,10 +152,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server { if err != nil { log.Fatalf("failed to create management server: %v", err) } - reverseProxyMgr := s.ReverseProxyManager() - srv.SetReverseProxyManager(reverseProxyMgr) - if reverseProxyMgr != nil { - reverseProxyMgr.StartExposeReaper(context.Background()) + serviceMgr := s.ServiceManager() + srv.SetReverseProxyManager(serviceMgr) + if serviceMgr != nil { + serviceMgr.StartExposeReaper(context.Background()) } mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv) @@ -168,9 +168,10 @@ func (s *BaseServer) GRPCServer() *grpc.Server { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer { - proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager()) + proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager()) s.AfterInit(func(s *BaseServer) { - proxyService.SetProxyManager(s.ReverseProxyManager()) + proxyService.SetServiceManager(s.ServiceManager()) + proxyService.SetProxyController(s.ServiceProxyController()) }) return proxyService }) @@ -193,7 +194,10 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig { func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { return Create(s, func() *nbgrpc.OneTimeTokenStore { - tokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute) + tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100) + if err != nil { + log.Fatalf("failed to create proxy token store: %v", err) + } log.Info("One-time token store initialized for proxy authentication") return tokenStore }) diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 4ea86900a..62ed659c0 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -6,6 +6,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" "github.com/netbirdio/netbird/management/internals/controllers/network_map" nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" @@ -106,6 +108,16 @@ func (s *BaseServer) NetworkMapController() network_map.Controller { }) } +func (s *BaseServer) ServiceProxyController() proxy.Controller { + return Create(s, func() proxy.Controller { + controller, err := proxymanager.NewGRPCController(s.ReverseProxyGRPCServer(), s.Metrics().GetMeter()) + if err != nil { + log.Fatalf("failed to create service proxy controller: %v", err) + } + return controller + }) +} + func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer { return Create(s, func() *server.AccountRequestBuffer { return server.NewAccountRequestBuffer(context.Background(), s.Store()) diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index faec5b99c..2383019e2 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -8,9 +8,11 @@ import ( "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/management/internals/modules/peers" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" - nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" "github.com/netbirdio/netbird/management/internals/modules/zones" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" "github.com/netbirdio/netbird/management/internals/modules/zones/records" @@ -99,11 +101,11 @@ func (s *BaseServer) AccountManager() account.Manager { return Create(s, func() account.Manager { accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy) if err != nil { - log.Fatalf("failed to create account manager: %v", err) + log.Fatalf("failed to create account service: %v", err) } s.AfterInit(func(s *BaseServer) { - accountManager.SetServiceManager(s.ReverseProxyManager()) + accountManager.SetServiceManager(s.ServiceManager()) }) return accountManager @@ -114,28 +116,28 @@ func (s *BaseServer) IdpManager() idp.Manager { return Create(s, func() idp.Manager { var idpManager idp.Manager var err error - // Use embedded IdP manager if embedded Dex is configured and enabled. + // Use embedded IdP service if embedded Dex is configured and enabled. // Legacy IdpManager won't be used anymore even if configured. if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled { idpManager, err = idp.NewEmbeddedIdPManager(context.Background(), s.Config.EmbeddedIdP, s.Metrics()) if err != nil { - log.Fatalf("failed to create embedded IDP manager: %v", err) + log.Fatalf("failed to create embedded IDP service: %v", err) } return idpManager } - // Fall back to external IdP manager + // Fall back to external IdP service if s.Config.IdpManagerConfig != nil { idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics()) if err != nil { - log.Fatalf("failed to create IDP manager: %v", err) + log.Fatalf("failed to create IDP service: %v", err) } } return idpManager }) } -// OAuthConfigProvider is only relevant when we have an embedded IdP manager. Otherwise must be nil +// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider { if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled { return nil @@ -162,7 +164,7 @@ func (s *BaseServer) GroupsManager() groups.Manager { func (s *BaseServer) ResourcesManager() resources.Manager { return Create(s, func() resources.Manager { - return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ReverseProxyManager()) + return resources.NewManager(s.Store(), s.PermissionsManager(), s.GroupsManager(), s.AccountManager(), s.ServiceManager()) }) } @@ -190,15 +192,25 @@ func (s *BaseServer) RecordsManager() records.Manager { }) } -func (s *BaseServer) ReverseProxyManager() reverseproxy.Manager { - return Create(s, func() reverseproxy.Manager { - return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.SettingsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager()) +func (s *BaseServer) ServiceManager() service.Manager { + return Create(s, func() service.Manager { + return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager()) + }) +} + +func (s *BaseServer) ProxyManager() proxy.Manager { + return Create(s, func() proxy.Manager { + manager, err := proxymanager.NewManager(s.Store(), s.Metrics().GetMeter()) + if err != nil { + log.Fatalf("failed to create proxy manager: %v", err) + } + return manager }) } func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager { return Create(s, func() *manager.Manager { - m := manager.NewManager(s.Store(), s.ReverseProxyGRPCServer(), s.PermissionsManager()) + m := manager.NewManager(s.Store(), s.ProxyManager(), s.PermissionsManager()) return &m }) } diff --git a/management/internals/server/server.go b/management/internals/server/server.go index 3f7f9c4c0..5149c338b 100644 --- a/management/internals/server/server.go +++ b/management/internals/server/server.go @@ -157,7 +157,7 @@ func (s *BaseServer) Start(ctx context.Context) error { // Eagerly create the gRPC server so that all AfterInit hooks are registered // before we iterate them. Lazy creation after the loop would miss hooks - // registered during GRPCServer() construction (e.g., SetProxyManager). + // registered during GRPCServer() construction (e.g., SetServiceManager). s.GRPCServer() for _, fn := range s.afterInit { diff --git a/management/internals/shared/grpc/expose_service.go b/management/internals/shared/grpc/expose_service.go index ef00354af..c444471b0 100644 --- a/management/internals/shared/grpc/expose_service.go +++ b/management/internals/shared/grpc/expose_service.go @@ -10,7 +10,7 @@ import ( "google.golang.org/grpc/status" "github.com/netbirdio/netbird/encryption" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbContext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" @@ -39,7 +39,7 @@ func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage) return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") } - created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &reverseproxy.ExposeServiceRequest{ + created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &rpservice.ExposeServiceRequest{ NamePrefix: exposeReq.NamePrefix, Port: int(exposeReq.Port), Protocol: exposeProtocolToString(exposeReq.Protocol), @@ -167,14 +167,14 @@ func (s *Server) authenticateExposePeer(ctx context.Context, peerKey wgtypes.Key return accountID, peer, nil } -func (s *Server) getReverseProxyManager() reverseproxy.Manager { +func (s *Server) getReverseProxyManager() rpservice.Manager { s.reverseProxyMu.RLock() defer s.reverseProxyMu.RUnlock() return s.reverseProxyManager } // SetReverseProxyManager sets the reverse proxy manager on the server. -func (s *Server) SetReverseProxyManager(mgr reverseproxy.Manager) { +func (s *Server) SetReverseProxyManager(mgr rpservice.Manager) { s.reverseProxyMu.Lock() defer s.reverseProxyMu.Unlock() s.reverseProxyManager = mgr diff --git a/management/internals/shared/grpc/onetime_token.go b/management/internals/shared/grpc/onetime_token.go index dcc37c639..7999407db 100644 --- a/management/internals/shared/grpc/onetime_token.go +++ b/management/internals/shared/grpc/onetime_token.go @@ -1,28 +1,23 @@ package grpc import ( + "context" "crypto/rand" + "crypto/sha256" "crypto/subtle" "encoding/base64" + "encoding/hex" + "encoding/json" "fmt" - "sync" "time" + "github.com/eko/gocache/lib/v4/cache" + "github.com/eko/gocache/lib/v4/store" log "github.com/sirupsen/logrus" + + nbcache "github.com/netbirdio/netbird/management/server/cache" ) -// OneTimeTokenStore manages short-lived, single-use authentication tokens -// for proxy-to-management RPC authentication. Tokens are generated when -// a service is created and must be used exactly once by the proxy -// to authenticate a subsequent RPC call. -type OneTimeTokenStore struct { - tokens map[string]*tokenMetadata - mu sync.RWMutex - cleanup *time.Ticker - cleanupDone chan struct{} -} - -// tokenMetadata stores information about a one-time token type tokenMetadata struct { ServiceID string AccountID string @@ -30,20 +25,24 @@ type tokenMetadata struct { CreatedAt time.Time } -// NewOneTimeTokenStore creates a new token store with automatic cleanup -// of expired tokens. The cleanupInterval determines how often expired -// tokens are removed from memory. -func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore { - store := &OneTimeTokenStore{ - tokens: make(map[string]*tokenMetadata), - cleanup: time.NewTicker(cleanupInterval), - cleanupDone: make(chan struct{}), +// OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC. +// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var. +type OneTimeTokenStore struct { + cache *cache.Cache[string] + ctx context.Context +} + +// NewOneTimeTokenStore creates a token store with automatic backend selection +func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) { + cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn) + if err != nil { + return nil, fmt.Errorf("failed to create cache store: %w", err) } - // Start background cleanup goroutine - go store.cleanupExpired() - - return store + return &OneTimeTokenStore{ + cache: cache.New[string](cacheStore), + ctx: ctx, + }, nil } // GenerateToken creates a new cryptographically secure one-time token @@ -52,25 +51,30 @@ func NewOneTimeTokenStore(cleanupInterval time.Duration) *OneTimeTokenStore { // // Returns the generated token string or an error if random generation fails. func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.Duration) (string, error) { - // Generate 32 bytes (256 bits) of cryptographically secure random data randomBytes := make([]byte, 32) if _, err := rand.Read(randomBytes); err != nil { return "", fmt.Errorf("failed to generate random token: %w", err) } - // Encode as URL-safe base64 for easy transmission in gRPC token := base64.URLEncoding.EncodeToString(randomBytes) + hashedToken := hashToken(token) - s.mu.Lock() - defer s.mu.Unlock() - - s.tokens[token] = &tokenMetadata{ + metadata := &tokenMetadata{ ServiceID: serviceID, AccountID: accountID, ExpiresAt: time.Now().Add(ttl), CreatedAt: time.Now(), } + metadataJSON, err := json.Marshal(metadata) + if err != nil { + return "", fmt.Errorf("failed to serialize token metadata: %w", err) + } + + if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil { + return "", fmt.Errorf("failed to store token: %w", err) + } + log.Debugf("Generated one-time token for proxy %s in account %s (expires in %s)", serviceID, accountID, ttl) @@ -88,80 +92,45 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time. // - Account ID doesn't match // - Reverse proxy ID doesn't match func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error { - s.mu.Lock() - defer s.mu.Unlock() + hashedToken := hashToken(token) - metadata, exists := s.tokens[token] - if !exists { - log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", - serviceID, accountID) + metadataJSON, err := s.cache.Get(s.ctx, hashedToken) + if err != nil { + log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID) return fmt.Errorf("invalid token") } - // Check expiration + metadata := &tokenMetadata{} + if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil { + log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err) + return fmt.Errorf("invalid token metadata") + } + if time.Now().After(metadata.ExpiresAt) { - delete(s.tokens, token) - log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", - serviceID, accountID) + log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID) return fmt.Errorf("token expired") } - // Validate account ID using constant-time comparison (prevents timing attacks) if subtle.ConstantTimeCompare([]byte(metadata.AccountID), []byte(accountID)) != 1 { - log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", - metadata.AccountID, accountID) + log.Warnf("Token validation failed: account ID mismatch (expected: %s, got: %s)", metadata.AccountID, accountID) return fmt.Errorf("account ID mismatch") } - // Validate service ID using constant-time comparison if subtle.ConstantTimeCompare([]byte(metadata.ServiceID), []byte(serviceID)) != 1 { - log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", - metadata.ServiceID, serviceID) + log.Warnf("Token validation failed: service ID mismatch (expected: %s, got: %s)", metadata.ServiceID, serviceID) return fmt.Errorf("service ID mismatch") } - // Delete token immediately to enforce single-use - delete(s.tokens, token) + if err := s.cache.Delete(s.ctx, hashedToken); err != nil { + log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err) + } - log.Infof("Token validated and consumed for proxy %s in account %s", - serviceID, accountID) + log.Infof("Token validated and consumed for proxy %s in account %s", serviceID, accountID) return nil } -// cleanupExpired removes expired tokens in the background to prevent memory leaks -func (s *OneTimeTokenStore) cleanupExpired() { - for { - select { - case <-s.cleanup.C: - s.mu.Lock() - now := time.Now() - removed := 0 - for token, metadata := range s.tokens { - if now.After(metadata.ExpiresAt) { - delete(s.tokens, token) - removed++ - } - } - if removed > 0 { - log.Debugf("Cleaned up %d expired one-time tokens", removed) - } - s.mu.Unlock() - case <-s.cleanupDone: - return - } - } -} - -// Close stops the cleanup goroutine and releases resources -func (s *OneTimeTokenStore) Close() { - s.cleanup.Stop() - close(s.cleanupDone) -} - -// GetTokenCount returns the current number of tokens in the store (for debugging/metrics) -func (s *OneTimeTokenStore) GetTokenCount() int { - s.mu.RLock() - defer s.mu.RUnlock() - return len(s.tokens) +func hashToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return hex.EncodeToString(hash[:]) } diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index e47ea5315..676757c1e 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -24,8 +24,9 @@ import ( "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/management/internals/modules/peers" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" @@ -58,14 +59,17 @@ type ProxyServiceServer struct { // Map of connected proxies: proxy_id -> proxy connection connectedProxies sync.Map - // Map of cluster address -> set of proxy IDs - clusterProxies sync.Map - // Manager for access logs accessLogManager accesslogs.Manager // Manager for reverse proxy operations - reverseProxyManager reverseproxy.Manager + serviceManager rpservice.Manager + + // ProxyController for service updates and cluster management + proxyController proxy.Controller + + // Manager for proxy connections + proxyManager proxy.Manager // Manager for peers peersManager peers.Manager @@ -104,7 +108,7 @@ type proxyConnection struct { } // NewProxyServiceServer creates a new proxy service server. -func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer { +func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer { ctx, cancel := context.WithCancel(context.Background()) s := &ProxyServiceServer{ accessLogManager: accessLogMgr, @@ -112,9 +116,11 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT tokenStore: tokenStore, peersManager: peersManager, usersManager: usersManager, + proxyManager: proxyMgr, pkceCleanupCancel: cancel, } go s.cleanupPKCEVerifiers(ctx) + go s.cleanupStaleProxies(ctx) return s } @@ -138,13 +144,33 @@ func (s *ProxyServiceServer) cleanupPKCEVerifiers(ctx context.Context) { } } +// cleanupStaleProxies periodically removes proxies that haven't sent heartbeat in 10 minutes +func (s *ProxyServiceServer) cleanupStaleProxies(ctx context.Context) { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := s.proxyManager.CleanupStale(ctx, 10*time.Minute); err != nil { + log.WithContext(ctx).Debugf("Failed to cleanup stale proxies: %v", err) + } + } + } +} + // Close stops background goroutines. func (s *ProxyServiceServer) Close() { s.pkceCleanupCancel() } -func (s *ProxyServiceServer) SetProxyManager(manager reverseproxy.Manager) { - s.reverseProxyManager = manager +func (s *ProxyServiceServer) SetServiceManager(manager rpservice.Manager) { + s.serviceManager = manager +} + +func (s *ProxyServiceServer) SetProxyController(proxyController proxy.Controller) { + s.proxyController = proxyController } // GetMappingUpdate handles the control stream with proxy clients @@ -179,7 +205,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest } s.connectedProxies.Store(proxyID, conn) - s.addToCluster(conn.address, proxyID) + if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { + log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) + } + + // Register proxy in database + if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil { + log.WithContext(ctx).Warnf("Failed to register proxy %s in database: %v", proxyID, err) + } + log.WithFields(log.Fields{ "proxy_id": proxyID, "address": proxyAddress, @@ -187,8 +221,15 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest "total_proxies": len(s.GetConnectedProxies()), }).Info("Proxy registered in cluster") defer func() { + if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil { + log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err) + } + s.connectedProxies.Delete(proxyID) - s.removeFromCluster(conn.address, proxyID) + if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil { + log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err) + } + cancel() log.Infof("Proxy %s disconnected", proxyID) }() @@ -200,6 +241,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest errChan := make(chan error, 2) go s.sender(conn, errChan) + // Start heartbeat goroutine + go s.heartbeat(connCtx, proxyID) + select { case err := <-errChan: return fmt.Errorf("send update to proxy %s: %w", proxyID, err) @@ -208,10 +252,27 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest } } +// heartbeat updates the proxy's last_seen timestamp every minute +func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID string) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if err := s.proxyManager.Heartbeat(ctx, proxyID); err != nil { + log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err) + } + case <-ctx.Done(): + return + } + } +} + // sendSnapshot sends the initial snapshot of services to the connecting proxy. // Only services matching the proxy's cluster address are sent. func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnection) error { - services, err := s.reverseProxyManager.GetGlobalServices(ctx) + services, err := s.serviceManager.GetGlobalServices(ctx) if err != nil { return fmt.Errorf("get services from store: %w", err) } @@ -220,7 +281,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec return fmt.Errorf("proxy address is invalid") } - var filtered []*reverseproxy.Service + var filtered []*rpservice.Service for _, service := range services { if !service.Enabled { continue @@ -255,7 +316,7 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec if err := conn.stream.Send(&proto.GetMappingUpdateResponse{ Mapping: []*proto.ProxyMapping{ service.ToProtoMapping( - reverseproxy.Create, // Initial snapshot, all records are "new" for the proxy. + rpservice.Create, // Initial snapshot, all records are "new" for the proxy. token, s.GetOIDCValidationConfig(), ), @@ -389,61 +450,47 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string { return urls } -// addToCluster registers a proxy in a cluster. -func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) { - if clusterAddr == "" { - return - } - proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{}) - proxySet.(*sync.Map).Store(proxyID, struct{}{}) - log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr) -} - -// removeFromCluster removes a proxy from a cluster. -func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) { - if clusterAddr == "" { - return - } - if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok { - proxySet.(*sync.Map).Delete(proxyID) - log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr) - } -} - // SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster. // If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility). // For create/update operations a unique one-time auth token is generated per // proxy so that every replica can independently authenticate with management. -func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.GetMappingUpdateResponse, clusterAddr string) { +func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) { + updateResponse := &proto.GetMappingUpdateResponse{ + Mapping: []*proto.ProxyMapping{update}, + } + if clusterAddr == "" { - s.SendServiceUpdate(update) + s.SendServiceUpdate(updateResponse) return } - proxySet, ok := s.clusterProxies.Load(clusterAddr) - if !ok { - log.Debugf("No proxies connected for cluster %s", clusterAddr) + if s.proxyController == nil { + log.WithContext(ctx).Debugf("ProxyController not set, cannot send to cluster %s", clusterAddr) + return + } + + proxyIDs := s.proxyController.GetProxiesForCluster(clusterAddr) + if len(proxyIDs) == 0 { + log.WithContext(ctx).Debugf("No proxies connected for cluster %s", clusterAddr) return } log.Debugf("Sending service update to cluster %s", clusterAddr) - proxySet.(*sync.Map).Range(func(key, _ interface{}) bool { - proxyID := key.(string) + for _, proxyID := range proxyIDs { if connVal, ok := s.connectedProxies.Load(proxyID); ok { conn := connVal.(*proxyConnection) - msg := s.perProxyMessage(update, proxyID) + msg := s.perProxyMessage(updateResponse, proxyID) if msg == nil { - return true + continue } select { case conn.sendChan <- msg: - log.Debugf("Sent service update to proxy %s in cluster %s", proxyID, clusterAddr) + log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) default: - log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) + log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) } } - return true - }) + } } // perProxyMessage returns a copy of update with a fresh one-time token for @@ -490,35 +537,8 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping { } } -// GetAvailableClusters returns information about all connected proxy clusters. -func (s *ProxyServiceServer) GetAvailableClusters() []ClusterInfo { - clusterCounts := make(map[string]int) - s.clusterProxies.Range(func(key, value interface{}) bool { - clusterAddr := key.(string) - proxySet := value.(*sync.Map) - count := 0 - proxySet.Range(func(_, _ interface{}) bool { - count++ - return true - }) - if count > 0 { - clusterCounts[clusterAddr] = count - } - return true - }) - - clusters := make([]ClusterInfo, 0, len(clusterCounts)) - for addr, count := range clusterCounts { - clusters = append(clusters, ClusterInfo{ - Address: addr, - ConnectedProxies: count, - }) - } - return clusters -} - func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.AuthenticateRequest) (*proto.AuthenticateResponse, error) { - service, err := s.reverseProxyManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) + service, err := s.serviceManager.GetServiceByID(ctx, req.GetAccountId(), req.GetId()) if err != nil { log.WithContext(ctx).Debugf("failed to get service from store: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "get service from store: %v", err) @@ -537,7 +557,7 @@ func (s *ProxyServiceServer) Authenticate(ctx context.Context, req *proto.Authen }, nil } -func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *reverseproxy.Service) (bool, string, proxyauth.Method) { +func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto.AuthenticateRequest, service *rpservice.Service) (bool, string, proxyauth.Method) { switch v := req.GetRequest().(type) { case *proto.AuthenticateRequest_Pin: return s.authenticatePIN(ctx, req.GetId(), v, service.Auth.PinAuth) @@ -548,7 +568,7 @@ func (s *ProxyServiceServer) authenticateRequest(ctx context.Context, req *proto } } -func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *reverseproxy.PINAuthConfig) (bool, string, proxyauth.Method) { +func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Pin, auth *rpservice.PINAuthConfig) (bool, string, proxyauth.Method) { if auth == nil || !auth.Enabled { log.WithContext(ctx).Debugf("PIN authentication attempted but not enabled for service %s", serviceID) return false, "", "" @@ -562,7 +582,7 @@ func (s *ProxyServiceServer) authenticatePIN(ctx context.Context, serviceID stri return true, "pin-user", proxyauth.MethodPIN } -func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *reverseproxy.PasswordAuthConfig) (bool, string, proxyauth.Method) { +func (s *ProxyServiceServer) authenticatePassword(ctx context.Context, serviceID string, req *proto.AuthenticateRequest_Password, auth *rpservice.PasswordAuthConfig) (bool, string, proxyauth.Method) { if auth == nil || !auth.Enabled { log.WithContext(ctx).Debugf("password authentication attempted but not enabled for service %s", serviceID) return false, "", "" @@ -584,7 +604,7 @@ func (s *ProxyServiceServer) logAuthenticationError(ctx context.Context, err err } } -func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *reverseproxy.Service, userId string, method proxyauth.Method) (string, error) { +func (s *ProxyServiceServer) generateSessionToken(ctx context.Context, authenticated bool, service *rpservice.Service, userId string, method proxyauth.Method) (string, error) { if !authenticated || service.SessionPrivateKey == "" { return "", nil } @@ -624,7 +644,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se } if certificateIssued { - if err := s.reverseProxyManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil { + if err := s.serviceManager.SetCertificateIssuedAt(ctx, accountID, serviceID); err != nil { log.WithContext(ctx).WithError(err).Error("failed to set certificate issued timestamp") return nil, status.Errorf(codes.Internal, "update certificate timestamp: %v", err) } @@ -636,7 +656,7 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se internalStatus := protoStatusToInternal(protoStatus) - if err := s.reverseProxyManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil { + if err := s.serviceManager.SetStatus(ctx, accountID, serviceID, internalStatus); err != nil { log.WithContext(ctx).WithError(err).Error("failed to update service status") return nil, status.Errorf(codes.Internal, "update service status: %v", err) } @@ -651,22 +671,22 @@ func (s *ProxyServiceServer) SendStatusUpdate(ctx context.Context, req *proto.Se } // protoStatusToInternal maps proto status to internal status -func protoStatusToInternal(protoStatus proto.ProxyStatus) reverseproxy.ProxyStatus { +func protoStatusToInternal(protoStatus proto.ProxyStatus) rpservice.Status { switch protoStatus { case proto.ProxyStatus_PROXY_STATUS_PENDING: - return reverseproxy.StatusPending + return rpservice.StatusPending case proto.ProxyStatus_PROXY_STATUS_ACTIVE: - return reverseproxy.StatusActive + return rpservice.StatusActive case proto.ProxyStatus_PROXY_STATUS_TUNNEL_NOT_CREATED: - return reverseproxy.StatusTunnelNotCreated + return rpservice.StatusTunnelNotCreated case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_PENDING: - return reverseproxy.StatusCertificatePending + return rpservice.StatusCertificatePending case proto.ProxyStatus_PROXY_STATUS_CERTIFICATE_FAILED: - return reverseproxy.StatusCertificateFailed + return rpservice.StatusCertificateFailed case proto.ProxyStatus_PROXY_STATUS_ERROR: - return reverseproxy.StatusError + return rpservice.StatusError default: - return reverseproxy.StatusError + return rpservice.StatusError } } @@ -731,7 +751,7 @@ func (s *ProxyServiceServer) GetOIDCURL(ctx context.Context, req *proto.GetOIDCU return nil, status.Errorf(codes.InvalidArgument, "parse redirect url: %v", err) } // Validate redirectURL against known service endpoints to avoid abuse of OIDC redirection. - services, err := s.reverseProxyManager.GetAccountServices(ctx, req.GetAccountId()) + services, err := s.serviceManager.GetAccountServices(ctx, req.GetAccountId()) if err != nil { log.WithContext(ctx).Errorf("failed to get account services: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "get account services: %v", err) @@ -794,8 +814,8 @@ func (s *ProxyServiceServer) GetOIDCConfig() ProxyOIDCConfig { // GetOIDCValidationConfig returns the OIDC configuration for token validation // in the format needed by ToProtoMapping. -func (s *ProxyServiceServer) GetOIDCValidationConfig() reverseproxy.OIDCValidationConfig { - return reverseproxy.OIDCValidationConfig{ +func (s *ProxyServiceServer) GetOIDCValidationConfig() proxy.OIDCValidationConfig { + return proxy.OIDCValidationConfig{ Issuer: s.oidcConfig.Issuer, Audiences: []string{s.oidcConfig.Audience}, KeysLocation: s.oidcConfig.KeysLocation, @@ -854,12 +874,12 @@ func (s *ProxyServiceServer) ValidateState(state string) (verifier, redirectURL // GenerateSessionToken creates a signed session JWT for the given domain and user. func (s *ProxyServiceServer) GenerateSessionToken(ctx context.Context, domain, userID string, method proxyauth.Method) (string, error) { // Find the service by domain to get its signing key - services, err := s.reverseProxyManager.GetGlobalServices(ctx) + services, err := s.serviceManager.GetGlobalServices(ctx) if err != nil { return "", fmt.Errorf("get services: %w", err) } - var service *reverseproxy.Service + var service *rpservice.Service for _, svc := range services { if svc.Domain == domain { service = svc @@ -925,8 +945,8 @@ func (s *ProxyServiceServer) ValidateUserGroupAccess(ctx context.Context, domain return fmt.Errorf("user %s not in allowed groups for domain %s", user.Id, domain) } -func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) { - services, err := s.reverseProxyManager.GetAccountServices(ctx, accountID) +func (s *ProxyServiceServer) getAccountServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) { + services, err := s.serviceManager.GetAccountServices(ctx, accountID) if err != nil { return nil, fmt.Errorf("get account services: %w", err) } @@ -1047,8 +1067,8 @@ func (s *ProxyServiceServer) ValidateSession(ctx context.Context, req *proto.Val }, nil } -func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*reverseproxy.Service, error) { - services, err := s.reverseProxyManager.GetGlobalServices(ctx) +func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain string) (*rpservice.Service, error) { + services, err := s.serviceManager.GetGlobalServices(ctx) if err != nil { return nil, fmt.Errorf("get services: %w", err) } @@ -1062,7 +1082,7 @@ func (s *ProxyServiceServer) getServiceByDomain(ctx context.Context, domain stri return nil, fmt.Errorf("service not found for domain: %s", domain) } -func (s *ProxyServiceServer) checkGroupAccess(service *reverseproxy.Service, user *types.User) error { +func (s *ProxyServiceServer) checkGroupAccess(service *rpservice.Service, user *types.User) error { if service.Auth.BearerAuth == nil || !service.Auth.BearerAuth.Enabled { return nil } diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index 827897981..22fe4506b 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -8,12 +8,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/server/types" ) type mockReverseProxyManager struct { - proxiesByAccount map[string][]*reverseproxy.Service + proxiesByAccount map[string][]*service.Service err error } @@ -21,31 +21,31 @@ func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, account return nil } -func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { +func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) { if m.err != nil { return nil, m.err } return m.proxiesByAccount[accountID], nil } -func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { +func (m *mockReverseProxyManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) { return nil, nil } -func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { - return []*reverseproxy.Service{}, nil +func (m *mockReverseProxyManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { + return []*service.Service{}, nil } -func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*reverseproxy.Service, error) { - return &reverseproxy.Service{}, nil +func (m *mockReverseProxyManager) GetService(ctx context.Context, accountID, userID, reverseProxyID string) (*service.Service, error) { + return &service.Service{}, nil } -func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) { - return &reverseproxy.Service{}, nil +func (m *mockReverseProxyManager) CreateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) { + return &service.Service{}, nil } -func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *reverseproxy.Service) (*reverseproxy.Service, error) { - return &reverseproxy.Service{}, nil +func (m *mockReverseProxyManager) UpdateService(ctx context.Context, accountID, userID string, rp *service.Service) (*service.Service, error) { + return &service.Service{}, nil } func (m *mockReverseProxyManager) DeleteService(ctx context.Context, accountID, userID, reverseProxyID string) error { @@ -56,7 +56,7 @@ func (m *mockReverseProxyManager) SetCertificateIssuedAt(ctx context.Context, ac return nil } -func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status reverseproxy.ProxyStatus) error { +func (m *mockReverseProxyManager) SetStatus(ctx context.Context, accountID, reverseProxyID string, status service.Status) error { return nil } @@ -68,16 +68,16 @@ func (m *mockReverseProxyManager) ReloadService(ctx context.Context, accountID, return nil } -func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*reverseproxy.Service, error) { - return &reverseproxy.Service{}, nil +func (m *mockReverseProxyManager) GetServiceByID(ctx context.Context, accountID, reverseProxyID string) (*service.Service, error) { + return &service.Service{}, nil } func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) { return "", nil } -func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { - return &reverseproxy.ExposeServiceResponse{}, nil +func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) { + return &service.ExposeServiceResponse{}, nil } func (m *mockReverseProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { @@ -111,7 +111,7 @@ func TestValidateUserGroupAccess(t *testing.T) { name string domain string userID string - proxiesByAccount map[string][]*reverseproxy.Service + proxiesByAccount map[string][]*service.Service users map[string]*types.User proxyErr error userErr error @@ -122,7 +122,7 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "user not found", domain: "app.example.com", userID: "unknown-user", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": {{Domain: "app.example.com", AccountID: "account1"}}, }, users: map[string]*types.User{}, @@ -133,7 +133,7 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "proxy not found in user's account", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{}, + proxiesByAccount: map[string][]*service.Service{}, users: map[string]*types.User{ "user1": {Id: "user1", AccountID: "account1"}, }, @@ -144,7 +144,7 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "proxy exists in different account - not accessible", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account2": {{Domain: "app.example.com", AccountID: "account2"}}, }, users: map[string]*types.User{ @@ -157,8 +157,8 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "no bearer auth configured - same account allows access", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ - "account1": {{Domain: "app.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}}, + proxiesByAccount: map[string][]*service.Service{ + "account1": {{Domain: "app.example.com", AccountID: "account1", Auth: service.AuthConfig{}}}, }, users: map[string]*types.User{ "user1": {Id: "user1", AccountID: "account1"}, @@ -169,12 +169,12 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "bearer auth disabled - same account allows access", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": {{ Domain: "app.example.com", AccountID: "account1", - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: false}, + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{Enabled: false}, }, }}, }, @@ -187,12 +187,12 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "bearer auth enabled but no groups configured - same account allows access", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": {{ Domain: "app.example.com", AccountID: "account1", - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, DistributionGroups: []string{}, }, @@ -208,12 +208,12 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "user not in allowed groups", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": {{ Domain: "app.example.com", AccountID: "account1", - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, DistributionGroups: []string{"group1", "group2"}, }, @@ -230,12 +230,12 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "user in one of the allowed groups - allow access", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": {{ Domain: "app.example.com", AccountID: "account1", - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, DistributionGroups: []string{"group1", "group2"}, }, @@ -251,12 +251,12 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "user in all allowed groups - allow access", domain: "app.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": {{ Domain: "app.example.com", AccountID: "account1", - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, DistributionGroups: []string{"group1", "group2"}, }, @@ -284,10 +284,10 @@ func TestValidateUserGroupAccess(t *testing.T) { name: "multiple proxies in account - finds correct one", domain: "app2.example.com", userID: "user1", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": { {Domain: "app1.example.com", AccountID: "account1"}, - {Domain: "app2.example.com", AccountID: "account1", Auth: reverseproxy.AuthConfig{}}, + {Domain: "app2.example.com", AccountID: "account1", Auth: service.AuthConfig{}}, {Domain: "app3.example.com", AccountID: "account1"}, }, }, @@ -301,7 +301,7 @@ func TestValidateUserGroupAccess(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { server := &ProxyServiceServer{ - reverseProxyManager: &mockReverseProxyManager{ + serviceManager: &mockReverseProxyManager{ proxiesByAccount: tt.proxiesByAccount, err: tt.proxyErr, }, @@ -328,7 +328,7 @@ func TestGetAccountProxyByDomain(t *testing.T) { name string accountID string domain string - proxiesByAccount map[string][]*reverseproxy.Service + proxiesByAccount map[string][]*service.Service err error expectProxy bool expectErr bool @@ -337,7 +337,7 @@ func TestGetAccountProxyByDomain(t *testing.T) { name: "proxy found", accountID: "account1", domain: "app.example.com", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": { {Domain: "other.example.com", AccountID: "account1"}, {Domain: "app.example.com", AccountID: "account1"}, @@ -350,7 +350,7 @@ func TestGetAccountProxyByDomain(t *testing.T) { name: "proxy not found in account", accountID: "account1", domain: "unknown.example.com", - proxiesByAccount: map[string][]*reverseproxy.Service{ + proxiesByAccount: map[string][]*service.Service{ "account1": {{Domain: "app.example.com", AccountID: "account1"}}, }, expectProxy: false, @@ -360,7 +360,7 @@ func TestGetAccountProxyByDomain(t *testing.T) { name: "empty proxy list for account", accountID: "account1", domain: "app.example.com", - proxiesByAccount: map[string][]*reverseproxy.Service{}, + proxiesByAccount: map[string][]*service.Service{}, expectProxy: false, expectErr: true, }, @@ -378,7 +378,7 @@ func TestGetAccountProxyByDomain(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { server := &ProxyServiceServer{ - reverseProxyManager: &mockReverseProxyManager{ + serviceManager: &mockReverseProxyManager{ proxiesByAccount: tt.proxiesByAccount, err: tt.err, }, diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index de8ca3c84..ddeadac5a 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -1,19 +1,73 @@ package grpc import ( + "context" "crypto/rand" "encoding/base64" "strings" - "sync" "testing" "time" + "sync" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/shared/management/proto" ) +type testProxyController struct { + mu sync.Mutex + clusterProxies map[string]map[string]struct{} +} + +func newTestProxyController() *testProxyController { + return &testProxyController{ + clusterProxies: make(map[string]map[string]struct{}), + } +} + +func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) { +} + +func (c *testProxyController) GetOIDCValidationConfig() proxy.OIDCValidationConfig { + return proxy.OIDCValidationConfig{} +} + +func (c *testProxyController) RegisterProxyToCluster(_ context.Context, clusterAddr, proxyID string) error { + c.mu.Lock() + defer c.mu.Unlock() + if _, ok := c.clusterProxies[clusterAddr]; !ok { + c.clusterProxies[clusterAddr] = make(map[string]struct{}) + } + c.clusterProxies[clusterAddr][proxyID] = struct{}{} + return nil +} + +func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, clusterAddr, proxyID string) error { + c.mu.Lock() + defer c.mu.Unlock() + if proxies, ok := c.clusterProxies[clusterAddr]; ok { + delete(proxies, proxyID) + } + return nil +} + +func (c *testProxyController) GetProxiesForCluster(clusterAddr string) []string { + c.mu.Lock() + defer c.mu.Unlock() + proxies, ok := c.clusterProxies[clusterAddr] + if !ok { + return nil + } + result := make([]string, 0, len(proxies)) + for id := range proxies { + result = append(result, id) + } + return result +} + // registerFakeProxy adds a fake proxy connection to the server's internal maps // and returns the channel where messages will be received. func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse { @@ -25,8 +79,7 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan } s.connectedProxies.Store(proxyID, conn) - proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{}) - proxySet.(*sync.Map).Store(proxyID, struct{}{}) + _ = s.proxyController.RegisterProxyToCluster(context.Background(), clusterAddr, proxyID) return ch } @@ -41,12 +94,13 @@ func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpda } func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { - tokenStore := NewOneTimeTokenStore(time.Hour) - defer tokenStore.Close() + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) s := &ProxyServiceServer{ tokenStore: tokenStore, } + s.SetProxyController(newTestProxyController()) const cluster = "proxy.example.com" const numProxies = 3 @@ -67,11 +121,7 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { }, } - update := &proto.GetMappingUpdateResponse{ - Mapping: []*proto.ProxyMapping{mapping}, - } - - s.SendServiceUpdateToCluster(update, cluster) + s.SendServiceUpdateToCluster(context.Background(), mapping, cluster) tokens := make([]string, numProxies) for i, ch := range channels { @@ -101,12 +151,13 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { } func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { - tokenStore := NewOneTimeTokenStore(time.Hour) - defer tokenStore.Close() + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) s := &ProxyServiceServer{ tokenStore: tokenStore, } + s.SetProxyController(newTestProxyController()) const cluster = "proxy.example.com" ch1 := registerFakeProxy(s, "proxy-a", cluster) @@ -119,11 +170,7 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { Domain: "test.example.com", } - update := &proto.GetMappingUpdateResponse{ - Mapping: []*proto.ProxyMapping{mapping}, - } - - s.SendServiceUpdateToCluster(update, cluster) + s.SendServiceUpdateToCluster(context.Background(), mapping, cluster) resp1 := drainChannel(ch1) resp2 := drainChannel(ch2) @@ -135,18 +182,16 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { // Delete operations should not generate tokens assert.Empty(t, resp1.Mapping[0].AuthToken) assert.Empty(t, resp2.Mapping[0].AuthToken) - - // No tokens should have been created - assert.Equal(t, 0, tokenStore.GetTokenCount()) } func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { - tokenStore := NewOneTimeTokenStore(time.Hour) - defer tokenStore.Close() + tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) + require.NoError(t, err) s := &ProxyServiceServer{ tokenStore: tokenStore, } + s.SetProxyController(newTestProxyController()) // Register proxies in different clusters (SendServiceUpdate broadcasts to all) ch1 := registerFakeProxy(s, "proxy-a", "cluster-a") diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 029d71e2e..a07cafe90 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -26,7 +26,7 @@ import ( "github.com/netbirdio/netbird/shared/management/client/common" "github.com/netbirdio/netbird/management/internals/controllers/network_map" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/job" @@ -82,7 +82,7 @@ type Server struct { syncLimEnabled bool syncLim int32 - reverseProxyManager reverseproxy.Manager + reverseProxyManager rpservice.Manager reverseProxyMu sync.RWMutex } diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 640a27bb2..124ddf620 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -13,7 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -34,11 +34,15 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup { testStore, storeCleanup, err := store.NewTestStoreFromSQL(ctx, "../../../server/testdata/auth_callback.sql", t.TempDir()) require.NoError(t, err) - proxyManager := &testValidateSessionProxyManager{store: testStore} + serviceManager := &testValidateSessionServiceManager{store: testStore} usersManager := &testValidateSessionUsersManager{store: testStore} + proxyManager := &testValidateSessionProxyManager{} - proxyService := NewProxyServiceServer(nil, NewOneTimeTokenStore(time.Minute), ProxyOIDCConfig{}, nil, usersManager) - proxyService.SetProxyManager(proxyManager) + tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + + proxyService := NewProxyServiceServer(nil, tokenStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) + proxyService.SetServiceManager(serviceManager) createTestProxies(t, ctx, testStore) @@ -54,7 +58,7 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) pubKey, privKey := generateSessionKeyPair(t) - testProxy := &reverseproxy.Service{ + testProxy := &service.Service{ ID: "testProxyId", AccountID: "testAccountId", Name: "Test Proxy", @@ -62,15 +66,15 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) Enabled: true, SessionPrivateKey: privKey, SessionPublicKey: pubKey, - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, }, }, } require.NoError(t, testStore.CreateService(ctx, testProxy)) - restrictedProxy := &reverseproxy.Service{ + restrictedProxy := &service.Service{ ID: "restrictedProxyId", AccountID: "testAccountId", Name: "Restricted Proxy", @@ -78,8 +82,8 @@ func createTestProxies(t *testing.T, ctx context.Context, testStore store.Store) Enabled: true, SessionPrivateKey: privKey, SessionPublicKey: pubKey, - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, DistributionGroups: []string{"allowedGroupId"}, }, @@ -239,79 +243,101 @@ func TestValidateSession_MissingToken(t *testing.T) { assert.Contains(t, resp.DeniedReason, "missing") } -type testValidateSessionProxyManager struct { +type testValidateSessionServiceManager struct { store store.Store } -func (m *testValidateSessionProxyManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) { +func (m *testValidateSessionServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) { return nil, nil } -func (m *testValidateSessionProxyManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) { +func (m *testValidateSessionServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) { return nil, nil } -func (m *testValidateSessionProxyManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *testValidateSessionServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) { return nil, nil } -func (m *testValidateSessionProxyManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *testValidateSessionServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) { return nil, nil } -func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, _ string) error { +func (m *testValidateSessionServiceManager) DeleteService(_ context.Context, _, _, _ string) error { return nil } -func (m *testValidateSessionProxyManager) DeleteAllServices(_ context.Context, _, _ string) error { +func (m *testValidateSessionServiceManager) DeleteAllServices(_ context.Context, _, _ string) error { return nil } -func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error { +func (m *testValidateSessionServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error { return nil } -func (m *testValidateSessionProxyManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error { +func (m *testValidateSessionServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error { return nil } -func (m *testValidateSessionProxyManager) ReloadAllServicesForAccount(_ context.Context, _ string) error { +func (m *testValidateSessionServiceManager) ReloadAllServicesForAccount(_ context.Context, _ string) error { return nil } -func (m *testValidateSessionProxyManager) ReloadService(_ context.Context, _, _ string) error { +func (m *testValidateSessionServiceManager) ReloadService(_ context.Context, _, _ string) error { return nil } -func (m *testValidateSessionProxyManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { +func (m *testValidateSessionServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) { return m.store.GetServices(ctx, store.LockingStrengthNone) } -func (m *testValidateSessionProxyManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) { +func (m *testValidateSessionServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) { return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID) } -func (m *testValidateSessionProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { +func (m *testValidateSessionServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) { return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) } -func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) { +func (m *testValidateSessionServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ string) (string, error) { return "", nil } -func (m *testValidateSessionProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { +func (m *testValidateSessionServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) { return nil, nil } -func (m *testValidateSessionProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { +func (m *testValidateSessionServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { return nil } -func (m *testValidateSessionProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error { +func (m *testValidateSessionServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error { return nil } -func (m *testValidateSessionProxyManager) StartExposeReaper(_ context.Context) {} +func (m *testValidateSessionServiceManager) StartExposeReaper(_ context.Context) {} + +type testValidateSessionProxyManager struct{} + +func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error { + return nil +} + +func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string) error { + return nil +} + +func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error { + return nil +} + +func (m *testValidateSessionProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) { + return nil, nil +} + +func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time.Duration) error { + return nil +} type testValidateSessionUsersManager struct { store store.Store diff --git a/management/server/account.go b/management/server/account.go index fb8592164..550971337 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -15,7 +15,7 @@ import ( "sync" "time" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/shared/auth" @@ -83,9 +83,9 @@ type DefaultAccountManager struct { requestBuffer *AccountRequestBuffer - proxyController port_forwarding.Controller - settingsManager settings.Manager - reverseProxyManager reverseproxy.Manager + proxyController port_forwarding.Controller + settingsManager settings.Manager + serviceManager service.Manager // config contains the management server configuration config *nbconfig.Config @@ -115,8 +115,8 @@ type DefaultAccountManager struct { var _ account.Manager = (*DefaultAccountManager)(nil) -func (am *DefaultAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) { - am.reverseProxyManager = serviceManager +func (am *DefaultAccountManager) SetServiceManager(serviceManager service.Manager) { + am.serviceManager = serviceManager } func isUniqueConstraintError(err error) bool { @@ -395,7 +395,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta) } if reloadReverseProxy { - if err = am.reverseProxyManager.ReloadAllServicesForAccount(ctx, accountID); err != nil { + if err = am.serviceManager.ReloadAllServicesForAccount(ctx, accountID); err != nil { log.WithContext(ctx).Warnf("failed to reload all services for account %s: %v", accountID, err) } } @@ -730,7 +730,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err) } - err = am.reverseProxyManager.DeleteAllServices(ctx, accountID, userID) + err = am.serviceManager.DeleteAllServices(ctx, accountID, userID) if err != nil { return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 893e894e1..45af63ae8 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -8,7 +8,7 @@ import ( "net/netip" "time" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/shared/auth" nbdns "github.com/netbirdio/netbird/dns" @@ -142,5 +142,5 @@ type Manager interface { CreatePeerJob(ctx context.Context, accountID, peerID, userID string, job *types.Job) error GetAllPeerJobs(ctx context.Context, accountID, userID, peerID string) ([]*types.Job, error) GetPeerJobByID(ctx context.Context, accountID, userID, peerID, jobID string) (*types.Job, error) - SetServiceManager(serviceManager reverseproxy.Manager) + SetServiceManager(serviceManager service.Manager) } diff --git a/management/server/account/manager_mock.go b/management/server/account/manager_mock.go index ab6e8b1c9..90700c795 100644 --- a/management/server/account/manager_mock.go +++ b/management/server/account/manager_mock.go @@ -13,7 +13,7 @@ import ( gomock "github.com/golang/mock/gomock" dns "github.com/netbirdio/netbird/dns" - reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" activity "github.com/netbirdio/netbird/management/server/activity" idp "github.com/netbirdio/netbird/management/server/idp" peer "github.com/netbirdio/netbird/management/server/peer" @@ -1494,7 +1494,7 @@ func (mr *MockManagerMockRecorder) SaveUser(ctx, accountID, initiatorUserID, upd } // SetServiceManager mocks base method. -func (m *MockManager) SetServiceManager(serviceManager reverseproxy.Manager) { +func (m *MockManager) SetServiceManager(serviceManager service.Manager) { m.ctrl.T.Helper() m.ctrl.Call(m, "SetServiceManager", serviceManager) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 340e130d9..65bab6c18 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -19,6 +19,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" @@ -27,8 +28,10 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" - reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/server/config" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" @@ -1803,12 +1806,12 @@ func TestAccount_Copy(t *testing.T) { Address: "172.12.6.1/24", }, }, - Services: []*reverseproxy.Service{ + Services: []*service.Service{ { ID: "service1", Name: "test-service", AccountID: "account1", - Targets: []*reverseproxy.Target{}, + Targets: []*service.Target{}, }, }, NetworkMapCache: &types.NetworkMapBuilder{}, @@ -3113,6 +3116,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU permissionsManager := permissions.NewManager(store) peersManager := peers.NewManager(store, permissionsManager) + proxyManager := proxy.NewMockManager(ctrl) + proxyManager.EXPECT(). + CleanupStale(gomock.Any(), gomock.Any()). + Return(nil). + AnyTimes() + ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) @@ -3123,8 +3132,12 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU return nil, nil, err } - proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil) - manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, settingsMockManager, proxyGrpcServer, nil)) + proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager) + proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{}) + if err != nil { + return nil, nil, err + } + manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyController, nil)) return manager, updateManager, nil } diff --git a/management/server/group_test.go b/management/server/group_test.go index dd6869d50..fa818e532 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -766,7 +766,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { t.Run("saving group linked to network router", func(t *testing.T) { permissionsManager := permissions.NewManager(manager.Store) groupsManager := groups.NewManager(manager.Store, permissionsManager, manager) - resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.reverseProxyManager) + resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager) routersManager := routers.NewManager(manager.Store, permissionsManager, manager) networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager) diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 9d2384cae..ddeda6d7f 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -17,9 +17,9 @@ import ( "github.com/netbirdio/netbird/management/server/types" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" - reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" idpmanager "github.com/netbirdio/netbird/management/server/idp" @@ -73,7 +73,7 @@ const ( ) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, reverseProxyManager reverseproxy.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { // Register bypass paths for unauthenticated endpoints if err := bypass.AddBypassPath("/api/instance"); err != nil { @@ -173,8 +173,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks idp.AddEndpoints(accountManager, router) instance.AddEndpoints(instanceManager, router) instance.AddVersionEndpoint(instanceManager, router) - if reverseProxyManager != nil && reverseProxyDomainManager != nil { - reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router) + if serviceManager != nil && reverseProxyDomainManager != nil { + reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router) } // Register OAuth callback handler for proxy authentication diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index 12634dda4..c7fd08da8 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -18,8 +18,8 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -190,7 +190,8 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { oidcServer := newFakeOIDCServer() - tokenStore := nbgrpc.NewOneTimeTokenStore(time.Minute) + tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) + require.NoError(t, err) usersManager := users.NewManager(testStore) @@ -208,9 +209,10 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { oidcConfig, nil, usersManager, + nil, ) - proxyService.SetProxyManager(&testServiceManager{store: testStore}) + proxyService.SetServiceManager(&testServiceManager{store: testStore}) handler := NewAuthCallbackHandler(proxyService, nil) @@ -239,12 +241,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store pubKey := base64.StdEncoding.EncodeToString(pub) privKey := base64.StdEncoding.EncodeToString(priv) - testProxy := &reverseproxy.Service{ + testProxy := &service.Service{ ID: "testProxyId", AccountID: "testAccountId", Name: "Test Proxy", Domain: "test-proxy.example.com", - Targets: []*reverseproxy.Target{{ + Targets: []*service.Target{{ Path: strPtr("/"), Host: "localhost", Port: 8080, @@ -254,8 +256,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store Enabled: true, }}, Enabled: true, - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, DistributionGroups: []string{"allowedGroupId"}, }, @@ -265,12 +267,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store } require.NoError(t, testStore.CreateService(ctx, testProxy)) - restrictedProxy := &reverseproxy.Service{ + restrictedProxy := &service.Service{ ID: "restrictedProxyId", AccountID: "testAccountId", Name: "Restricted Proxy", Domain: "restricted-proxy.example.com", - Targets: []*reverseproxy.Target{{ + Targets: []*service.Target{{ Path: strPtr("/"), Host: "localhost", Port: 8080, @@ -280,8 +282,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store Enabled: true, }}, Enabled: true, - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: true, DistributionGroups: []string{"restrictedGroupId"}, }, @@ -291,12 +293,12 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store } require.NoError(t, testStore.CreateService(ctx, restrictedProxy)) - noAuthProxy := &reverseproxy.Service{ + noAuthProxy := &service.Service{ ID: "noAuthProxyId", AccountID: "testAccountId", Name: "No Auth Proxy", Domain: "no-auth-proxy.example.com", - Targets: []*reverseproxy.Target{{ + Targets: []*service.Target{{ Path: strPtr("/"), Host: "localhost", Port: 8080, @@ -306,8 +308,8 @@ func createTestReverseProxies(t *testing.T, ctx context.Context, testStore store Enabled: true, }}, Enabled: true, - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{ + Auth: service.AuthConfig{ + BearerAuth: &service.BearerAuthConfig{ Enabled: false, }, }, @@ -361,19 +363,19 @@ func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, u return nil } -func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) { +func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*service.Service, error) { return nil, nil } -func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*reverseproxy.Service, error) { +func (m *testServiceManager) GetService(_ context.Context, _, _, _ string) (*service.Service, error) { return nil, nil } -func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *testServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) { return nil, nil } -func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *testServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) { return nil, nil } @@ -385,7 +387,7 @@ func (m *testServiceManager) SetCertificateIssuedAt(_ context.Context, _, _ stri return nil } -func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ reverseproxy.ProxyStatus) error { +func (m *testServiceManager) SetStatus(_ context.Context, _, _ string, _ service.Status) error { return nil } @@ -397,15 +399,15 @@ func (m *testServiceManager) ReloadService(_ context.Context, _, _ string) error return nil } -func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { +func (m *testServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) { return m.store.GetServices(ctx, store.LockingStrengthNone) } -func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*reverseproxy.Service, error) { +func (m *testServiceManager) GetServiceByID(ctx context.Context, accountID, proxyID string) (*service.Service, error) { return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, proxyID) } -func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { +func (m *testServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) { return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) } @@ -413,7 +415,7 @@ func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ stri return "", nil } -func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { +func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) { return nil, nil } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index fd2dc5848..1d74f88d5 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -9,10 +9,13 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" + "go.opentelemetry.io/otel/metric/noop" + "github.com/netbirdio/management-integrations/integrations" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" - reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager" + proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" + reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" @@ -91,12 +94,24 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee } accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) - proxyTokenStore := nbgrpc.NewOneTimeTokenStore(1 * time.Minute) - proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager) - domainManager := manager.NewManager(store, proxyServiceServer, permissionsManager) - reverseProxyManager := reverseproxymanager.NewManager(store, am, permissionsManager, settingsManager, proxyServiceServer, domainManager) - proxyServiceServer.SetProxyManager(reverseProxyManager) - am.SetServiceManager(reverseProxyManager) + proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) + if err != nil { + t.Fatalf("Failed to create proxy token store: %v", err) + } + noopMeter := noop.NewMeterProvider().Meter("") + proxyMgr, err := proxymanager.NewManager(store, noopMeter) + if err != nil { + t.Fatalf("Failed to create proxy manager: %v", err) + } + proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr) + domainManager := manager.NewManager(store, proxyMgr, permissionsManager) + serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) + if err != nil { + t.Fatalf("Failed to create proxy controller: %v", err) + } + serviceManager := reverseproxymanager.NewManager(store, am, permissionsManager, serviceProxyController, domainManager) + proxyServiceServer.SetServiceManager(serviceManager) + am.SetServiceManager(serviceManager) // @note this is required so that PAT's validate from store, but JWT's are mocked authManager := serverauth.NewManager(store, "", "", "", "", []string{}, false) @@ -114,7 +129,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, reverseProxyManager, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index 9b1383c6c..f25a72181 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -14,7 +14,7 @@ import ( "github.com/hashicorp/go-version" "github.com/netbirdio/netbird/idp/dex" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/types" @@ -358,12 +358,12 @@ func (w *Worker) generateProperties(ctx context.Context) properties { } servicesTargets += len(service.Targets) - switch reverseproxy.ProxyStatus(service.Meta.Status) { - case reverseproxy.StatusActive: + switch rpservice.Status(service.Meta.Status) { + case rpservice.StatusActive: servicesStatusActive++ - case reverseproxy.StatusPending: + case rpservice.StatusPending: servicesStatusPending++ - case reverseproxy.StatusError, reverseproxy.StatusCertificateFailed, reverseproxy.StatusTunnelNotCreated: + case rpservice.StatusError, rpservice.StatusCertificateFailed, rpservice.StatusTunnelNotCreated: servicesStatusError++ } diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index bc4d68178..412559bff 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -6,7 +6,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/idp/dex" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -116,29 +116,29 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { }, }, }, - Services: []*reverseproxy.Service{ + Services: []*rpservice.Service{ { ID: "svc1", Enabled: true, - Targets: []*reverseproxy.Target{ + Targets: []*rpservice.Target{ {TargetType: "peer"}, {TargetType: "host"}, }, - Auth: reverseproxy.AuthConfig{ - PasswordAuth: &reverseproxy.PasswordAuthConfig{Enabled: true}, + Auth: rpservice.AuthConfig{ + PasswordAuth: &rpservice.PasswordAuthConfig{Enabled: true}, }, - Meta: reverseproxy.ServiceMeta{Status: string(reverseproxy.StatusActive)}, + Meta: rpservice.Meta{Status: string(rpservice.StatusActive)}, }, { ID: "svc2", Enabled: false, - Targets: []*reverseproxy.Target{ + Targets: []*rpservice.Target{ {TargetType: "domain"}, }, - Auth: reverseproxy.AuthConfig{ - BearerAuth: &reverseproxy.BearerAuthConfig{Enabled: true}, + Auth: rpservice.AuthConfig{ + BearerAuth: &rpservice.BearerAuthConfig{Enabled: true}, }, - Meta: reverseproxy.ServiceMeta{Status: string(reverseproxy.StatusPending)}, + Meta: rpservice.Meta{Status: string(rpservice.StatusPending)}, }, }, }, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index ea848328f..afd2021ac 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -12,7 +12,7 @@ import ( "google.golang.org/grpc/status" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" @@ -148,7 +148,7 @@ type MockAccountManager struct { DeleteUserInviteFunc func(ctx context.Context, accountID, initiatorUserID, inviteID string) error } -func (am *MockAccountManager) SetServiceManager(serviceManager reverseproxy.Manager) { +func (am *MockAccountManager) SetServiceManager(serviceManager service.Manager) { // Mock implementation - no-op } diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 843ca93e5..86f9b6579 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -7,7 +7,7 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/groups" @@ -33,23 +33,23 @@ type Manager interface { } type managerImpl struct { - store store.Store - permissionsManager permissions.Manager - groupsManager groups.Manager - accountManager account.Manager - reverseProxyManager reverseproxy.Manager + store store.Store + permissionsManager permissions.Manager + groupsManager groups.Manager + accountManager account.Manager + serviceManager service.Manager } type mockManager struct { } -func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager reverseproxy.Manager) Manager { +func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager account.Manager, reverseproxyManager service.Manager) Manager { return &managerImpl{ - store: store, - permissionsManager: permissionsManager, - groupsManager: groupsManager, - accountManager: accountManager, - reverseProxyManager: reverseproxyManager, + store: store, + permissionsManager: permissionsManager, + groupsManager: groupsManager, + accountManager: accountManager, + serviceManager: reverseproxyManager, } } @@ -264,7 +264,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc // TODO: optimize to only reload reverse proxies that are affected by the resource update instead of all of them go func() { - err := m.reverseProxyManager.ReloadAllServicesForAccount(ctx, resource.AccountID) + err := m.serviceManager.ReloadAllServicesForAccount(ctx, resource.AccountID) if err != nil { log.WithContext(ctx).Warnf("failed to reload all proxies for account: %v", err) } @@ -322,7 +322,7 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net return status.NewPermissionDeniedError() } - serviceID, err := m.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, resourceID) + serviceID, err := m.serviceManager.GetServiceIDByTargetID(ctx, accountID, resourceID) if err != nil { return fmt.Errorf("failed to check if resource is used by service: %w", err) } diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go index 99de484e5..c6d8e7bcc 100644 --- a/management/server/networks/resources/manager_test.go +++ b/management/server/networks/resources/manager_test.go @@ -7,7 +7,7 @@ import ( "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -31,8 +31,8 @@ func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) require.NoError(t, err) @@ -54,8 +54,8 @@ func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) require.Error(t, err) @@ -76,8 +76,8 @@ func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) require.NoError(t, err) @@ -98,8 +98,8 @@ func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) require.Error(t, err) @@ -123,8 +123,8 @@ func Test_GetResourceInNetworkReturnsResources(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) require.NoError(t, err) @@ -147,8 +147,8 @@ func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) require.Error(t, err) @@ -176,9 +176,9 @@ func Test_CreateResourceSuccessfully(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes() - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), resource.AccountID).Return(nil).AnyTimes() + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) createdResource, err := manager.CreateResource(ctx, userID, resource) require.NoError(t, err) @@ -205,8 +205,8 @@ func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) createdResource, err := manager.CreateResource(ctx, userID, resource) require.Error(t, err) @@ -234,8 +234,8 @@ func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) createdResource, err := manager.CreateResource(ctx, userID, resource) require.Error(t, err) @@ -262,8 +262,8 @@ func Test_CreateResourceFailsWithUsedName(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) createdResource, err := manager.CreateResource(ctx, userID, resource) require.Error(t, err) @@ -294,9 +294,9 @@ func Test_UpdateResourceSuccessfully(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - reverseProxyManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes() - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + serviceManager.EXPECT().ReloadAllServicesForAccount(gomock.Any(), accountID).Return(nil).AnyTimes() + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.NoError(t, err) @@ -329,8 +329,8 @@ func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.Error(t, err) @@ -361,8 +361,8 @@ func Test_UpdateResourceFailsWithNameInUse(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.Error(t, err) @@ -392,8 +392,8 @@ func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) updatedResource, err := manager.UpdateResource(ctx, userID, resource) require.Error(t, err) @@ -416,9 +416,9 @@ func Test_DeleteResourceSuccessfully(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - reverseProxyManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes() - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + serviceManager.EXPECT().GetServiceIDByTargetID(gomock.Any(), accountID, resourceID).Return("", nil).AnyTimes() + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) require.NoError(t, err) @@ -440,8 +440,8 @@ func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) { am := mock_server.MockAccountManager{} groupsManager := groups.NewManagerMock() ctrl := gomock.NewController(t) - reverseProxyManager := reverseproxy.NewMockManager(ctrl) - manager := NewManager(store, permissionsManager, groupsManager, &am, reverseProxyManager) + serviceManager := reverseproxy.NewMockManager(ctrl) + manager := NewManager(store, permissionsManager, groupsManager, &am, serviceManager) err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) require.Error(t, err) diff --git a/management/server/peer.go b/management/server/peer.go index a2ca97208..78ecbfcae 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -493,7 +493,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer var settings *types.Settings var eventsToStore []func() - serviceID, err := am.reverseProxyManager.GetServiceIDByTargetID(ctx, accountID, peerID) + serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID) if err != nil { return fmt.Errorf("failed to check if resource is used by service: %w", err) } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 269b30822..db392ddda 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -352,9 +352,10 @@ func (p *Peer) FromAPITemporaryAccessRequest(a *api.PeerTemporaryAccessRequest) p.Name = a.Name p.Key = a.WgPubKey p.Meta = PeerSystemMeta{ - Hostname: a.Name, - GoOS: "js", - OS: "js", + Hostname: a.Name, + GoOS: "js", + OS: "js", + KernelVersion: "wasm", } } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 04045f226..41c53980b 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -28,9 +28,10 @@ import ( "gorm.io/gorm/logger" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones/records" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -131,8 +132,8 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, - &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &reverseproxy.Service{}, &reverseproxy.Target{}, &domain.Domain{}, - &accesslogs.AccessLogEntry{}, + &types.Job{}, &zones.Zone{}, &records.Record{}, &types.UserInviteRecord{}, &rpservice.Service{}, &rpservice.Target{}, &domain.Domain{}, + &accesslogs.AccessLogEntry{}, &proxy.Proxy{}, ) if err != nil { return nil, fmt.Errorf("auto migratePreAuto: %w", err) @@ -2075,7 +2076,7 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p return checks, nil } -func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { +func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) { const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth, meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster, pass_host_header, rewrite_redirects, session_private_key, session_public_key @@ -2090,8 +2091,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers return nil, err } - services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*reverseproxy.Service, error) { - var s reverseproxy.Service + services, err := pgx.CollectRows(serviceRows, func(row pgx.CollectableRow) (*rpservice.Service, error) { + var s rpservice.Service var auth []byte var createdAt, certIssuedAt sql.NullTime var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString @@ -2121,7 +2122,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers } } - s.Meta = reverseproxy.ServiceMeta{} + s.Meta = rpservice.Meta{} if createdAt.Valid { s.Meta.CreatedAt = createdAt.Time } @@ -2142,7 +2143,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers s.SessionPublicKey = sessionPublicKey.String } - s.Targets = []*reverseproxy.Target{} + s.Targets = []*rpservice.Target{} return &s, nil }) if err != nil { @@ -2154,7 +2155,7 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers } serviceIDs := make([]string, len(services)) - serviceMap := make(map[string]*reverseproxy.Service) + serviceMap := make(map[string]*rpservice.Service) for i, s := range services { serviceIDs[i] = s.ID serviceMap[s.ID] = s @@ -2165,8 +2166,8 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*revers return nil, err } - targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*reverseproxy.Target, error) { - var t reverseproxy.Target + targets, err := pgx.CollectRows(targetRows, func(row pgx.CollectableRow) (*rpservice.Target, error) { + var t rpservice.Target var path sql.NullString err := row.Scan( &t.ID, @@ -4852,7 +4853,7 @@ func (s *SqlStore) GetPeerIDByKey(ctx context.Context, lockStrength LockingStren return peerID, nil } -func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Service) error { +func (s *SqlStore) CreateService(ctx context.Context, service *rpservice.Service) error { serviceCopy := service.Copy() if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { return fmt.Errorf("encrypt service data: %w", err) @@ -4866,16 +4867,19 @@ func (s *SqlStore) CreateService(ctx context.Context, service *reverseproxy.Serv return nil } -func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error { +func (s *SqlStore) UpdateService(ctx context.Context, service *rpservice.Service) error { serviceCopy := service.Copy() if err := serviceCopy.EncryptSensitiveData(s.fieldEncrypt); err != nil { return fmt.Errorf("encrypt service data: %w", err) } + // Create target type instance outside transaction to avoid variable shadowing + targetType := &rpservice.Target{} + // Use a transaction to ensure atomic updates of the service and its targets err := s.db.Transaction(func(tx *gorm.DB) error { // Delete existing targets - if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(&reverseproxy.Target{}).Error; err != nil { + if err := tx.Where("service_id = ?", serviceCopy.ID).Delete(targetType).Error; err != nil { return err } @@ -4896,7 +4900,7 @@ func (s *SqlStore) UpdateService(ctx context.Context, service *reverseproxy.Serv } func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID string) error { - result := s.db.Delete(&reverseproxy.Service{}, accountAndIDQueryCondition, accountID, serviceID) + result := s.db.Delete(&rpservice.Service{}, accountAndIDQueryCondition, accountID, serviceID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete service from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete service from store") @@ -4910,7 +4914,7 @@ func (s *SqlStore) DeleteService(ctx context.Context, accountID, serviceID strin } func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error { - result := s.db.Delete(&reverseproxy.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID) + result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ? AND id = ?", accountID, serviceID, targetID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete target from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete target from store") @@ -4924,7 +4928,7 @@ func (s *SqlStore) DeleteTarget(ctx context.Context, accountID string, serviceID } func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error { - result := s.db.Delete(&reverseproxy.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID) + result := s.db.Delete(&rpservice.Target{}, "account_id = ? AND service_id = ?", accountID, serviceID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete targets from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete targets from store") @@ -4934,8 +4938,8 @@ func (s *SqlStore) DeleteServiceTargets(ctx context.Context, accountID string, s } // GetTargetsByServiceID retrieves all targets for a given service -func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*reverseproxy.Target, error) { - var targets []*reverseproxy.Target +func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error) { + var targets []*rpservice.Target tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) @@ -4949,13 +4953,13 @@ func (s *SqlStore) GetTargetsByServiceID(ctx context.Context, lockStrength Locki return targets, nil } -func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) { +func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) { tx := s.db.Preload("Targets") if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } - var service *reverseproxy.Service + var service *rpservice.Service result := tx.Take(&service, accountAndIDQueryCondition, accountID, serviceID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -4973,30 +4977,8 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren return service, nil } -func (s *SqlStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) { - tx := s.db.Preload("Targets") - if lockStrength != LockingStrengthNone { - tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) - } - - var serviceList []*reverseproxy.Service - result := tx.Find(&serviceList, accountIDCondition, accountID) - if result.Error != nil { - log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "failed to get services from store") - } - - for _, service := range serviceList { - if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil { - return nil, fmt.Errorf("decrypt service data: %w", err) - } - } - - return serviceList, nil -} - -func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) { - var service *reverseproxy.Service +func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) { + var service *rpservice.Service result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -5014,13 +4996,13 @@ func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain str return service, nil } -func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) { +func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) { tx := s.db.Preload("Targets") if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } - var serviceList []*reverseproxy.Service + var serviceList []*rpservice.Service result := tx.Find(&serviceList) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error) @@ -5036,13 +5018,13 @@ func (s *SqlStore) GetServices(ctx context.Context, lockStrength LockingStrength return serviceList, nil } -func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) { +func (s *SqlStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) { tx := s.db.Preload("Targets") if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } - var serviceList []*reverseproxy.Service + var serviceList []*rpservice.Service result := tx.Find(&serviceList, accountIDCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error) @@ -5270,13 +5252,13 @@ func (s *SqlStore) applyAccessLogFilters(query *gorm.DB, filter accesslogs.Acces return query } -func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) { +func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) { tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } - var target *reverseproxy.Target + var target *rpservice.Target result := tx.Take(&target, "account_id = ? AND target_id = ?", accountID, targetID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -5289,3 +5271,65 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength return target, nil } + +// SaveProxy saves or updates a proxy in the database +func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error { + result := s.db.WithContext(ctx).Save(p) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error) + return status.Errorf(status.Internal, "failed to save proxy") + } + return nil +} + +// UpdateProxyHeartbeat updates the last_seen timestamp for a proxy +func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error { + result := s.db.WithContext(ctx). + Model(&proxy.Proxy{}). + Where("id = ? AND status = ?", proxyID, "connected"). + Update("last_seen", time.Now()) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to update proxy heartbeat: %v", result.Error) + return status.Errorf(status.Internal, "failed to update proxy heartbeat") + } + return nil +} + +// GetActiveProxyClusterAddresses returns all unique cluster addresses for active proxies +func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) { + var addresses []string + + result := s.db.WithContext(ctx). + Model(&proxy.Proxy{}). + Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-2*time.Minute)). + Distinct("cluster_address"). + Pluck("cluster_address", &addresses) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get active proxy cluster addresses") + } + + return addresses, nil +} + +// CleanupStaleProxies deletes proxies that haven't sent heartbeat in the specified duration +func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error { + cutoffTime := time.Now().Add(-inactivityDuration) + + result := s.db.WithContext(ctx). + Where("last_seen < ?", cutoffTime). + Delete(&proxy.Proxy{}) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", result.Error) + return status.Errorf(status.Internal, "failed to cleanup stale proxies") + } + + if result.RowsAffected > 0 { + log.WithContext(ctx).Infof("Cleaned up %d stale proxies", result.RowsAffected) + } + + return nil +} diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index fa9a9dbf5..f2abafceb 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -20,7 +20,7 @@ import ( "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -264,7 +264,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) { &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, - &types.AccountOnboarding{}, &reverseproxy.Service{}, &reverseproxy.Target{}, + &types.AccountOnboarding{}, &service.Service{}, &service.Target{}, } for i := len(models) - 1; i >= 0; i-- { diff --git a/management/server/store/store.go b/management/server/store/store.go index 9e982f70b..941aca08a 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -25,9 +25,10 @@ import ( "gorm.io/gorm" "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/server/telemetry" @@ -252,14 +253,13 @@ type Store interface { MarkAllPendingJobsAsFailed(ctx context.Context, accountID, peerID, reason string) error GetPeerIDByKey(ctx context.Context, lockStrength LockingStrength, key string) (string, error) - CreateService(ctx context.Context, service *reverseproxy.Service) error - UpdateService(ctx context.Context, service *reverseproxy.Service) error + CreateService(ctx context.Context, service *rpservice.Service) error + UpdateService(ctx context.Context, service *rpservice.Service) error DeleteService(ctx context.Context, accountID, serviceID string) error - GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) - GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) - GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) - GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) - GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) + GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*rpservice.Service, error) + GetServiceByDomain(ctx context.Context, accountID, domain string) (*rpservice.Service, error) + GetServices(ctx context.Context, lockStrength LockingStrength) ([]*rpservice.Service, error) + GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*rpservice.Service, error) GetCustomDomain(ctx context.Context, accountID string, domainID string) (*domain.Domain, error) ListFreeDomains(ctx context.Context, accountID string) ([]string, error) @@ -271,12 +271,16 @@ type Store interface { CreateAccessLog(ctx context.Context, log *accesslogs.AccessLogEntry) error GetAccountAccessLogs(ctx context.Context, lockStrength LockingStrength, accountID string, filter accesslogs.AccessLogFilter) ([]*accesslogs.AccessLogEntry, int64, error) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) - GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*reverseproxy.Target, error) - GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*reverseproxy.Target, error) + GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID string, targetID string) (*rpservice.Target, error) + GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID string, serviceID string) ([]*rpservice.Target, error) DeleteTarget(ctx context.Context, accountID string, serviceID string, targetID uint) error DeleteServiceTargets(ctx context.Context, accountID string, serviceID string) error - // GetCustomDomainsCounts returns the total and validated custom domain counts. + SaveProxy(ctx context.Context, proxy *proxy.Proxy) error + UpdateProxyHeartbeat(ctx context.Context, proxyID string) error + GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) + CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error + GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error) } diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 682ecc4d8..9e11f85fb 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -12,9 +12,10 @@ import ( gomock "github.com/golang/mock/gomock" dns "github.com/netbirdio/netbird/dns" - reverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" accesslogs "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" domain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" + proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" zones "github.com/netbirdio/netbird/management/internals/modules/zones" records "github.com/netbirdio/netbird/management/internals/modules/zones/records" types "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -150,6 +151,20 @@ func (mr *MockStoreMockRecorder) ApproveAccountPeers(ctx, accountID interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ApproveAccountPeers", reflect.TypeOf((*MockStore)(nil).ApproveAccountPeers), ctx, accountID) } +// CleanupStaleProxies mocks base method. +func (m *MockStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CleanupStaleProxies", ctx, inactivityDuration) + ret0, _ := ret[0].(error) + return ret0 +} + +// CleanupStaleProxies indicates an expected call of CleanupStaleProxies. +func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration) +} + // Close mocks base method. func (m *MockStore) Close(ctx context.Context) error { m.ctrl.T.Helper() @@ -293,7 +308,7 @@ func (mr *MockStoreMockRecorder) CreatePolicy(ctx, policy interface{}) *gomock.C } // CreateService mocks base method. -func (m *MockStore) CreateService(ctx context.Context, service *reverseproxy.Service) error { +func (m *MockStore) CreateService(ctx context.Context, service *service.Service) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "CreateService", ctx, service) ret0, _ := ret[0].(error) @@ -1123,10 +1138,10 @@ func (mr *MockStoreMockRecorder) GetAccountRoutes(ctx, lockStrength, accountID i } // GetAccountServices mocks base method. -func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) { +func (m *MockStore) GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*service.Service, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetAccountServices", ctx, lockStrength, accountID) - ret0, _ := ret[0].([]*reverseproxy.Service) + ret0, _ := ret[0].([]*service.Service) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1227,6 +1242,21 @@ func (mr *MockStoreMockRecorder) GetAccountsCounter(ctx interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountsCounter", reflect.TypeOf((*MockStore)(nil).GetAccountsCounter), ctx) } +// GetActiveProxyClusterAddresses mocks base method. +func (m *MockStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetActiveProxyClusterAddresses", ctx) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetActiveProxyClusterAddresses indicates an expected call of GetActiveProxyClusterAddresses. +func (mr *MockStoreMockRecorder) GetActiveProxyClusterAddresses(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveProxyClusterAddresses", reflect.TypeOf((*MockStore)(nil).GetActiveProxyClusterAddresses), ctx) +} + // GetAllAccounts mocks base method. func (m *MockStore) GetAllAccounts(ctx context.Context) []*types2.Account { m.ctrl.T.Helper() @@ -1857,10 +1887,10 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout } // GetServiceByDomain mocks base method. -func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) { +func (m *MockStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*service.Service, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetServiceByDomain", ctx, accountID, domain) - ret0, _ := ret[0].(*reverseproxy.Service) + ret0, _ := ret[0].(*service.Service) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1872,10 +1902,10 @@ func (mr *MockStoreMockRecorder) GetServiceByDomain(ctx, accountID, domain inter } // GetServiceByID mocks base method. -func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error) { +func (m *MockStore) GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*service.Service, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetServiceByID", ctx, lockStrength, accountID, serviceID) - ret0, _ := ret[0].(*reverseproxy.Service) + ret0, _ := ret[0].(*service.Service) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1887,10 +1917,10 @@ func (mr *MockStoreMockRecorder) GetServiceByID(ctx, lockStrength, accountID, se } // GetServiceTargetByTargetID mocks base method. -func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*reverseproxy.Target, error) { +func (m *MockStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength LockingStrength, accountID, targetID string) (*service.Target, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetServiceTargetByTargetID", ctx, lockStrength, accountID, targetID) - ret0, _ := ret[0].(*reverseproxy.Target) + ret0, _ := ret[0].(*service.Target) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1902,10 +1932,10 @@ func (mr *MockStoreMockRecorder) GetServiceTargetByTargetID(ctx, lockStrength, a } // GetServices mocks base method. -func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error) { +func (m *MockStore) GetServices(ctx context.Context, lockStrength LockingStrength) ([]*service.Service, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetServices", ctx, lockStrength) - ret0, _ := ret[0].([]*reverseproxy.Service) + ret0, _ := ret[0].([]*service.Service) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -1916,21 +1946,6 @@ func (mr *MockStoreMockRecorder) GetServices(ctx, lockStrength interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServices", reflect.TypeOf((*MockStore)(nil).GetServices), ctx, lockStrength) } -// GetServicesByAccountID mocks base method. -func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID) - ret0, _ := ret[0].([]*reverseproxy.Service) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetServicesByAccountID indicates an expected call of GetServicesByAccountID. -func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID) -} - // GetSetupKeyByID mocks base method. func (m *MockStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types2.SetupKey, error) { m.ctrl.T.Helper() @@ -1991,10 +2006,10 @@ func (mr *MockStoreMockRecorder) GetTakenIPs(ctx, lockStrength, accountId interf } // GetTargetsByServiceID mocks base method. -func (m *MockStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) ([]*reverseproxy.Target, error) { +func (m *MockStore) GetTargetsByServiceID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) ([]*service.Target, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetTargetsByServiceID", ctx, lockStrength, accountID, serviceID) - ret0, _ := ret[0].([]*reverseproxy.Target) + ret0, _ := ret[0].([]*service.Target) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -2610,6 +2625,20 @@ func (mr *MockStoreMockRecorder) SavePostureChecks(ctx, postureCheck interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePostureChecks", reflect.TypeOf((*MockStore)(nil).SavePostureChecks), ctx, postureCheck) } +// SaveProxy mocks base method. +func (m *MockStore) SaveProxy(ctx context.Context, proxy *proxy.Proxy) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SaveProxy", ctx, proxy) + ret0, _ := ret[0].(error) + return ret0 +} + +// SaveProxy indicates an expected call of SaveProxy. +func (mr *MockStoreMockRecorder) SaveProxy(ctx, proxy interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SaveProxy", reflect.TypeOf((*MockStore)(nil).SaveProxy), ctx, proxy) +} + // SaveProxyAccessToken mocks base method. func (m *MockStore) SaveProxyAccessToken(ctx context.Context, token *types2.ProxyAccessToken) error { m.ctrl.T.Helper() @@ -2805,8 +2834,22 @@ func (mr *MockStoreMockRecorder) UpdateGroups(ctx, accountID, groups interface{} return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroups", reflect.TypeOf((*MockStore)(nil).UpdateGroups), ctx, accountID, groups) } +// UpdateProxyHeartbeat mocks base method. +func (m *MockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateProxyHeartbeat", ctx, proxyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateProxyHeartbeat indicates an expected call of UpdateProxyHeartbeat. +func (mr *MockStoreMockRecorder) UpdateProxyHeartbeat(ctx, proxyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateProxyHeartbeat", reflect.TypeOf((*MockStore)(nil).UpdateProxyHeartbeat), ctx, proxyID) +} + // UpdateService mocks base method. -func (m *MockStore) UpdateService(ctx context.Context, service *reverseproxy.Service) error { +func (m *MockStore) UpdateService(ctx context.Context, service *service.Service) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "UpdateService", ctx, service) ret0, _ := ret[0].(error) diff --git a/management/server/types/account.go b/management/server/types/account.go index 3208cc89a..6145ceeb2 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -18,7 +18,7 @@ import ( "github.com/netbirdio/netbird/client/ssh/auth" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/modules/zones/records" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -100,7 +100,7 @@ type Account struct { NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` - Services []*reverseproxy.Service `gorm:"foreignKey:AccountID;references:id"` + Services []*service.Service `gorm:"foreignKey:AccountID;references:id"` // Settings is a dictionary of Account settings Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"` @@ -906,7 +906,7 @@ func (a *Account) Copy() *Account { networkResources = append(networkResources, resource.Copy()) } - services := []*reverseproxy.Service{} + services := []*service.Service{} for _, service := range a.Services { services = append(services, service.Copy()) } @@ -1814,7 +1814,7 @@ func (a *Account) InjectProxyPolicies(ctx context.Context) { } } -func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *reverseproxy.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) { +func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *service.Service, proxyPeersByCluster map[string][]*nbpeer.Peer) { for _, target := range service.Targets { if !target.Enabled { continue @@ -1823,7 +1823,7 @@ func (a *Account) injectServiceProxyPolicies(ctx context.Context, service *rever } } -func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *reverseproxy.Service, target *reverseproxy.Target, proxyPeers []*nbpeer.Peer) { +func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *service.Service, target *service.Target, proxyPeers []*nbpeer.Peer) { port, ok := a.resolveTargetPort(ctx, target) if !ok { return @@ -1840,7 +1840,7 @@ func (a *Account) injectTargetProxyPolicies(ctx context.Context, service *revers } } -func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Target) (int, bool) { +func (a *Account) resolveTargetPort(ctx context.Context, target *service.Target) (int, bool) { if target.Port != 0 { return target.Port, true } @@ -1856,7 +1856,7 @@ func (a *Account) resolveTargetPort(ctx context.Context, target *reverseproxy.Ta } } -func (a *Account) createProxyPolicy(service *reverseproxy.Service, target *reverseproxy.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy { +func (a *Account) createProxyPolicy(service *service.Service, target *service.Target, proxyPeer *nbpeer.Peer, port int, path string) *Policy { policyID := fmt.Sprintf("proxy-access-%s-%s-%s", service.ID, proxyPeer.ID, path) return &Policy{ ID: policyID, diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index c594f9800..50aa38b29 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -42,6 +42,8 @@ var ( acmeCerts bool acmeAddr string acmeDir string + acmeEABKID string + acmeEABHMACKey string acmeChallengeType string debugEndpoint bool debugEndpointAddr string @@ -74,6 +76,8 @@ func init() { rootCmd.Flags().BoolVar(&acmeCerts, "acme-certs", envBoolOrDefault("NB_PROXY_ACME_CERTIFICATES", false), "Generate ACME certificates automatically") rootCmd.Flags().StringVar(&acmeAddr, "acme-addr", envStringOrDefault("NB_PROXY_ACME_ADDRESS", ":80"), "HTTP address for ACME HTTP-01 challenges (only used when acme-challenge-type is http-01)") rootCmd.Flags().StringVar(&acmeDir, "acme-dir", envStringOrDefault("NB_PROXY_ACME_DIRECTORY", acme.LetsEncryptURL), "URL of ACME challenge directory") + rootCmd.Flags().StringVar(&acmeEABKID, "acme-eab-kid", envStringOrDefault("NB_PROXY_ACME_EAB_KID", ""), "ACME EAB KID for account registration") + rootCmd.Flags().StringVar(&acmeEABHMACKey, "acme-eab-hmac-key", envStringOrDefault("NB_PROXY_ACME_EAB_HMAC_KEY", ""), "ACME EAB HMAC key for account registration") rootCmd.Flags().StringVar(&acmeChallengeType, "acme-challenge-type", envStringOrDefault("NB_PROXY_ACME_CHALLENGE_TYPE", "tls-alpn-01"), "ACME challenge type: tls-alpn-01 (default, port 443 only) or http-01 (requires port 80)") rootCmd.Flags().BoolVar(&debugEndpoint, "debug-endpoint", envBoolOrDefault("NB_PROXY_DEBUG_ENDPOINT", false), "Enable debug HTTP endpoint") rootCmd.Flags().StringVar(&debugEndpointAddr, "debug-endpoint-addr", envStringOrDefault("NB_PROXY_DEBUG_ENDPOINT_ADDRESS", "localhost:8444"), "Address for the debug HTTP endpoint") @@ -149,6 +153,8 @@ func runServer(cmd *cobra.Command, args []string) error { GenerateACMECertificates: acmeCerts, ACMEChallengeAddress: acmeAddr, ACMEDirectory: acmeDir, + ACMEEABKID: acmeEABKID, + ACMEEABHMACKey: acmeEABHMACKey, ACMEChallengeType: acmeChallengeType, DebugEndpointEnabled: debugEndpoint, DebugEndpointAddress: debugEndpointAddr, diff --git a/proxy/internal/acme/manager.go b/proxy/internal/acme/manager.go index a663b8138..d491d65a3 100644 --- a/proxy/internal/acme/manager.go +++ b/proxy/internal/acme/manager.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/asn1" + "encoding/base64" "encoding/binary" "fmt" "net" @@ -59,7 +60,10 @@ type Manager struct { // NewManager creates a new ACME certificate manager. The certDir is used // for caching certificates. The lockMethod controls cross-replica // coordination strategy (see CertLockMethod constants). -func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager { +// eabKID and eabHMACKey are optional External Account Binding credentials +// required for some CAs like ZeroSSL. The eabHMACKey should be the base64 +// URL-encoded string provided by the CA. +func NewManager(certDir, acmeURL, eabKID, eabHMACKey string, notifier certificateNotifier, logger *log.Logger, lockMethod CertLockMethod) *Manager { if logger == nil { logger = log.StandardLogger() } @@ -70,10 +74,26 @@ func NewManager(certDir, acmeURL string, notifier certificateNotifier, logger *l certNotifier: notifier, logger: logger, } + + var eab *acme.ExternalAccountBinding + if eabKID != "" && eabHMACKey != "" { + decodedKey, err := base64.RawURLEncoding.DecodeString(eabHMACKey) + if err != nil { + logger.Errorf("failed to decode EAB HMAC key: %v", err) + } else { + eab = &acme.ExternalAccountBinding{ + KID: eabKID, + Key: decodedKey, + } + logger.Infof("configured External Account Binding with KID: %s", eabKID) + } + } + mgr.Manager = &autocert.Manager{ - Prompt: autocert.AcceptTOS, - HostPolicy: mgr.hostPolicy, - Cache: autocert.DirCache(certDir), + Prompt: autocert.AcceptTOS, + HostPolicy: mgr.hostPolicy, + Cache: autocert.DirCache(certDir), + ExternalAccountBinding: eab, Client: &acme.Client{ DirectoryURL: acmeURL, }, @@ -136,7 +156,7 @@ func (mgr *Manager) prefetchCertificate(d domain.Domain) { cert, err := mgr.GetCertificate(hello) elapsed := time.Since(start) if err != nil { - mgr.logger.Warnf("prefetch certificate for domain %q: %v", name, err) + mgr.logger.Warnf("prefetch certificate for domain %q in %s: %v", name, elapsed.String(), err) mgr.setDomainState(d, domainFailed, err.Error()) return } diff --git a/proxy/internal/acme/manager_test.go b/proxy/internal/acme/manager_test.go index 3b554e360..f7efe5933 100644 --- a/proxy/internal/acme/manager_test.go +++ b/proxy/internal/acme/manager_test.go @@ -10,7 +10,7 @@ import ( ) func TestHostPolicy(t *testing.T) { - mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "") + mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "") mgr.AddDomain("example.com", "acc1", "rp1") // Wait for the background prefetch goroutine to finish so the temp dir @@ -70,7 +70,7 @@ func TestHostPolicy(t *testing.T) { } func TestDomainStates(t *testing.T) { - mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", nil, nil, "") + mgr := NewManager(t.TempDir(), "https://acme.example.com/directory", "", "", nil, nil, "") assert.Equal(t, 0, mgr.PendingCerts(), "initially zero") assert.Equal(t, 0, mgr.TotalDomains(), "initially zero domains") diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index e91335a81..3e5a21400 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -18,8 +18,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" + nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -37,7 +38,7 @@ type integrationTestSetup struct { grpcServer *grpc.Server grpcAddr string cleanup func() - services []*reverseproxy.Service + services []*service.Service } func setupIntegrationTest(t *testing.T) *integrationTestSetup { @@ -66,13 +67,13 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { privKey := base64.StdEncoding.EncodeToString(priv) // Create test services in the store - services := []*reverseproxy.Service{ + services := []*service.Service{ { ID: "rp-1", AccountID: "test-account-1", Name: "Test App 1", Domain: "app1.test.proxy.io", - Targets: []*reverseproxy.Target{{ + Targets: []*service.Target{{ Path: strPtr("/"), Host: "10.0.0.1", Port: 8080, @@ -91,7 +92,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { AccountID: "test-account-1", Name: "Test App 2", Domain: "app2.test.proxy.io", - Targets: []*reverseproxy.Target{{ + Targets: []*service.Target{{ Path: strPtr("/"), Host: "10.0.0.2", Port: 8080, @@ -112,7 +113,8 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { } // Create real token store - tokenStore := nbgrpc.NewOneTimeTokenStore(5 * time.Minute) + tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) // Create real users manager usersManager := users.NewManager(testStore) @@ -124,17 +126,23 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { HMACKey: []byte("test-hmac-key"), } + proxyManager := &testProxyManager{} + proxyService := nbgrpc.NewProxyServiceServer( &testAccessLogManager{}, tokenStore, oidcConfig, nil, usersManager, + proxyManager, ) // Use store-backed service manager svcMgr := &storeBackedServiceManager{store: testStore, tokenStore: tokenStore} - proxyService.SetProxyManager(svcMgr) + proxyService.SetServiceManager(svcMgr) + + proxyController := &testProxyController{} + proxyService.SetProxyController(proxyController) // Start real gRPC server lis, err := net.Listen("tcp", "127.0.0.1:0") @@ -185,6 +193,52 @@ func (m *testAccessLogManager) GetAllAccessLogs(_ context.Context, _, _ string, return nil, 0, nil } +// testProxyManager is a mock implementation of proxy.Manager for testing. +type testProxyManager struct{} + +func (m *testProxyManager) Connect(_ context.Context, _, _, _ string) error { + return nil +} + +func (m *testProxyManager) Disconnect(_ context.Context, _ string) error { + return nil +} + +func (m *testProxyManager) Heartbeat(_ context.Context, _ string) error { + return nil +} + +func (m *testProxyManager) GetActiveClusterAddresses(_ context.Context) ([]string, error) { + return nil, nil +} + +func (m *testProxyManager) CleanupStale(_ context.Context, _ time.Duration) error { + return nil +} + +// testProxyController is a mock implementation of rpservice.ProxyController for testing. +type testProxyController struct{} + +func (c *testProxyController) SendServiceUpdateToCluster(_ context.Context, _ string, _ *proto.ProxyMapping, _ string) { + // noop +} + +func (c *testProxyController) GetOIDCValidationConfig() nbproxy.OIDCValidationConfig { + return nbproxy.OIDCValidationConfig{} +} + +func (c *testProxyController) RegisterProxyToCluster(_ context.Context, _, _ string) error { + return nil +} + +func (c *testProxyController) UnregisterProxyFromCluster(_ context.Context, _, _ string) error { + return nil +} + +func (c *testProxyController) GetProxiesForCluster(_ string) []string { + return nil +} + // storeBackedServiceManager reads directly from the real store. type storeBackedServiceManager struct { store store.Store @@ -195,19 +249,19 @@ func (m *storeBackedServiceManager) DeleteAllServices(ctx context.Context, accou return nil } -func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { +func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) } -func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) { +func (m *storeBackedServiceManager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) { return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) } -func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *storeBackedServiceManager) CreateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) { return nil, errors.New("not implemented") } -func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *storeBackedServiceManager) UpdateService(_ context.Context, _, _ string, _ *service.Service) (*service.Service, error) { return nil, errors.New("not implemented") } @@ -219,7 +273,7 @@ func (m *storeBackedServiceManager) SetCertificateIssuedAt(ctx context.Context, return nil } -func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error { +func (m *storeBackedServiceManager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error { return nil } @@ -231,15 +285,15 @@ func (m *storeBackedServiceManager) ReloadService(ctx context.Context, accountID return nil } -func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { +func (m *storeBackedServiceManager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) { return m.store.GetAccountServices(ctx, store.LockingStrengthNone, "test-account-1") } -func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) { +func (m *storeBackedServiceManager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) { return m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) } -func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { +func (m *storeBackedServiceManager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) { return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) } @@ -247,8 +301,8 @@ func (m *storeBackedServiceManager) GetServiceIDByTargetID(ctx context.Context, return "", nil } -func (m *storeBackedServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { - return &reverseproxy.ExposeServiceResponse{}, nil +func (m *storeBackedServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) { + return &service.ExposeServiceResponse{}, nil } func (m *storeBackedServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { diff --git a/proxy/server.go b/proxy/server.go index 48a876899..155610305 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -84,6 +84,10 @@ type Server struct { GenerateACMECertificates bool ACMEChallengeAddress string ACMEDirectory string + // ACMEEABKID is the External Account Binding Key ID for CAs that require EAB (e.g., ZeroSSL). + ACMEEABKID string + // ACMEEABHMACKey is the External Account Binding HMAC key (base64 URL-encoded) for CAs that require EAB. + ACMEEABHMACKey string // ACMEChallengeType specifies the ACME challenge type: "http-01" or "tls-alpn-01". // Defaults to "tls-alpn-01" if not specified. ACMEChallengeType string @@ -419,7 +423,7 @@ func (s *Server) configureTLS(ctx context.Context) (*tls.Config, error) { "acme_server": s.ACMEDirectory, "challenge_type": s.ACMEChallengeType, }).Debug("ACME certificates enabled, configuring certificate manager") - s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s, s.Logger, s.CertLockMethod) + s.acme = acme.NewManager(s.CertificateDirectory, s.ACMEDirectory, s.ACMEEABKID, s.ACMEEABHMACKey, s, s.Logger, s.CertLockMethod) if s.ACMEChallengeType == "http-01" { s.http = &http.Server{