diff --git a/management/internals/modules/reverseproxy/service/interface.go b/management/internals/modules/reverseproxy/service/interface.go index 001a68149..d5219511b 100644 --- a/management/internals/modules/reverseproxy/service/interface.go +++ b/management/internals/modules/reverseproxy/service/interface.go @@ -4,6 +4,8 @@ package service import ( "context" + + "github.com/netbirdio/netbird/shared/management/proto" ) type Manager interface { @@ -21,3 +23,12 @@ type Manager interface { GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) } + +// ProxyController is responsible for managing proxy clusters and routing service updates. +type ProxyController interface { + SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) + GetOIDCValidationConfig() OIDCValidationConfig + RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error + UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error + GetProxiesForCluster(clusterAddr string) []string +} diff --git a/management/internals/modules/reverseproxy/service/interface_mock.go b/management/internals/modules/reverseproxy/service/interface_mock.go index 7e946a821..673e11fa9 100644 --- a/management/internals/modules/reverseproxy/service/interface_mock.go +++ b/management/internals/modules/reverseproxy/service/interface_mock.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + proto "github.com/netbirdio/netbird/shared/management/proto" ) // MockManager is a mock of Manager interface. @@ -223,3 +224,94 @@ func (mr *MockManagerMockRecorder) UpdateService(ctx, accountID, userID, service mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateService", reflect.TypeOf((*MockManager)(nil).UpdateService), ctx, accountID, userID, service) } + +// MockProxyController is a mock of ProxyController interface. +type MockProxyController struct { + ctrl *gomock.Controller + recorder *MockProxyControllerMockRecorder +} + +// MockProxyControllerMockRecorder is the mock recorder for MockProxyController. +type MockProxyControllerMockRecorder struct { + mock *MockProxyController +} + +// NewMockProxyController creates a new mock instance. +func NewMockProxyController(ctrl *gomock.Controller) *MockProxyController { + mock := &MockProxyController{ctrl: ctrl} + mock.recorder = &MockProxyControllerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockProxyController) EXPECT() *MockProxyControllerMockRecorder { + return m.recorder +} + +// GetOIDCValidationConfig mocks base method. +func (m *MockProxyController) GetOIDCValidationConfig() OIDCValidationConfig { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOIDCValidationConfig") + ret0, _ := ret[0].(OIDCValidationConfig) + return ret0 +} + +// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig. +func (mr *MockProxyControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockProxyController)(nil).GetOIDCValidationConfig)) +} + +// GetProxiesForCluster mocks base method. +func (m *MockProxyController) GetProxiesForCluster(clusterAddr string) []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr) + ret0, _ := ret[0].([]string) + return ret0 +} + +// GetProxiesForCluster indicates an expected call of GetProxiesForCluster. +func (mr *MockProxyControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockProxyController)(nil).GetProxiesForCluster), clusterAddr) +} + +// RegisterProxyToCluster mocks base method. +func (m *MockProxyController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster. +func (mr *MockProxyControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockProxyController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID) +} + +// SendServiceUpdateToCluster mocks base method. +func (m *MockProxyController) SendServiceUpdateToCluster(accountID string, update *proto.ProxyMapping, clusterAddr string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SendServiceUpdateToCluster", accountID, update, clusterAddr) +} + +// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster. +func (mr *MockProxyControllerMockRecorder) SendServiceUpdateToCluster(accountID, update, clusterAddr interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockProxyController)(nil).SendServiceUpdateToCluster), accountID, update, clusterAddr) +} + +// UnregisterProxyFromCluster mocks base method. +func (m *MockProxyController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID) + ret0, _ := ret[0].(error) + return ret0 +} + +// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster. +func (mr *MockProxyControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockProxyController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID) +} diff --git a/management/internals/modules/reverseproxy/service/manager/controller.go b/management/internals/modules/reverseproxy/service/manager/controller.go new file mode 100644 index 000000000..5b58af1a3 --- /dev/null +++ b/management/internals/modules/reverseproxy/service/manager/controller.go @@ -0,0 +1,74 @@ +package manager + +import ( + "context" + "sync" + + log "github.com/sirupsen/logrus" + + rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + "github.com/netbirdio/netbird/shared/management/proto" +) + +// GRPCProxyController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC. +type GRPCProxyController struct { + proxyGRPCServer *nbgrpc.ProxyServiceServer + // Map of cluster address -> set of proxy IDs + clusterProxies sync.Map +} + +// NewGRPCProxyController creates a new GRPCProxyController. +func NewGRPCProxyController(proxyGRPCServer *nbgrpc.ProxyServiceServer) *GRPCProxyController { + return &GRPCProxyController{ + proxyGRPCServer: proxyGRPCServer, + } +} + +// SendServiceUpdateToCluster sends a service update to a specific proxy cluster. +func (c *GRPCProxyController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) { + c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr) +} + +// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server. +func (c *GRPCProxyController) GetOIDCValidationConfig() rpservice.OIDCValidationConfig { + return c.proxyGRPCServer.GetOIDCValidationConfig() +} + +// RegisterProxyToCluster registers a proxy to a specific cluster for routing. +func (c *GRPCProxyController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error { + if clusterAddr == "" { + return nil + } + proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{}) + proxySet.(*sync.Map).Store(proxyID, struct{}{}) + log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr) + return nil +} + +// UnregisterProxyFromCluster removes a proxy from a cluster. +func (c *GRPCProxyController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error { + if clusterAddr == "" { + return nil + } + if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok { + proxySet.(*sync.Map).Delete(proxyID) + log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr) + } + return nil +} + +// GetProxiesForCluster returns all proxy IDs registered for a specific cluster. +func (c *GRPCProxyController) GetProxiesForCluster(clusterAddr string) []string { + proxySet, ok := c.clusterProxies.Load(clusterAddr) + if !ok { + return nil + } + + var proxies []string + proxySet.(*sync.Map).Range(func(key, _ interface{}) bool { + proxies = append(proxies, key.(string)) + return true + }) + return proxies +} diff --git a/management/internals/modules/reverseproxy/service/manager/manager.go b/management/internals/modules/reverseproxy/service/manager/manager.go index a2a616737..9927f4244 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager.go +++ b/management/internals/modules/reverseproxy/service/manager/manager.go @@ -7,16 +7,14 @@ import ( log "github.com/sirupsen/logrus" - rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" - nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) @@ -31,22 +29,22 @@ type Manager struct { store store.Store accountManager account.Manager permissionsManager permissions.Manager - proxyGRPCServer *nbgrpc.ProxyServiceServer + proxyController service.ProxyController clusterDeriver ClusterDeriver } // NewManager creates a new service manager. -func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) rpservice.Manager { +func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController service.ProxyController, clusterDeriver ClusterDeriver) *Manager { return &Manager{ store: store, accountManager: accountManager, permissionsManager: permissionsManager, - proxyGRPCServer: proxyGRPCServer, + proxyController: proxyController, clusterDeriver: clusterDeriver, } } -func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*rpservice.Service, error) { +func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -70,34 +68,34 @@ func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) return services, nil } -func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, service *rpservice.Service) error { - for _, target := range service.Targets { +func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error { + for _, target := range s.Targets { switch target.TargetType { - case rpservice.TargetTypePeer: + case service.TargetTypePeer: peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) if err != nil { - log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err) + log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err) target.Host = unknownHostPlaceholder continue } target.Host = peer.IP.String() - case rpservice.TargetTypeHost: + case service.TargetTypeHost: resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) if err != nil { - log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err) + log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err) target.Host = unknownHostPlaceholder continue } target.Host = resource.Prefix.Addr().String() - case rpservice.TargetTypeDomain: + case service.TargetTypeDomain: resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId) if err != nil { - log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err) + log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err) target.Host = unknownHostPlaceholder continue } target.Host = resource.Domain - case rpservice.TargetTypeSubnet: + case service.TargetTypeSubnet: // For subnets we do not do any lookups on the resource default: return fmt.Errorf("unknown target type: %s", target.TargetType) @@ -106,7 +104,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, ser return nil } -func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*rpservice.Service, error) { +func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -127,7 +125,7 @@ func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID s return service, nil } -func (m *Manager) CreateService(ctx context.Context, accountID, userID string, service *rpservice.Service) (*rpservice.Service, error) { +func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -136,29 +134,29 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s return nil, status.NewPermissionDeniedError() } - if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil { + if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil { return nil, err } - if err := m.persistNewService(ctx, accountID, service); err != nil { + if err := m.persistNewService(ctx, accountID, s); err != nil { return nil, err } - m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta()) + m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta()) - err = m.replaceHostByLookup(ctx, accountID, service) + err = m.replaceHostByLookup(ctx, accountID, s) if err != nil { - return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) + return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err) } - m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) - return service, nil + return s, nil } -func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *rpservice.Service) error { +func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error { if m.clusterDeriver != nil { proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain) if err != nil { @@ -185,7 +183,7 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri return nil } -func (m *Manager) persistNewService(ctx context.Context, accountID string, service *rpservice.Service) error { +func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil { return err @@ -219,7 +217,7 @@ func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.St return nil } -func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *rpservice.Service) (*rpservice.Service, error) { +func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -243,7 +241,7 @@ func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, s return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) } - m.sendServiceUpdateNotifications(accountID, service, updateInfo) + m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo) m.accountManager.UpdateAccountPeers(ctx, accountID) return service, nil @@ -255,7 +253,7 @@ type serviceUpdateInfo struct { serviceEnabledChanged bool } -func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *rpservice.Service) (*serviceUpdateInfo, error) { +func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) { var updateInfo serviceUpdateInfo err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -293,7 +291,7 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se return &updateInfo, err } -func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *rpservice.Service) error { +func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error { if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil { return err } @@ -310,7 +308,7 @@ func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Stor return nil } -func (m *Manager) preserveExistingAuthSecrets(service, existingService *rpservice.Service) { +func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) { if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled && existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled && service.Auth.PasswordAuth.Password == "" { @@ -324,44 +322,40 @@ func (m *Manager) preserveExistingAuthSecrets(service, existingService *rpservic } } -func (m *Manager) SendServiceUpdateToCluster(accountID string, update *proto.ProxyMapping, clusterAddr string) { - m.proxyGRPCServer.SendServiceUpdateToCluster(update, clusterAddr) -} - -func (m *Manager) preserveServiceMetadata(service, existingService *rpservice.Service) { +func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) { service.Meta = existingService.Meta service.SessionPrivateKey = existingService.SessionPrivateKey service.SessionPublicKey = existingService.SessionPublicKey } -func (m *Manager) sendServiceUpdateNotifications(accountID string, service *rpservice.Service, updateInfo *serviceUpdateInfo) { - oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() +func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) { + oidcCfg := m.proxyController.GetOIDCValidationConfig() switch { - case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster: - m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Delete, "", oidcCfg), updateInfo.oldCluster) - m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Create, "", oidcCfg), service.ProxyCluster) - case !service.Enabled && updateInfo.serviceEnabledChanged: - m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Delete, "", oidcCfg), service.ProxyCluster) - case service.Enabled && updateInfo.serviceEnabledChanged: - m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Create, "", oidcCfg), service.ProxyCluster) + case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster: + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster) + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster) + case !s.Enabled && updateInfo.serviceEnabledChanged: + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster) + case s.Enabled && updateInfo.serviceEnabledChanged: + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster) default: - m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Update, "", oidcCfg), service.ProxyCluster) + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster) } } // validateTargetReferences checks that all target IDs reference existing peers or resources in the account. -func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*rpservice.Target) error { +func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error { for _, target := range targets { switch target.TargetType { - case rpservice.TargetTypePeer: + case service.TargetTypePeer: if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId) } return fmt.Errorf("look up peer target %q: %w", target.TargetId, err) } - case rpservice.TargetTypeHost, rpservice.TargetTypeSubnet, rpservice.TargetTypeDomain: + case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain: if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId) @@ -382,10 +376,10 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI return status.NewPermissionDeniedError() } - var service *rpservice.Service + var s *service.Service err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var err error - service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) + s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) if err != nil { return err } @@ -400,9 +394,9 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI return err } - m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta()) + m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta()) - m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) @@ -429,7 +423,7 @@ func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, service } // SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.) -func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status rpservice.Status) error { +func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) if err != nil { @@ -447,17 +441,17 @@ func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, st } func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error { - service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) + s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) if err != nil { return fmt.Errorf("failed to get service: %w", err) } - err = m.replaceHostByLookup(ctx, accountID, service) + err = m.replaceHostByLookup(ctx, accountID, s) if err != nil { - return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) + return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err) } - m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) @@ -470,18 +464,18 @@ func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID str return fmt.Errorf("failed to get services: %w", err) } - for _, service := range services { - err = m.replaceHostByLookup(ctx, accountID, service) + for _, s := range services { + err = m.replaceHostByLookup(ctx, accountID, s) if err != nil { - return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) + return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err) } - m.SendServiceUpdateToCluster(accountID, service.ToProtoMapping(rpservice.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) } return nil } -func (m *Manager) GetGlobalServices(ctx context.Context) ([]*rpservice.Service, error) { +func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) { services, err := m.store.GetServices(ctx, store.LockingStrengthNone) if err != nil { return nil, fmt.Errorf("failed to get services: %w", err) @@ -497,7 +491,7 @@ func (m *Manager) GetGlobalServices(ctx context.Context) ([]*rpservice.Service, return services, nil } -func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*rpservice.Service, error) { +func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) { service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) if err != nil { return nil, fmt.Errorf("failed to get service: %w", err) @@ -511,7 +505,7 @@ func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID strin return service, nil } -func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) { +func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) { services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, fmt.Errorf("failed to get services: %w", err) diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 98f3ad3eb..74941938c 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -166,6 +166,7 @@ func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager()) s.AfterInit(func(s *BaseServer) { proxyService.SetProxyManager(s.ServiceManager()) + proxyService.SetProxyController(s.ServiceProxyController()) }) return proxyService }) diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index 4ea86900a..d3263a077 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -6,6 +6,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" + nbreverseproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager" "github.com/netbirdio/netbird/management/internals/controllers/network_map" nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" @@ -106,6 +108,12 @@ func (s *BaseServer) NetworkMapController() network_map.Controller { }) } +func (s *BaseServer) ServiceProxyController() service.ProxyController { + return Create(s, func() service.ProxyController { + return nbreverseproxy.NewGRPCProxyController(s.ReverseProxyGRPCServer()) + }) +} + func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer { return Create(s, func() *server.AccountRequestBuffer { return server.NewAccountRequestBuffer(context.Background(), s.Store()) diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index ba8820596..d250721ba 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -194,7 +194,7 @@ func (s *BaseServer) RecordsManager() records.Manager { func (s *BaseServer) ServiceManager() service.Manager { return Create(s, func() service.Manager { - return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ReverseProxyGRPCServer(), s.ReverseProxyDomainManager()) + return nbreverseproxy.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.ServiceProxyController(), s.ReverseProxyDomainManager()) }) } diff --git a/management/internals/shared/grpc/onetime_token.go b/management/internals/shared/grpc/onetime_token.go index 462198124..7999407db 100644 --- a/management/internals/shared/grpc/onetime_token.go +++ b/management/internals/shared/grpc/onetime_token.go @@ -7,9 +7,11 @@ import ( "crypto/subtle" "encoding/base64" "encoding/hex" + "encoding/json" "fmt" "time" + "github.com/eko/gocache/lib/v4/cache" "github.com/eko/gocache/lib/v4/store" log "github.com/sirupsen/logrus" @@ -26,7 +28,7 @@ type tokenMetadata struct { // OneTimeTokenStore manages single-use authentication tokens for proxy-to-management RPC. // Supports both in-memory and Redis storage via NB_IDP_CACHE_REDIS_ADDRESS env var. type OneTimeTokenStore struct { - store store.StoreInterface + cache *cache.Cache[string] ctx context.Context } @@ -38,7 +40,7 @@ func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time. } return &OneTimeTokenStore{ - store: cacheStore, + cache: cache.New[string](cacheStore), ctx: ctx, }, nil } @@ -64,7 +66,12 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time. CreatedAt: time.Now(), } - if err := s.store.Set(s.ctx, hashedToken, metadata, store.WithExpiration(ttl)); err != nil { + metadataJSON, err := json.Marshal(metadata) + if err != nil { + return "", fmt.Errorf("failed to serialize token metadata: %w", err) + } + + if err := s.cache.Set(s.ctx, hashedToken, string(metadataJSON), store.WithExpiration(ttl)); err != nil { return "", fmt.Errorf("failed to store token: %w", err) } @@ -87,20 +94,19 @@ func (s *OneTimeTokenStore) GenerateToken(accountID, serviceID string, ttl time. func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID string) error { hashedToken := hashToken(token) - value, err := s.store.Get(s.ctx, hashedToken) + metadataJSON, err := s.cache.Get(s.ctx, hashedToken) if err != nil { log.Warnf("Token validation failed: token not found (proxy: %s, account: %s)", serviceID, accountID) return fmt.Errorf("invalid token") } - metadata, ok := value.(*tokenMetadata) - if !ok { - log.Warnf("Token validation failed: invalid metadata type (proxy: %s, account: %s)", serviceID, accountID) + metadata := &tokenMetadata{} + if err := json.Unmarshal([]byte(metadataJSON), metadata); err != nil { + log.Warnf("Token validation failed: failed to unmarshal metadata (proxy: %s, account: %s): %v", serviceID, accountID, err) return fmt.Errorf("invalid token metadata") } if time.Now().After(metadata.ExpiresAt) { - s.store.Delete(s.ctx, hashedToken) log.Warnf("Token validation failed: token expired (proxy: %s, account: %s)", serviceID, accountID) return fmt.Errorf("token expired") } @@ -115,7 +121,7 @@ func (s *OneTimeTokenStore) ValidateAndConsume(token, accountID, serviceID strin return fmt.Errorf("service ID mismatch") } - if err := s.store.Delete(s.ctx, hashedToken); err != nil { + if err := s.cache.Delete(s.ctx, hashedToken); err != nil { log.Warnf("Token deletion warning (proxy: %s, account: %s): %v", serviceID, accountID, err) } diff --git a/management/internals/shared/grpc/proxy.go b/management/internals/shared/grpc/proxy.go index f002f117e..bdaeda866 100644 --- a/management/internals/shared/grpc/proxy.go +++ b/management/internals/shared/grpc/proxy.go @@ -59,9 +59,6 @@ type ProxyServiceServer struct { // Map of connected proxies: proxy_id -> proxy connection connectedProxies sync.Map - // Map of cluster address -> set of proxy IDs - clusterProxies sync.Map - // Channel for broadcasting reverse proxy updates to all proxies updatesChan chan *proto.ProxyMapping @@ -71,6 +68,9 @@ type ProxyServiceServer struct { // Manager for reverse proxy operations serviceManager rpservice.Manager + // ProxyController for service updates and cluster management + proxyController rpservice.ProxyController + // Manager for proxy connections proxyManager proxy.Manager @@ -173,6 +173,10 @@ func (s *ProxyServiceServer) SetProxyManager(manager rpservice.Manager) { s.serviceManager = manager } +func (s *ProxyServiceServer) SetProxyController(proxyController rpservice.ProxyController) { + s.proxyController = proxyController +} + // GetMappingUpdate handles the control stream with proxy clients func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest, stream proto.ProxyService_GetMappingUpdateServer) error { ctx := stream.Context() @@ -205,7 +209,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest } s.connectedProxies.Store(proxyID, conn) - s.addToCluster(conn.address, proxyID) + if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil { + log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err) + } // Register proxy in database if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo); err != nil { @@ -224,7 +230,9 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest } s.connectedProxies.Delete(proxyID) - s.removeFromCluster(conn.address, proxyID) + if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil { + log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err) + } cancel() log.Infof("Proxy %s disconnected", proxyID) @@ -446,61 +454,43 @@ func (s *ProxyServiceServer) GetConnectedProxyURLs() []string { return urls } -// addToCluster registers a proxy in a cluster. -func (s *ProxyServiceServer) addToCluster(clusterAddr, proxyID string) { - if clusterAddr == "" { - return - } - proxySet, _ := s.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{}) - proxySet.(*sync.Map).Store(proxyID, struct{}{}) - log.Debugf("Added proxy %s to cluster %s", proxyID, clusterAddr) -} - -// removeFromCluster removes a proxy from a cluster. -func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) { - if clusterAddr == "" { - return - } - if proxySet, ok := s.clusterProxies.Load(clusterAddr); ok { - proxySet.(*sync.Map).Delete(proxyID) - log.Debugf("Removed proxy %s from cluster %s", proxyID, clusterAddr) - } -} - // SendServiceUpdateToCluster sends a service update to all proxy servers in a specific cluster. // If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility). // For create/update operations a unique one-time auth token is generated per // proxy so that every replica can independently authenticate with management. -func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) { +func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, update *proto.ProxyMapping, clusterAddr string) { if clusterAddr == "" { s.SendServiceUpdate(update) return } - proxySet, ok := s.clusterProxies.Load(clusterAddr) - if !ok { - log.Debugf("No proxies connected for cluster %s", clusterAddr) + if s.proxyController == nil { + log.WithContext(ctx).Debugf("ProxyController not set, cannot send to cluster %s", clusterAddr) + return + } + + proxyIDs := s.proxyController.GetProxiesForCluster(clusterAddr) + if len(proxyIDs) == 0 { + log.WithContext(ctx).Debugf("No proxies connected for cluster %s", clusterAddr) return } log.Debugf("Sending service update to cluster %s", clusterAddr) - proxySet.(*sync.Map).Range(func(key, _ interface{}) bool { - proxyID := key.(string) + for _, proxyID := range proxyIDs { if connVal, ok := s.connectedProxies.Load(proxyID); ok { conn := connVal.(*proxyConnection) msg := s.perProxyMessage(update, proxyID) if msg == nil { - return true + continue } select { case conn.sendChan <- msg: - log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) + log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr) default: - log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) + log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr) } } - return true - }) + } } // perProxyMessage returns a copy of update with a fresh one-time token for diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index 8308abbf3..f8863d78f 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -69,7 +69,7 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { }, } - s.SendServiceUpdateToCluster(update, cluster) + s.SendServiceUpdateToCluster(context.Background(), update, cluster) tokens := make([]string, numProxies) for i, ch := range channels { @@ -116,7 +116,7 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { Domain: "test.example.com", } - s.SendServiceUpdateToCluster(update, cluster) + s.SendServiceUpdateToCluster(context.Background(), update, cluster) msg1 := drainChannel(ch1) msg2 := drainChannel(ch2)