mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-30 22:26:42 +00:00
extract proxy controller
This commit is contained in:
@@ -4,6 +4,8 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
@@ -21,3 +23,12 @@ type Manager interface {
|
|||||||
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
|
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
|
||||||
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
|
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProxyController is responsible for managing proxy clusters and routing service updates.
|
||||||
|
type ProxyController interface {
|
||||||
|
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
|
||||||
|
GetOIDCValidationConfig() OIDCValidationConfig
|
||||||
|
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||||
|
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||||
|
GetProxiesForCluster(clusterAddr string) []string
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
|
proto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockManager is a mock of Manager interface.
|
// MockManager is a mock of Manager interface.
|
||||||
@@ -223,3 +224,94 @@ func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MockProxyController is a mock of ProxyController interface.
|
||||||
|
type MockProxyController struct {
|
||||||
|
ctrl *gomock.Controller
|
||||||
|
recorder *MockProxyControllerMockRecorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockProxyControllerMockRecorder is the mock recorder for MockProxyController.
|
||||||
|
type MockProxyControllerMockRecorder struct {
|
||||||
|
mock *MockProxyController
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockProxyController creates a new mock instance.
|
||||||
|
func NewMockProxyController(ctrl *gomock.Controller) *MockProxyController {
|
||||||
|
mock := &MockProxyController{ctrl: ctrl}
|
||||||
|
mock.recorder = &MockProxyControllerMockRecorder{mock}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||||
|
func (m *MockProxyController) EXPECT() *MockProxyControllerMockRecorder {
|
||||||
|
return m.recorder
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOIDCValidationConfig mocks base method.
|
||||||
|
func (m *MockProxyController) GetOIDCValidationConfig() OIDCValidationConfig {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
|
||||||
|
ret0, _ := ret[0].(OIDCValidationConfig)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
|
||||||
|
func (mr *MockProxyControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockProxyController)(nil).GetOIDCValidationConfig))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxiesForCluster mocks base method.
|
||||||
|
func (m *MockProxyController) GetProxiesForCluster(clusterAddr string) []string {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
|
||||||
|
ret0, _ := ret[0].([]string)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
|
||||||
|
func (mr *MockProxyControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockProxyController)(nil).GetProxiesForCluster), clusterAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterProxyToCluster mocks base method.
|
||||||
|
func (m *MockProxyController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
|
||||||
|
func (mr *MockProxyControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockProxyController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendServiceUpdateToCluster mocks base method.
|
||||||
|
func (m *MockProxyController) SendServiceUpdateToCluster(accountID string, update *proto.ProxyMapping, clusterAddr string) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "SendServiceUpdateToCluster", accountID, update, clusterAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
|
||||||
|
func (mr *MockProxyControllerMockRecorder) SendServiceUpdateToCluster(accountID, update, clusterAddr interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockProxyController)(nil).SendServiceUpdateToCluster), accountID, update, clusterAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterProxyFromCluster mocks base method.
|
||||||
|
func (m *MockProxyController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
|
||||||
|
func (mr *MockProxyControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockProxyController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
|
||||||
|
}
|
||||||
|
|||||||
@@ -0,0 +1,74 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GRPCProxyController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
|
||||||
|
type GRPCProxyController struct {
|
||||||
|
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
||||||
|
// Map of cluster address -> set of proxy IDs
|
||||||
|
clusterProxies sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGRPCProxyController creates a new GRPCProxyController.
|
||||||
|
func NewGRPCProxyController(proxyGRPCServer *nbgrpc.ProxyServiceServer) *GRPCProxyController {
|
||||||
|
return &GRPCProxyController{
|
||||||
|
proxyGRPCServer: proxyGRPCServer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendServiceUpdateToCluster sends a service update to a specific proxy cluster.
|
||||||
|
func (c *GRPCProxyController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
|
||||||
|
c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server.
|
||||||
|
func (c *GRPCProxyController) GetOIDCValidationConfig() rpservice.OIDCValidationConfig {
|
||||||
|
return c.proxyGRPCServer.GetOIDCValidationConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterProxyToCluster registers a proxy to a specific cluster for routing.
|
||||||
|
func (c *GRPCProxyController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||||
|
if clusterAddr == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||||
|
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||||
|
log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterProxyFromCluster removes a proxy from a cluster.
|
||||||
|
func (c *GRPCProxyController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||||
|
if clusterAddr == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok {
|
||||||
|
proxySet.(*sync.Map).Delete(proxyID)
|
||||||
|
log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
|
||||||
|
func (c *GRPCProxyController) GetProxiesForCluster(clusterAddr string) []string {
|
||||||
|
proxySet, ok := c.clusterProxies.Load(clusterAddr)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var proxies []string
|
||||||
|
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
|
||||||
|
proxies = append(proxies, key.(string))
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return proxies
|
||||||
|
}
|
||||||
@@ -7,16 +7,14 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions"
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/shared/management/proto"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,22 +29,22 @@ type Manager struct {
|
|||||||
store store.Store
|
store store.Store
|
||||||
accountManager account.Manager
|
accountManager account.Manager
|
||||||
permissionsManager permissions.Manager
|
permissionsManager permissions.Manager
|
||||||
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
proxyController service.ProxyController
|
||||||
clusterDeriver ClusterDeriver
|
clusterDeriver ClusterDeriver
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager creates a new service manager.
|
// NewManager creates a new service manager.
|
||||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) rpservice.Manager {
|
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController service.ProxyController, clusterDeriver ClusterDeriver) *Manager {
|
||||||
return &Manager{
|
return &Manager{
|
||||||
store: store,
|
store: store,
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
permissionsManager: permissionsManager,
|
permissionsManager: permissionsManager,
|
||||||
proxyGRPCServer: proxyGRPCServer,
|
proxyController: proxyController,
|
||||||
clusterDeriver: clusterDeriver,
|
clusterDeriver: clusterDeriver,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*rpservice.Service, error) {
|
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.NewPermissionValidationError(err)
|
return nil, status.NewPermissionValidationError(err)
|
||||||
@@ -70,34 +68,34 @@ func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string)
|
|||||||
return services, nil
|
return services, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, service *rpservice.Service) error {
|
func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error {
|
||||||
for _, target := range service.Targets {
|
for _, target := range s.Targets {
|
||||||
switch target.TargetType {
|
switch target.TargetType {
|
||||||
case rpservice.TargetTypePeer:
|
case service.TargetTypePeer:
|
||||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
|
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||||
target.Host = unknownHostPlaceholder
|
target.Host = unknownHostPlaceholder
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
target.Host = peer.IP.String()
|
target.Host = peer.IP.String()
|
||||||
case rpservice.TargetTypeHost:
|
case service.TargetTypeHost:
|
||||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||||
target.Host = unknownHostPlaceholder
|
target.Host = unknownHostPlaceholder
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
target.Host = resource.Prefix.Addr().String()
|
target.Host = resource.Prefix.Addr().String()
|
||||||
case rpservice.TargetTypeDomain:
|
case service.TargetTypeDomain:
|
||||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||||
target.Host = unknownHostPlaceholder
|
target.Host = unknownHostPlaceholder
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
target.Host = resource.Domain
|
target.Host = resource.Domain
|
||||||
case rpservice.TargetTypeSubnet:
|
case service.TargetTypeSubnet:
|
||||||
// For subnets we do not do any lookups on the resource
|
// For subnets we do not do any lookups on the resource
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||||
@@ -106,7 +104,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, ser
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*rpservice.Service, error) {
|
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.NewPermissionValidationError(err)
|
return nil, status.NewPermissionValidationError(err)
|
||||||
@@ -127,7 +125,7 @@ func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID s
|
|||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, service *rpservice.Service) (*rpservice.Service, error) {
|
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.NewPermissionValidationError(err)
|
return nil, status.NewPermissionValidationError(err)
|
||||||
@@ -136,29 +134,29 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
|
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.persistNewService(ctx, accountID, service); err != nil {
|
if err := m.persistNewService(ctx, accountID, s); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
|
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
|
||||||
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return service, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *rpservice.Service) error {
|
func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error {
|
||||||
if m.clusterDeriver != nil {
|
if m.clusterDeriver != nil {
|
||||||
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -185,7 +183,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *rpservice.Service) error {
|
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
|
||||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -219,7 +217,7 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *rpservice.Service) (*rpservice.Service, error) {
|
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
|
||||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.NewPermissionValidationError(err)
|
return nil, status.NewPermissionValidationError(err)
|
||||||
@@ -243,7 +241,7 @@ func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, s
|
|||||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.sendServiceUpdateNotifications(accountID, service, updateInfo)
|
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
@@ -255,7 +253,7 @@ type serviceUpdateInfo struct {
|
|||||||
serviceEnabledChanged bool
|
serviceEnabledChanged bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *rpservice.Service) (*serviceUpdateInfo, error) {
|
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
|
||||||
var updateInfo serviceUpdateInfo
|
var updateInfo serviceUpdateInfo
|
||||||
|
|
||||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
@@ -293,7 +291,7 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
|
|||||||
return &updateInfo, err
|
return &updateInfo, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *rpservice.Service) error {
|
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
|
||||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -310,7 +308,7 @@ func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Stor
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) preserveExistingAuthSecrets(service, existingService *rpservice.Service) {
|
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
|
||||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||||
service.Auth.PasswordAuth.Password == "" {
|
service.Auth.PasswordAuth.Password == "" {
|
||||||
@@ -324,44 +322,40 @@ func (m *Manager) preserveExistingAuthSecrets(service, existingService *rpservic
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) SendServiceUpdateToCluster(accountID string, update *proto.ProxyMapping, clusterAddr string) {
|
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
|
||||||
m.proxyGRPCServer.SendServiceUpdateToCluster(update, clusterAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) preserveServiceMetadata(service, existingService *rpservice.Service) {
|
|
||||||
service.Meta = existingService.Meta
|
service.Meta = existingService.Meta
|
||||||
service.SessionPrivateKey = existingService.SessionPrivateKey
|
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||||
service.SessionPublicKey = existingService.SessionPublicKey
|
service.SessionPublicKey = existingService.SessionPublicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) sendServiceUpdateNotifications(accountID string, service *rpservice.Service, updateInfo *serviceUpdateInfo) {
|
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
|
||||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
|
case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
|
||||||
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Delete, "", oidcCfg), updateInfo.oldCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
|
||||||
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Create, "", oidcCfg), service.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
|
||||||
case !service.Enabled && updateInfo.serviceEnabledChanged:
|
case !s.Enabled && updateInfo.serviceEnabledChanged:
|
||||||
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Delete, "", oidcCfg), service.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster)
|
||||||
case service.Enabled && updateInfo.serviceEnabledChanged:
|
case s.Enabled && updateInfo.serviceEnabledChanged:
|
||||||
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Create, "", oidcCfg), service.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
|
||||||
default:
|
default:
|
||||||
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Update, "", oidcCfg), service.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||||
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*rpservice.Target) error {
|
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
|
||||||
for _, target := range targets {
|
for _, target := range targets {
|
||||||
switch target.TargetType {
|
switch target.TargetType {
|
||||||
case rpservice.TargetTypePeer:
|
case service.TargetTypePeer:
|
||||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
|
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
|
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
|
||||||
}
|
}
|
||||||
case rpservice.TargetTypeHost, rpservice.TargetTypeSubnet, rpservice.TargetTypeDomain:
|
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
|
||||||
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||||
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
|
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
|
||||||
@@ -382,10 +376,10 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
|
|||||||
return status.NewPermissionDeniedError()
|
return status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
var service *rpservice.Service
|
var s *service.Service
|
||||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
var err error
|
var err error
|
||||||
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -400,9 +394,9 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
|
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
|
||||||
|
|
||||||
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
@@ -429,7 +423,7 @@ func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, service
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
|
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
|
||||||
func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status rpservice.Status) error {
|
func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
|
||||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -447,17 +441,17 @@ func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, st
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get service: %w", err)
|
return fmt.Errorf("failed to get service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
@@ -470,18 +464,18 @@ func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID str
|
|||||||
return fmt.Errorf("failed to get services: %w", err)
|
return fmt.Errorf("failed to get services: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, service := range services {
|
for _, s := range services {
|
||||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||||
}
|
}
|
||||||
m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) GetGlobalServices(ctx context.Context) ([]*rpservice.Service, error) {
|
func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||||
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
|
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||||
@@ -497,7 +491,7 @@ func (m *Manager) GetGlobalServices(ctx context.Context) ([]*rpservice.Service,
|
|||||||
return services, nil
|
return services, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*rpservice.Service, error) {
|
func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
|
||||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||||
@@ -511,7 +505,7 @@ func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID strin
|
|||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
|
func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||||
|
|||||||
@@ -166,6 +166,7 @@ func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
|||||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager())
|
||||||
s.AfterInit(func(s *BaseServer) {
|
s.AfterInit(func(s *BaseServer) {
|
||||||
proxyService.SetProxyManager(s.ServiceManager())
|
proxyService.SetProxyManager(s.ServiceManager())
|
||||||
|
proxyService.SetProxyController(s.ServiceProxyController())
|
||||||
})
|
})
|
||||||
return proxyService
|
return proxyService
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
|
nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||||
@@ -106,6 +108,12 @@ func (s *BaseServer) NetworkMapController() network_map.Controller {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) ServiceProxyController() service.ProxyController {
|
||||||
|
return Create(s, func() service.ProxyController {
|
||||||
|
return nbreverseproxy.NewGRPCProxyController(s.ReverseProxyGRPCServer())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
|
func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
|
||||||
return Create(s, func() *server.AccountRequestBuffer {
|
return Create(s, func() *server.AccountRequestBuffer {
|
||||||
return server.NewAccountRequestBuffer(context.Background(), s.Store())
|
return server.NewAccountRequestBuffer(context.Background(), s.Store())
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ func (s *BaseServer) RecordsManager() records.Manager {
|
|||||||
|
|
||||||
func (s *BaseServer) ServiceManager() service.Manager {
|
func (s *BaseServer) ServiceManager() service.Manager {
|
||||||
return Create(s, func() service.Manager {
|
return Create(s, func() service.Manager {
|
||||||
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager())
|
return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,9 +7,11 @@ import (
|
|||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/eko/gocache/lib/v4/cache"
|
||||||
"github.com/eko/gocache/lib/v4/store"
|
"github.com/eko/gocache/lib/v4/store"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -26,7 +28,7 @@ type tokenMetadata struct {
|
|||||||
// OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC.
|
// OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC.
|
||||||
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
|
// Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var.
|
||||||
type OneTimeTokenStore struct {
|
type OneTimeTokenStore struct {
|
||||||
store store.StoreInterface
|
cache *cache.Cache[string]
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -38,7 +40,7 @@ func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.
|
|||||||
}
|
}
|
||||||
|
|
||||||
return &OneTimeTokenStore{
|
return &OneTimeTokenStore{
|
||||||
store: cacheStore,
|
cache: cache.New[string](cacheStore),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -64,7 +66,12 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
|
|||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.store.Set(s.ctx, hashedToken, metadata, store.WithExpiration(ttl)); err != nil {
|
metadataJSON, err := json.Marshal(metadata)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to serialize token metadata: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil {
|
||||||
return "", fmt.Errorf("failed to store token: %w", err)
|
return "", fmt.Errorf("failed to store token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,20 +94,19 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time.
|
|||||||
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
|
func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error {
|
||||||
hashedToken := hashToken(token)
|
hashedToken := hashToken(token)
|
||||||
|
|
||||||
value, err := s.store.Get(s.ctx, hashedToken)
|
metadataJSON, err := s.cache.Get(s.ctx, hashedToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID)
|
log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID)
|
||||||
return fmt.Errorf("invalid token")
|
return fmt.Errorf("invalid token")
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata, ok := value.(*tokenMetadata)
|
metadata := &tokenMetadata{}
|
||||||
if !ok {
|
if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil {
|
||||||
log.Warnf("Token validation failed: invalid metadata type (proxy: %s, account: %s)", serviceID, accountID)
|
log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err)
|
||||||
return fmt.Errorf("invalid token metadata")
|
return fmt.Errorf("invalid token metadata")
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Now().After(metadata.ExpiresAt) {
|
if time.Now().After(metadata.ExpiresAt) {
|
||||||
s.store.Delete(s.ctx, hashedToken)
|
|
||||||
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID)
|
log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID)
|
||||||
return fmt.Errorf("token expired")
|
return fmt.Errorf("token expired")
|
||||||
}
|
}
|
||||||
@@ -115,7 +121,7 @@ func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID strin
|
|||||||
return fmt.Errorf("service ID mismatch")
|
return fmt.Errorf("service ID mismatch")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.store.Delete(s.ctx, hashedToken); err != nil {
|
if err := s.cache.Delete(s.ctx, hashedToken); err != nil {
|
||||||
log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err)
|
log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -59,9 +59,6 @@ type ProxyServiceServer struct {
|
|||||||
// Map of connected proxies: proxy_id -> proxy connection
|
// Map of connected proxies: proxy_id -> proxy connection
|
||||||
connectedProxies sync.Map
|
connectedProxies sync.Map
|
||||||
|
|
||||||
// Map of cluster address -> set of proxy IDs
|
|
||||||
clusterProxies sync.Map
|
|
||||||
|
|
||||||
// Channel for broadcasting reverse proxy updates to all proxies
|
// Channel for broadcasting reverse proxy updates to all proxies
|
||||||
updatesChan chan *proto.ProxyMapping
|
updatesChan chan *proto.ProxyMapping
|
||||||
|
|
||||||
@@ -71,6 +68,9 @@ type ProxyServiceServer struct {
|
|||||||
// Manager for reverse proxy operations
|
// Manager for reverse proxy operations
|
||||||
serviceManager rpservice.Manager
|
serviceManager rpservice.Manager
|
||||||
|
|
||||||
|
// ProxyController for service updates and cluster management
|
||||||
|
proxyController rpservice.ProxyController
|
||||||
|
|
||||||
// Manager for proxy connections
|
// Manager for proxy connections
|
||||||
proxyManager proxy.Manager
|
proxyManager proxy.Manager
|
||||||
|
|
||||||
@@ -173,6 +173,10 @@ func (s *ProxyServiceServer) SetProxyManager(manager rpservice.Manager) {
|
|||||||
s.serviceManager = manager
|
s.serviceManager = manager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ProxyServiceServer) SetProxyController(proxyController rpservice.ProxyController) {
|
||||||
|
s.proxyController = proxyController
|
||||||
|
}
|
||||||
|
|
||||||
// GetMappingUpdate handles the control stream with proxy clients
|
// GetMappingUpdate handles the control stream with proxy clients
|
||||||
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
|
func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error {
|
||||||
ctx := stream.Context()
|
ctx := stream.Context()
|
||||||
@@ -205,7 +209,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.connectedProxies.Store(proxyID, conn)
|
s.connectedProxies.Store(proxyID, conn)
|
||||||
s.addToCluster(conn.address, proxyID)
|
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
// Register proxy in database
|
// Register proxy in database
|
||||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
|
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil {
|
||||||
@@ -224,7 +230,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.connectedProxies.Delete(proxyID)
|
s.connectedProxies.Delete(proxyID)
|
||||||
s.removeFromCluster(conn.address, proxyID)
|
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||||
|
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||||
|
}
|
||||||
|
|
||||||
cancel()
|
cancel()
|
||||||
log.Infof("Proxy %s disconnected", proxyID)
|
log.Infof("Proxy %s disconnected", proxyID)
|
||||||
@@ -446,61 +454,43 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string {
|
|||||||
return urls
|
return urls
|
||||||
}
|
}
|
||||||
|
|
||||||
// addToCluster registers a proxy in a cluster.
|
|
||||||
func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) {
|
|
||||||
if clusterAddr == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
|
||||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
|
||||||
log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeFromCluster removes a proxy from a cluster.
|
|
||||||
func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
|
|
||||||
if clusterAddr == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok {
|
|
||||||
proxySet.(*sync.Map).Delete(proxyID)
|
|
||||||
log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
|
// SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster.
|
||||||
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
|
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
|
||||||
// For create/update operations a unique one-time auth token is generated per
|
// For create/update operations a unique one-time auth token is generated per
|
||||||
// proxy so that every replica can independently authenticate with management.
|
// proxy so that every replica can independently authenticate with management.
|
||||||
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
|
func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) {
|
||||||
if clusterAddr == "" {
|
if clusterAddr == "" {
|
||||||
s.SendServiceUpdate(update)
|
s.SendServiceUpdate(update)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
proxySet, ok := s.clusterProxies.Load(clusterAddr)
|
if s.proxyController == nil {
|
||||||
if !ok {
|
log.WithContext(ctx).Debugf("ProxyController not set, cannot send to cluster %s", clusterAddr)
|
||||||
log.Debugf("No proxies connected for cluster %s", clusterAddr)
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
proxyIDs := s.proxyController.GetProxiesForCluster(clusterAddr)
|
||||||
|
if len(proxyIDs) == 0 {
|
||||||
|
log.WithContext(ctx).Debugf("No proxies connected for cluster %s", clusterAddr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Sending service update to cluster %s", clusterAddr)
|
log.Debugf("Sending service update to cluster %s", clusterAddr)
|
||||||
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
|
for _, proxyID := range proxyIDs {
|
||||||
proxyID := key.(string)
|
|
||||||
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
if connVal, ok := s.connectedProxies.Load(proxyID); ok {
|
||||||
conn := connVal.(*proxyConnection)
|
conn := connVal.(*proxyConnection)
|
||||||
msg := s.perProxyMessage(update, proxyID)
|
msg := s.perProxyMessage(update, proxyID)
|
||||||
if msg == nil {
|
if msg == nil {
|
||||||
return true
|
continue
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case conn.sendChan <- msg:
|
case conn.sendChan <- msg:
|
||||||
log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||||
default:
|
default:
|
||||||
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true
|
}
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
s.SendServiceUpdateToCluster(update, cluster)
|
s.SendServiceUpdateToCluster(context.Background(), update, cluster)
|
||||||
|
|
||||||
tokens := make([]string, numProxies)
|
tokens := make([]string, numProxies)
|
||||||
for i, ch := range channels {
|
for i, ch := range channels {
|
||||||
@@ -116,7 +116,7 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
|||||||
Domain: "test.example.com",
|
Domain: "test.example.com",
|
||||||
}
|
}
|
||||||
|
|
||||||
s.SendServiceUpdateToCluster(update, cluster)
|
s.SendServiceUpdateToCluster(context.Background(), update, cluster)
|
||||||
|
|
||||||
msg1 := drainChannel(ch1)
|
msg1 := drainChannel(ch1)
|
||||||
msg2 := drainChannel(ch2)
|
msg2 := drainChannel(ch2)
|
||||||
|
|||||||
Reference in New Issue
Block a user