refactor service manager code and add tests

This commit is contained in:
pascal
2026-02-13 12:08:01 +01:00
parent c4bfbbaa52
commit 0a884d839e
4 changed files with 3282 additions and 119 deletions

View File

@@ -135,54 +135,11 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
return nil, status.NewPermissionDeniedError()
}
var proxyCluster string
if m.clusterDeriver != nil {
proxyCluster, err = m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
return nil, status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
}
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
return nil, err
}
service.AccountID = accountID
service.ProxyCluster = proxyCluster
service.InitNewRecord()
err = service.Auth.HashSecrets()
if err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
// Generate session JWT signing keys
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return nil, fmt.Errorf("generate session keys: %w", err)
}
service.SessionPrivateKey = keyPair.PrivateKey
service.SessionPublicKey = keyPair.PublicKey
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Check for duplicate domain
existingService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err)
}
}
if existingService != nil {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
}
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err = transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
return nil
})
if err != nil {
if err := m.persistNewService(ctx, accountID, service); err != nil {
return nil, err
}
@@ -200,6 +157,67 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
return service, nil
}
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
if m.clusterDeriver != nil {
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
}
service.ProxyCluster = proxyCluster
}
service.AccountID = accountID
service.InitNewRecord()
if err := service.Auth.HashSecrets(); err != nil {
return fmt.Errorf("hash secrets: %w", err)
}
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return fmt.Errorf("generate session keys: %w", err)
}
service.SessionPrivateKey = keyPair.PrivateKey
service.SessionPublicKey = keyPair.PublicKey
return nil
}
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
return nil
})
}
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err)
}
return nil
}
if existingService != nil && existingService.ID != excludeServiceID {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
}
return nil
}
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
if err != nil {
@@ -209,99 +227,122 @@ func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID strin
return nil, status.NewPermissionDeniedError()
}
var oldCluster string
var domainChanged bool
var serviceEnabledChanged bool
err = service.Auth.HashSecrets()
if err != nil {
if err := service.Auth.HashSecrets(); err != nil {
return nil, fmt.Errorf("hash secrets: %w", err)
}
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
oldCluster = existingService.ProxyCluster
if existingService.Domain != service.Domain {
domainChanged = true
conflictService, err := transaction.GetServiceByDomain(ctx, accountID, service.Domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("check existing service: %w", err)
}
}
if conflictService != nil && conflictService.ID != service.ID {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", service.Domain)
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
}
service.ProxyCluster = newCluster
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
}
service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey
serviceEnabledChanged = existingService.Enabled != service.Enabled
if err = validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err = transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
return nil
})
updateInfo, err := m.persistServiceUpdate(ctx, accountID, service)
if err != nil {
return nil, err
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
err = m.replaceHostByLookup(ctx, accountID, service)
if err != nil {
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
}
m.sendServiceUpdateNotifications(service, updateInfo)
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
}
type serviceUpdateInfo struct {
oldCluster string
domainChanged bool
serviceEnabledChanged bool
}
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
return err
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
m.preserveExistingAuthSecrets(service, existingService)
m.preserveServiceMetadata(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
return nil
})
return &updateInfo, err
}
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
return err
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
} else {
service.ProxyCluster = newCluster
}
}
return nil
}
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
}
}
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
service.Meta = existingService.Meta
service.SessionPrivateKey = existingService.SessionPrivateKey
service.SessionPublicKey = existingService.SessionPublicKey
}
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
switch {
case domainChanged && oldCluster != service.ProxyCluster:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), oldCluster)
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster)
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
case !service.Enabled && serviceEnabledChanged:
case !service.Enabled && updateInfo.serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
case service.Enabled && serviceEnabledChanged:
case service.Enabled && updateInfo.serviceEnabledChanged:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
default:
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
}
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
}
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.