diff --git a/management/internals/modules/reverseproxy/manager/manager.go b/management/internals/modules/reverseproxy/manager/manager.go index 535705a37..9a7ac56cb 100644 --- a/management/internals/modules/reverseproxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/manager/manager.go @@ -16,6 +16,7 @@ import ( "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/shared/management/status" ) @@ -26,7 +27,7 @@ type ClusterDeriver interface { DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) } -type managerImpl struct { +type Manager struct { store store.Store accountManager account.Manager permissionsManager permissions.Manager @@ -35,8 +36,8 @@ type managerImpl struct { } // NewManager creates a new service manager. -func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager { - return &managerImpl{ +func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) *Manager { + return &Manager{ store: store, accountManager: accountManager, permissionsManager: permissionsManager, @@ -45,7 +46,7 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa } } -func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { +func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -69,7 +70,7 @@ func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID stri return services, nil } -func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error { +func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error { for _, target := range service.Targets { switch target.TargetType { case reverseproxy.TargetTypePeer: @@ -105,7 +106,7 @@ func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, return nil } -func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) { +func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -126,7 +127,7 @@ func (m *managerImpl) GetService(ctx context.Context, accountID, userID, service return service, nil } -func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *Manager) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -150,14 +151,14 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) } - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) return service, nil } -func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error { +func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error { if m.clusterDeriver != nil { proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain) if err != nil { @@ -184,7 +185,7 @@ func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID return nil } -func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error { +func (m *Manager) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil { return err @@ -202,7 +203,7 @@ func (m *managerImpl) persistNewService(ctx context.Context, accountID string, s }) } -func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error { +func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error { existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain) if err != nil { if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { @@ -218,7 +219,7 @@ func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction stor return nil } -func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update) if err != nil { return nil, status.NewPermissionValidationError(err) @@ -254,7 +255,7 @@ type serviceUpdateInfo struct { serviceEnabledChanged bool } -func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) { +func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) { var updateInfo serviceUpdateInfo err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -292,7 +293,7 @@ func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string return &updateInfo, err } -func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error { +func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error { if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil { return err } @@ -309,7 +310,7 @@ func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store. return nil } -func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) { +func (m *Manager) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) { if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled && existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled && service.Auth.PasswordAuth.Password == "" { @@ -323,25 +324,29 @@ func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reve } } -func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) { +func (m *Manager) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) { + m.proxyGRPCServer.SendServiceUpdateToCluster(update, clusterAddr) +} + +func (m *Manager) preserveServiceMetadata(service, existingService *reverseproxy.Service) { service.Meta = existingService.Meta service.SessionPrivateKey = existingService.SessionPrivateKey service.SessionPublicKey = existingService.SessionPublicKey } -func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) { +func (m *Manager) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) { oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() switch { case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster) - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster) + m.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster) + m.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster) case !service.Enabled && updateInfo.serviceEnabledChanged: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster) + m.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster) case service.Enabled && updateInfo.serviceEnabledChanged: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster) + m.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster) default: - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster) + m.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster) } } @@ -368,7 +373,7 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco return nil } -func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { +func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete) if err != nil { return status.NewPermissionValidationError(err) @@ -397,7 +402,7 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta()) - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) @@ -406,7 +411,7 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv // SetCertificateIssuedAt sets the certificate issued timestamp to the current time. // Call this when receiving a gRPC notification that the certificate was issued. -func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error { +func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) if err != nil { @@ -424,7 +429,7 @@ func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, ser } // SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.) -func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error { +func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error { return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID) if err != nil { @@ -441,7 +446,7 @@ func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string }) } -func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error { +func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error { service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) if err != nil { return fmt.Errorf("failed to get service: %w", err) @@ -452,14 +457,14 @@ func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID st return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) } - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) m.accountManager.UpdateAccountPeers(ctx, accountID) return nil } -func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error { +func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error { services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) if err != nil { return fmt.Errorf("failed to get services: %w", err) @@ -470,13 +475,13 @@ func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID if err != nil { return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err) } - m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) + m.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster) } return nil } -func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { +func (m *Manager) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) { services, err := m.store.GetServices(ctx, store.LockingStrengthNone) if err != nil { return nil, fmt.Errorf("failed to get services: %w", err) @@ -492,7 +497,7 @@ func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Se return services, nil } -func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) { +func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) { service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID) if err != nil { return nil, fmt.Errorf("failed to get service: %w", err) @@ -506,7 +511,7 @@ func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID s return service, nil } -func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { +func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) { services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID) if err != nil { return nil, fmt.Errorf("failed to get services: %w", err) @@ -522,7 +527,7 @@ func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) return services, nil } -func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) { +func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) { target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { diff --git a/management/internals/modules/reverseproxy/manager/manager_test.go b/management/internals/modules/reverseproxy/manager/manager_test.go index 266b0066f..eb615bd8d 100644 --- a/management/internals/modules/reverseproxy/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/manager/manager_test.go @@ -20,7 +20,7 @@ func TestInitializeServiceForCreate(t *testing.T) { accountID := "test-account" t.Run("successful initialization without cluster deriver", func(t *testing.T) { - mgr := &managerImpl{ + mgr := &Manager{ clusterDeriver: nil, } @@ -40,7 +40,7 @@ func TestInitializeServiceForCreate(t *testing.T) { }) t.Run("verifies session keys are different", func(t *testing.T) { - mgr := &managerImpl{ + mgr := &Manager{ clusterDeriver: nil, } @@ -136,7 +136,7 @@ func TestCheckDomainAvailable(t *testing.T) { mockStore := store.NewMockStore(ctrl) tt.setupMock(mockStore) - mgr := &managerImpl{} + mgr := &Manager{} err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID) if tt.expectedError { @@ -166,7 +166,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) { GetServiceByDomain(ctx, accountID, ""). Return(nil, status.Errorf(status.NotFound, "not found")) - mgr := &managerImpl{} + mgr := &Manager{} err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "") assert.NoError(t, err) @@ -181,7 +181,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) { GetServiceByDomain(ctx, accountID, "test.com"). Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil) - mgr := &managerImpl{} + mgr := &Manager{} err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "") assert.Error(t, err) @@ -199,7 +199,7 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) { GetServiceByDomain(ctx, accountID, "nil.com"). Return(nil, nil) - mgr := &managerImpl{} + mgr := &Manager{} err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "") assert.NoError(t, err) @@ -237,7 +237,7 @@ func TestPersistNewService(t *testing.T) { return fn(txMock) }) - mgr := &managerImpl{store: mockStore} + mgr := &Manager{store: mockStore} err := mgr.persistNewService(ctx, accountID, service) assert.NoError(t, err) @@ -265,7 +265,7 @@ func TestPersistNewService(t *testing.T) { return fn(txMock) }) - mgr := &managerImpl{store: mockStore} + mgr := &Manager{store: mockStore} err := mgr.persistNewService(ctx, accountID, service) require.Error(t, err) @@ -275,7 +275,7 @@ func TestPersistNewService(t *testing.T) { }) } func TestPreserveExistingAuthSecrets(t *testing.T) { - mgr := &managerImpl{} + mgr := &Manager{} t.Run("preserve password when empty", func(t *testing.T) { existing := &reverseproxy.Service{ @@ -352,7 +352,7 @@ func TestPreserveExistingAuthSecrets(t *testing.T) { } func TestPreserveServiceMetadata(t *testing.T) { - mgr := &managerImpl{} + mgr := &Manager{} existing := &reverseproxy.Service{ Meta: reverseproxy.ServiceMeta{