mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 01:06:45 +00:00
Merge branch 'refs/heads/main' into refactor/permissions-manager
# Conflicts: # management/internals/modules/reverseproxy/service/manager/api.go # management/server/http/handler.go
This commit is contained in:
@@ -9,4 +9,5 @@ type Manager interface {
|
||||
CreateDomain(ctx context.Context, accountID, userID, domainName, targetCluster string) (*Domain, error)
|
||||
DeleteDomain(ctx context.Context, accountID, userID, domainID string) error
|
||||
ValidateDomain(ctx context.Context, accountID, userID, domainID string)
|
||||
GetClusterDomains() []string
|
||||
}
|
||||
|
||||
@@ -27,21 +27,21 @@ type store interface {
|
||||
DeleteCustomDomain(ctx context.Context, accountID string, domainID string) error
|
||||
}
|
||||
|
||||
type proxyURLProvider interface {
|
||||
GetConnectedProxyURLs() []string
|
||||
type proxyManager interface {
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store store
|
||||
validator domain.Validator
|
||||
proxyURLProvider proxyURLProvider
|
||||
proxyManager proxyManager
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func NewManager(store store, proxyURLProvider proxyURLProvider, permissionsManager permissions.Manager) Manager {
|
||||
func NewManager(store store, proxyMgr proxyManager, permissionsManager permissions.Manager) Manager {
|
||||
return Manager{
|
||||
store: store,
|
||||
proxyURLProvider: proxyURLProvider,
|
||||
store: store,
|
||||
proxyManager: proxyMgr,
|
||||
validator: domain.Validator{
|
||||
Resolver: net.DefaultResolver,
|
||||
},
|
||||
@@ -67,8 +67,12 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
|
||||
|
||||
// Add connected proxy clusters as free domains.
|
||||
// The cluster address itself is the free domain base (e.g., "eu.proxy.netbird.io").
|
||||
allowList := m.proxyURLAllowList()
|
||||
log.WithFields(log.Fields{
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"accountID": accountID,
|
||||
"proxyAllowList": allowList,
|
||||
}).Debug("getting domains with proxy allow list")
|
||||
@@ -107,7 +111,10 @@ func (m Manager) CreateDomain(ctx context.Context, accountID, userID, domainName
|
||||
}
|
||||
|
||||
// Verify the target cluster is in the available clusters
|
||||
allowList := m.proxyURLAllowList()
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
clusterValid := false
|
||||
for _, cluster := range allowList {
|
||||
if cluster == targetCluster {
|
||||
@@ -221,21 +228,26 @@ func (m Manager) ValidateDomain(ctx context.Context, accountID, userID, domainID
|
||||
}
|
||||
}
|
||||
|
||||
// proxyURLAllowList retrieves a list of currently connected proxies and
|
||||
// their URLs
|
||||
func (m Manager) proxyURLAllowList() []string {
|
||||
var reverseProxyAddresses []string
|
||||
if m.proxyURLProvider != nil {
|
||||
reverseProxyAddresses = m.proxyURLProvider.GetConnectedProxyURLs()
|
||||
// GetClusterDomains returns a list of proxy cluster domains.
|
||||
func (m Manager) GetClusterDomains() []string {
|
||||
if m.proxyManager == nil {
|
||||
return nil
|
||||
}
|
||||
return reverseProxyAddresses
|
||||
addresses, err := m.proxyManager.GetActiveClusterAddresses(context.Background())
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return addresses
|
||||
}
|
||||
|
||||
// DeriveClusterFromDomain determines the proxy cluster for a given domain.
|
||||
// For free domains (those ending with a known cluster suffix), the cluster is extracted from the domain.
|
||||
// For custom domains, the cluster is determined by checking the registered custom domain's target cluster.
|
||||
func (m Manager) DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error) {
|
||||
allowList := m.proxyURLAllowList()
|
||||
allowList, err := m.proxyManager.GetActiveClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active proxy cluster addresses: %w", err)
|
||||
}
|
||||
if len(allowList) == 0 {
|
||||
return "", fmt.Errorf("no proxy clusters available")
|
||||
}
|
||||
|
||||
@@ -1,609 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const unknownHostPlaceholder = "unknown"
|
||||
|
||||
// ClusterDeriver derives the proxy cluster from a domain.
|
||||
type ClusterDeriver interface {
|
||||
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
||||
clusterDeriver ClusterDeriver
|
||||
}
|
||||
|
||||
// NewManager creates a new service manager.
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
proxyGRPCServer: proxyGRPCServer,
|
||||
clusterDeriver: clusterDeriver,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) replaceHostByLookup(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
for _, target := range service.Targets {
|
||||
switch target.TargetType {
|
||||
case reverseproxy.TargetTypePeer:
|
||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = peer.IP.String()
|
||||
case reverseproxy.TargetTypeHost:
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = resource.Prefix.Addr().String()
|
||||
case reverseproxy.TargetTypeDomain:
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, service.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = resource.Domain
|
||||
case reverseproxy.TargetTypeSubnet:
|
||||
// For subnets we do not do any lookups on the resource
|
||||
default:
|
||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetService(ctx context.Context, accountID, userID, serviceID string) (*reverseproxy.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) CreateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, service); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.persistNewService(ctx, accountID, service); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceCreated, service.EventMeta())
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) initializeServiceForCreate(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
if m.clusterDeriver != nil {
|
||||
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
|
||||
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
|
||||
}
|
||||
service.ProxyCluster = proxyCluster
|
||||
}
|
||||
|
||||
service.AccountID = accountID
|
||||
service.InitNewRecord()
|
||||
|
||||
if err := service.Auth.HashSecrets(); err != nil {
|
||||
return fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
keyPair, err := sessionkey.GenerateKeyPair()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate session keys: %w", err)
|
||||
}
|
||||
service.SessionPrivateKey = keyPair.PrivateKey
|
||||
service.SessionPublicKey = keyPair.PublicKey
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) persistNewService(ctx context.Context, accountID string, service *reverseproxy.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *managerImpl) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing service: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if existingService != nil && existingService.ID != excludeServiceID {
|
||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) UpdateService(ctx context.Context, accountID, userID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := service.Auth.HashSecrets(); err != nil {
|
||||
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
updateInfo, err := m.persistServiceUpdate(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
|
||||
|
||||
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m.sendServiceUpdateNotifications(service, updateInfo)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
type serviceUpdateInfo struct {
|
||||
oldCluster string
|
||||
domainChanged bool
|
||||
serviceEnabledChanged bool
|
||||
}
|
||||
|
||||
func (m *managerImpl) persistServiceUpdate(ctx context.Context, accountID string, service *reverseproxy.Service) (*serviceUpdateInfo, error) {
|
||||
var updateInfo serviceUpdateInfo
|
||||
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateInfo.oldCluster = existingService.ProxyCluster
|
||||
updateInfo.domainChanged = existingService.Domain != service.Domain
|
||||
|
||||
if updateInfo.domainChanged {
|
||||
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
service.ProxyCluster = existingService.ProxyCluster
|
||||
}
|
||||
|
||||
m.preserveExistingAuthSecrets(service, existingService)
|
||||
m.preserveServiceMetadata(service, existingService)
|
||||
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("update service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return &updateInfo, err
|
||||
}
|
||||
|
||||
func (m *managerImpl) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *reverseproxy.Service) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.clusterDeriver != nil {
|
||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
||||
} else {
|
||||
service.ProxyCluster = newCluster
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) preserveExistingAuthSecrets(service, existingService *reverseproxy.Service) {
|
||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||
service.Auth.PasswordAuth.Password == "" {
|
||||
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
||||
}
|
||||
|
||||
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
||||
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
||||
service.Auth.PinAuth.Pin == "" {
|
||||
service.Auth.PinAuth = existingService.Auth.PinAuth
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) preserveServiceMetadata(service, existingService *reverseproxy.Service) {
|
||||
service.Meta = existingService.Meta
|
||||
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||
service.SessionPublicKey = existingService.SessionPublicKey
|
||||
}
|
||||
|
||||
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
|
||||
switch {
|
||||
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
|
||||
m.sendServiceUpdate(service, reverseproxy.Delete, updateInfo.oldCluster, "")
|
||||
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
|
||||
case !service.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
|
||||
case service.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
|
||||
default:
|
||||
m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) sendServiceUpdate(service *reverseproxy.Service, operation reverseproxy.Operation, cluster, oldService string) {
|
||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
mapping := service.ToProtoMapping(operation, oldService, oidcCfg)
|
||||
m.sendMappingsToCluster([]*proto.ProxyMapping{mapping}, cluster)
|
||||
}
|
||||
|
||||
func (m *managerImpl) sendMappingsToCluster(mappings []*proto.ProxyMapping, cluster string) {
|
||||
if len(mappings) == 0 {
|
||||
return
|
||||
}
|
||||
update := &proto.GetMappingUpdateResponse{
|
||||
Mapping: mappings,
|
||||
}
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(update, cluster)
|
||||
}
|
||||
|
||||
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
|
||||
for _, target := range targets {
|
||||
switch target.TargetType {
|
||||
case reverseproxy.TargetTypePeer:
|
||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
|
||||
}
|
||||
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
|
||||
}
|
||||
case reverseproxy.TargetTypeHost, reverseproxy.TargetTypeSubnet, reverseproxy.TargetTypeDomain:
|
||||
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
|
||||
}
|
||||
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var service *reverseproxy.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
service, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("failed to delete service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
|
||||
|
||||
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var services []*reverseproxy.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
services, err = transaction.GetServicesByAccountID(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
if err = transaction.DeleteService(ctx, accountID, service.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete service: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clusterMappings := make(map[string][]*proto.ProxyMapping)
|
||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
|
||||
for _, service := range services {
|
||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, service.EventMeta())
|
||||
mapping := service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg)
|
||||
clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping)
|
||||
}
|
||||
|
||||
for cluster, mappings := range clusterMappings {
|
||||
m.sendMappingsToCluster(mappings, cluster)
|
||||
}
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
|
||||
// Call this when receiving a gRPC notification that the certificate was issued.
|
||||
func (m *managerImpl) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
service.Meta.CertificateIssuedAt = time.Now()
|
||||
|
||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
|
||||
func (m *managerImpl) SetStatus(ctx context.Context, accountID, serviceID string, status reverseproxy.ProxyStatus) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
service.Meta.Status = string(status)
|
||||
|
||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to update service status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "")
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
clusterMappings := make(map[string][]*proto.ProxyMapping)
|
||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
mapping := service.ToProtoMapping(reverseproxy.Update, "", oidcCfg)
|
||||
clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping)
|
||||
}
|
||||
|
||||
for cluster, mappings := range clusterMappings {
|
||||
m.sendMappingsToCluster(mappings, cluster)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetGlobalServices(ctx context.Context) ([]*reverseproxy.Service, error) {
|
||||
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, service.AccountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetServiceByID(ctx context.Context, accountID, serviceID string) (*reverseproxy.Service, error) {
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
|
||||
}
|
||||
|
||||
if target == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return target.ServiceID, nil
|
||||
}
|
||||
@@ -1,375 +0,0 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
func TestInitializeServiceForCreate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("successful initialization without cluster deriver", func(t *testing.T) {
|
||||
mgr := &managerImpl{
|
||||
clusterDeriver: nil,
|
||||
}
|
||||
|
||||
service := &reverseproxy.Service{
|
||||
Domain: "example.com",
|
||||
Auth: reverseproxy.AuthConfig{},
|
||||
}
|
||||
|
||||
err := mgr.initializeServiceForCreate(ctx, accountID, service)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, accountID, service.AccountID)
|
||||
assert.Empty(t, service.ProxyCluster, "proxy cluster should be empty when no deriver")
|
||||
assert.NotEmpty(t, service.ID, "service ID should be initialized")
|
||||
assert.NotEmpty(t, service.SessionPrivateKey, "session private key should be generated")
|
||||
assert.NotEmpty(t, service.SessionPublicKey, "session public key should be generated")
|
||||
})
|
||||
|
||||
t.Run("verifies session keys are different", func(t *testing.T) {
|
||||
mgr := &managerImpl{
|
||||
clusterDeriver: nil,
|
||||
}
|
||||
|
||||
service1 := &reverseproxy.Service{Domain: "test1.com", Auth: reverseproxy.AuthConfig{}}
|
||||
service2 := &reverseproxy.Service{Domain: "test2.com", Auth: reverseproxy.AuthConfig{}}
|
||||
|
||||
err1 := mgr.initializeServiceForCreate(ctx, accountID, service1)
|
||||
err2 := mgr.initializeServiceForCreate(ctx, accountID, service2)
|
||||
|
||||
assert.NoError(t, err1)
|
||||
assert.NoError(t, err2)
|
||||
assert.NotEqual(t, service1.SessionPrivateKey, service2.SessionPrivateKey, "private keys should be unique")
|
||||
assert.NotEqual(t, service1.SessionPublicKey, service2.SessionPublicKey, "public keys should be unique")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckDomainAvailable(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
domain string
|
||||
excludeServiceID string
|
||||
setupMock func(*store.MockStore)
|
||||
expectedError bool
|
||||
errorType status.Type
|
||||
}{
|
||||
{
|
||||
name: "domain available - not found",
|
||||
domain: "available.com",
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "available.com").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "domain already exists",
|
||||
domain: "exists.com",
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "existing-id", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
errorType: status.AlreadyExists,
|
||||
},
|
||||
{
|
||||
name: "domain exists but excluded (same ID)",
|
||||
domain: "exists.com",
|
||||
excludeServiceID: "service-123",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "domain exists with different ID",
|
||||
domain: "exists.com",
|
||||
excludeServiceID: "service-456",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "exists.com").
|
||||
Return(&reverseproxy.Service{ID: "service-123", Domain: "exists.com"}, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
errorType: status.AlreadyExists,
|
||||
},
|
||||
{
|
||||
name: "store error (non-NotFound)",
|
||||
domain: "error.com",
|
||||
excludeServiceID: "",
|
||||
setupMock: func(ms *store.MockStore) {
|
||||
ms.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "error.com").
|
||||
Return(nil, errors.New("database error"))
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
tt.setupMock(mockStore)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
|
||||
|
||||
if tt.expectedError {
|
||||
require.Error(t, err)
|
||||
if tt.errorType != 0 {
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok, "error should be a status error")
|
||||
assert.Equal(t, tt.errorType, sErr.Type())
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("empty domain", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("empty exclude ID with existing service", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "test.com").
|
||||
Return(&reverseproxy.Service{ID: "some-id", Domain: "test.com"}, nil)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
|
||||
|
||||
assert.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, status.AlreadyExists, sErr.Type())
|
||||
})
|
||||
|
||||
t.Run("nil existing service with nil error", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
mockStore.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "nil.com").
|
||||
Return(nil, nil)
|
||||
|
||||
mgr := &managerImpl{}
|
||||
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPersistNewService(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
accountID := "test-account"
|
||||
|
||||
t.Run("successful service creation with no targets", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
service := &reverseproxy.Service{
|
||||
ID: "service-123",
|
||||
Domain: "new.com",
|
||||
Targets: []*reverseproxy.Target{},
|
||||
}
|
||||
|
||||
// Mock ExecuteInTransaction to execute the function immediately
|
||||
mockStore.EXPECT().
|
||||
ExecuteInTransaction(ctx, gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
||||
// Create another mock for the transaction
|
||||
txMock := store.NewMockStore(ctrl)
|
||||
txMock.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "new.com").
|
||||
Return(nil, status.Errorf(status.NotFound, "not found"))
|
||||
txMock.EXPECT().
|
||||
CreateService(ctx, service).
|
||||
Return(nil)
|
||||
|
||||
return fn(txMock)
|
||||
})
|
||||
|
||||
mgr := &managerImpl{store: mockStore}
|
||||
err := mgr.persistNewService(ctx, accountID, service)
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("domain already exists", func(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
service := &reverseproxy.Service{
|
||||
ID: "service-123",
|
||||
Domain: "existing.com",
|
||||
Targets: []*reverseproxy.Target{},
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
ExecuteInTransaction(ctx, gomock.Any()).
|
||||
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
|
||||
txMock := store.NewMockStore(ctrl)
|
||||
txMock.EXPECT().
|
||||
GetServiceByDomain(ctx, accountID, "existing.com").
|
||||
Return(&reverseproxy.Service{ID: "other-id", Domain: "existing.com"}, nil)
|
||||
|
||||
return fn(txMock)
|
||||
})
|
||||
|
||||
mgr := &managerImpl{store: mockStore}
|
||||
err := mgr.persistNewService(ctx, accountID, service)
|
||||
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, status.AlreadyExists, sErr.Type())
|
||||
})
|
||||
}
|
||||
func TestPreserveExistingAuthSecrets(t *testing.T) {
|
||||
mgr := &managerImpl{}
|
||||
|
||||
t.Run("preserve password when empty", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "hashed-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||
|
||||
assert.Equal(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
|
||||
})
|
||||
|
||||
t.Run("preserve pin when empty", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PinAuth: &reverseproxy.PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "hashed-pin",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PinAuth: &reverseproxy.PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||
|
||||
assert.Equal(t, existing.Auth.PinAuth, updated.Auth.PinAuth)
|
||||
})
|
||||
|
||||
t.Run("do not preserve when password is provided", func(t *testing.T) {
|
||||
existing := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "old-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Auth: reverseproxy.AuthConfig{
|
||||
PasswordAuth: &reverseproxy.PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "new-password",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr.preserveExistingAuthSecrets(updated, existing)
|
||||
|
||||
assert.Equal(t, "new-password", updated.Auth.PasswordAuth.Password)
|
||||
assert.NotEqual(t, existing.Auth.PasswordAuth, updated.Auth.PasswordAuth)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPreserveServiceMetadata(t *testing.T) {
|
||||
mgr := &managerImpl{}
|
||||
|
||||
existing := &reverseproxy.Service{
|
||||
Meta: reverseproxy.ServiceMeta{
|
||||
CertificateIssuedAt: time.Now(),
|
||||
Status: "active",
|
||||
},
|
||||
SessionPrivateKey: "private-key",
|
||||
SessionPublicKey: "public-key",
|
||||
}
|
||||
|
||||
updated := &reverseproxy.Service{
|
||||
Domain: "updated.com",
|
||||
}
|
||||
|
||||
mgr.preserveServiceMetadata(updated, existing)
|
||||
|
||||
assert.Equal(t, existing.Meta, updated.Meta)
|
||||
assert.Equal(t, existing.SessionPrivateKey, updated.SessionPrivateKey)
|
||||
assert.Equal(t, existing.SessionPublicKey, updated.SessionPublicKey)
|
||||
}
|
||||
36
management/internals/modules/reverseproxy/proxy/manager.go
Normal file
36
management/internals/modules/reverseproxy/proxy/manager.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package proxy
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -package proxy -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// Manager defines the interface for proxy operations
|
||||
type Manager interface {
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
Disconnect(ctx context.Context, proxyID string) error
|
||||
Heartbeat(ctx context.Context, proxyID string) error
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
|
||||
}
|
||||
|
||||
// OIDCValidationConfig contains the OIDC configuration needed for token validation.
|
||||
type OIDCValidationConfig struct {
|
||||
Issuer string
|
||||
Audiences []string
|
||||
KeysLocation string
|
||||
MaxTokenAgeSeconds int64
|
||||
}
|
||||
|
||||
// Controller is responsible for managing proxy clusters and routing service updates.
|
||||
type Controller interface {
|
||||
SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string)
|
||||
GetOIDCValidationConfig() OIDCValidationConfig
|
||||
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
|
||||
GetProxiesForCluster(clusterAddr string) []string
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// GRPCController is a concrete implementation that manages proxy clusters and sends updates directly via gRPC.
|
||||
type GRPCController struct {
|
||||
proxyGRPCServer *nbgrpc.ProxyServiceServer
|
||||
// Map of cluster address -> set of proxy IDs
|
||||
clusterProxies sync.Map
|
||||
metrics *metrics
|
||||
}
|
||||
|
||||
// NewGRPCController creates a new GRPCController.
|
||||
func NewGRPCController(proxyGRPCServer *nbgrpc.ProxyServiceServer, meter metric.Meter) (*GRPCController, error) {
|
||||
m, err := newMetrics(meter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GRPCController{
|
||||
proxyGRPCServer: proxyGRPCServer,
|
||||
metrics: m,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster sends a service update to a specific proxy cluster.
|
||||
func (c *GRPCController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
|
||||
c.proxyGRPCServer.SendServiceUpdateToCluster(ctx, update, clusterAddr)
|
||||
c.metrics.IncrementServiceUpdateSendCount(clusterAddr)
|
||||
}
|
||||
|
||||
// GetOIDCValidationConfig returns the OIDC validation configuration from the gRPC server.
|
||||
func (c *GRPCController) GetOIDCValidationConfig() proxy.OIDCValidationConfig {
|
||||
return c.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
}
|
||||
|
||||
// RegisterProxyToCluster registers a proxy to a specific cluster for routing.
|
||||
func (c *GRPCController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
if clusterAddr == "" {
|
||||
return nil
|
||||
}
|
||||
proxySet, _ := c.clusterProxies.LoadOrStore(clusterAddr, &sync.Map{})
|
||||
proxySet.(*sync.Map).Store(proxyID, struct{}{})
|
||||
log.WithContext(ctx).Debugf("Registered proxy %s to cluster %s", proxyID, clusterAddr)
|
||||
|
||||
c.metrics.IncrementProxyConnectionCount(clusterAddr)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnregisterProxyFromCluster removes a proxy from a cluster.
|
||||
func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
if clusterAddr == "" {
|
||||
return nil
|
||||
}
|
||||
if proxySet, ok := c.clusterProxies.Load(clusterAddr); ok {
|
||||
proxySet.(*sync.Map).Delete(proxyID)
|
||||
log.WithContext(ctx).Debugf("Unregistered proxy %s from cluster %s", proxyID, clusterAddr)
|
||||
|
||||
c.metrics.DecrementProxyConnectionCount(clusterAddr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
|
||||
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
proxySet, ok := c.clusterProxies.Load(clusterAddr)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
var proxies []string
|
||||
proxySet.(*sync.Map).Range(func(key, _ interface{}) bool {
|
||||
proxies = append(proxies, key.(string))
|
||||
return true
|
||||
})
|
||||
return proxies
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
)
|
||||
|
||||
// store defines the interface for proxy persistence operations
|
||||
type store interface {
|
||||
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
}
|
||||
|
||||
// Manager handles all proxy operations
|
||||
type Manager struct {
|
||||
store store
|
||||
metrics *metrics
|
||||
}
|
||||
|
||||
// NewManager creates a new proxy Manager
|
||||
func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||
m, err := newMetrics(meter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Manager{
|
||||
store: store,
|
||||
metrics: m,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Connect registers a new proxy connection in the database
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
LastSeen: now,
|
||||
ConnectedAt: &now,
|
||||
Status: "connected",
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"proxyID": proxyID,
|
||||
"clusterAddress": clusterAddress,
|
||||
"ipAddress": ipAddress,
|
||||
}).Info("proxy connected")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect marks a proxy as disconnected in the database
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
Status: "disconnected",
|
||||
DisconnectedAt: &now,
|
||||
LastSeen: now,
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"proxyID": proxyID,
|
||||
}).Info("proxy disconnected")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Heartbeat updates the proxy's last seen timestamp
|
||||
func (m Manager) Heartbeat(ctx context.Context, proxyID string) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
return err
|
||||
}
|
||||
m.metrics.IncrementProxyHeartbeatCount()
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses returns all unique cluster addresses for active proxies
|
||||
func (m Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
addresses, err := m.store.GetActiveProxyClusterAddresses(ctx)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get active proxy cluster addresses: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
return addresses, nil
|
||||
}
|
||||
|
||||
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
|
||||
func (m Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to cleanup stale proxies: %v", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
type metrics struct {
|
||||
proxyConnectionCount metric.Int64UpDownCounter
|
||||
serviceUpdateSendCount metric.Int64Counter
|
||||
proxyHeartbeatCount metric.Int64Counter
|
||||
}
|
||||
|
||||
func newMetrics(meter metric.Meter) (*metrics, error) {
|
||||
proxyConnectionCount, err := meter.Int64UpDownCounter(
|
||||
"management_proxy_connection_count",
|
||||
metric.WithDescription("Number of active proxy connections"),
|
||||
metric.WithUnit("{connection}"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serviceUpdateSendCount, err := meter.Int64Counter(
|
||||
"management_proxy_service_update_send_count",
|
||||
metric.WithDescription("Total number of service updates sent to proxies"),
|
||||
metric.WithUnit("{update}"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
proxyHeartbeatCount, err := meter.Int64Counter(
|
||||
"management_proxy_heartbeat_count",
|
||||
metric.WithDescription("Total number of proxy heartbeats received"),
|
||||
metric.WithUnit("{heartbeat}"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &metrics{
|
||||
proxyConnectionCount: proxyConnectionCount,
|
||||
serviceUpdateSendCount: serviceUpdateSendCount,
|
||||
proxyHeartbeatCount: proxyHeartbeatCount,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *metrics) IncrementProxyConnectionCount(clusterAddr string) {
|
||||
m.proxyConnectionCount.Add(context.Background(), 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("cluster", clusterAddr),
|
||||
))
|
||||
}
|
||||
|
||||
func (m *metrics) DecrementProxyConnectionCount(clusterAddr string) {
|
||||
m.proxyConnectionCount.Add(context.Background(), -1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("cluster", clusterAddr),
|
||||
))
|
||||
}
|
||||
|
||||
func (m *metrics) IncrementServiceUpdateSendCount(clusterAddr string) {
|
||||
m.serviceUpdateSendCount.Add(context.Background(), 1,
|
||||
metric.WithAttributes(
|
||||
attribute.String("cluster", clusterAddr),
|
||||
))
|
||||
}
|
||||
|
||||
func (m *metrics) IncrementProxyHeartbeatCount() {
|
||||
m.proxyHeartbeatCount.Add(context.Background(), 1)
|
||||
}
|
||||
199
management/internals/modules/reverseproxy/proxy/manager_mock.go
Normal file
199
management/internals/modules/reverseproxy/proxy/manager_mock.go
Normal file
@@ -0,0 +1,199 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ./manager.go
|
||||
|
||||
// Package proxy is a generated GoMock package.
|
||||
package proxy
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
proto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// MockManager is a mock of Manager interface.
|
||||
type MockManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockManagerMockRecorder is the mock recorder for MockManager.
|
||||
type MockManagerMockRecorder struct {
|
||||
mock *MockManager
|
||||
}
|
||||
|
||||
// NewMockManager creates a new mock instance.
|
||||
func NewMockManager(ctrl *gomock.Controller) *MockManager {
|
||||
mock := &MockManager{ctrl: ctrl}
|
||||
mock.recorder = &MockManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CleanupStale mocks base method.
|
||||
func (m *MockManager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CleanupStale", ctx, inactivityDuration)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CleanupStale indicates an expected call of CleanupStale.
|
||||
func (mr *MockManagerMockRecorder) CleanupStale(ctx, inactivityDuration interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStale", reflect.TypeOf((*MockManager)(nil).CleanupStale), ctx, inactivityDuration)
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Connect indicates an expected call of Connect.
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress)
|
||||
}
|
||||
|
||||
// Disconnect mocks base method.
|
||||
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Disconnect indicates an expected call of Disconnect.
|
||||
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses mocks base method.
|
||||
func (m *MockManager) GetActiveClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetActiveClusterAddresses", ctx)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses indicates an expected call of GetActiveClusterAddresses.
|
||||
func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
|
||||
}
|
||||
|
||||
// Heartbeat mocks base method.
|
||||
func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Heartbeat indicates an expected call of Heartbeat.
|
||||
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID)
|
||||
}
|
||||
|
||||
// MockController is a mock of Controller interface.
|
||||
type MockController struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockControllerMockRecorder
|
||||
}
|
||||
|
||||
// MockControllerMockRecorder is the mock recorder for MockController.
|
||||
type MockControllerMockRecorder struct {
|
||||
mock *MockController
|
||||
}
|
||||
|
||||
// NewMockController creates a new mock instance.
|
||||
func NewMockController(ctrl *gomock.Controller) *MockController {
|
||||
mock := &MockController{ctrl: ctrl}
|
||||
mock.recorder = &MockControllerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockController) EXPECT() *MockControllerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetOIDCValidationConfig mocks base method.
|
||||
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOIDCValidationConfig")
|
||||
ret0, _ := ret[0].(OIDCValidationConfig)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetOIDCValidationConfig indicates an expected call of GetOIDCValidationConfig.
|
||||
func (mr *MockControllerMockRecorder) GetOIDCValidationConfig() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOIDCValidationConfig", reflect.TypeOf((*MockController)(nil).GetOIDCValidationConfig))
|
||||
}
|
||||
|
||||
// GetProxiesForCluster mocks base method.
|
||||
func (m *MockController) GetProxiesForCluster(clusterAddr string) []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProxiesForCluster", clusterAddr)
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetProxiesForCluster indicates an expected call of GetProxiesForCluster.
|
||||
func (mr *MockControllerMockRecorder) GetProxiesForCluster(clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProxiesForCluster", reflect.TypeOf((*MockController)(nil).GetProxiesForCluster), clusterAddr)
|
||||
}
|
||||
|
||||
// RegisterProxyToCluster mocks base method.
|
||||
func (m *MockController) RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RegisterProxyToCluster", ctx, clusterAddr, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RegisterProxyToCluster indicates an expected call of RegisterProxyToCluster.
|
||||
func (mr *MockControllerMockRecorder) RegisterProxyToCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterProxyToCluster", reflect.TypeOf((*MockController)(nil).RegisterProxyToCluster), ctx, clusterAddr, proxyID)
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster mocks base method.
|
||||
func (m *MockController) SendServiceUpdateToCluster(ctx context.Context, accountID string, update *proto.ProxyMapping, clusterAddr string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SendServiceUpdateToCluster", ctx, accountID, update, clusterAddr)
|
||||
}
|
||||
|
||||
// SendServiceUpdateToCluster indicates an expected call of SendServiceUpdateToCluster.
|
||||
func (mr *MockControllerMockRecorder) SendServiceUpdateToCluster(ctx, accountID, update, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendServiceUpdateToCluster", reflect.TypeOf((*MockController)(nil).SendServiceUpdateToCluster), ctx, accountID, update, clusterAddr)
|
||||
}
|
||||
|
||||
// UnregisterProxyFromCluster mocks base method.
|
||||
func (m *MockController) UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnregisterProxyFromCluster", ctx, clusterAddr, proxyID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UnregisterProxyFromCluster indicates an expected call of UnregisterProxyFromCluster.
|
||||
func (mr *MockControllerMockRecorder) UnregisterProxyFromCluster(ctx, clusterAddr, proxyID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnregisterProxyFromCluster", reflect.TypeOf((*MockController)(nil).UnregisterProxyFromCluster), ctx, clusterAddr, proxyID)
|
||||
}
|
||||
20
management/internals/modules/reverseproxy/proxy/proxy.go
Normal file
20
management/internals/modules/reverseproxy/proxy/proxy.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package proxy
|
||||
|
||||
import "time"
|
||||
|
||||
// Proxy represents a reverse proxy instance
|
||||
type Proxy struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||
IPAddress string `gorm:"type:varchar(45)"`
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
func (Proxy) TableName() string {
|
||||
return "proxies"
|
||||
}
|
||||
@@ -1,463 +0,0 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type Operation string
|
||||
|
||||
const (
|
||||
Create Operation = "create"
|
||||
Update Operation = "update"
|
||||
Delete Operation = "delete"
|
||||
)
|
||||
|
||||
type ProxyStatus string
|
||||
|
||||
const (
|
||||
StatusPending ProxyStatus = "pending"
|
||||
StatusActive ProxyStatus = "active"
|
||||
StatusTunnelNotCreated ProxyStatus = "tunnel_not_created"
|
||||
StatusCertificatePending ProxyStatus = "certificate_pending"
|
||||
StatusCertificateFailed ProxyStatus = "certificate_failed"
|
||||
StatusError ProxyStatus = "error"
|
||||
|
||||
TargetTypePeer = "peer"
|
||||
TargetTypeHost = "host"
|
||||
TargetTypeDomain = "domain"
|
||||
TargetTypeSubnet = "subnet"
|
||||
)
|
||||
|
||||
type Target struct {
|
||||
ID uint `gorm:"primaryKey" json:"-"`
|
||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||
Port int `gorm:"index:idx_target_port" json:"port"`
|
||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
|
||||
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
||||
}
|
||||
|
||||
type PasswordAuthConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type PINAuthConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Pin string `json:"pin"`
|
||||
}
|
||||
|
||||
type BearerAuthConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
|
||||
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
|
||||
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
func (a *AuthConfig) HashSecrets() error {
|
||||
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
|
||||
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
a.PasswordAuth.Password = hashedPassword
|
||||
}
|
||||
|
||||
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
|
||||
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash pin: %w", err)
|
||||
}
|
||||
a.PinAuth.Pin = hashedPin
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AuthConfig) ClearSecrets() {
|
||||
if a.PasswordAuth != nil {
|
||||
a.PasswordAuth.Password = ""
|
||||
}
|
||||
if a.PinAuth != nil {
|
||||
a.PinAuth.Pin = ""
|
||||
}
|
||||
}
|
||||
|
||||
type OIDCValidationConfig struct {
|
||||
Issuer string
|
||||
Audiences []string
|
||||
KeysLocation string
|
||||
MaxTokenAgeSeconds int64
|
||||
}
|
||||
|
||||
type ServiceMeta struct {
|
||||
CreatedAt time.Time
|
||||
CertificateIssuedAt time.Time
|
||||
Status string
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Domain string `gorm:"index"`
|
||||
ProxyCluster string `gorm:"index"`
|
||||
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
|
||||
Enabled bool
|
||||
PassHostHeader bool
|
||||
RewriteRedirects bool
|
||||
Auth AuthConfig `gorm:"serializer:json"`
|
||||
Meta ServiceMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||
}
|
||||
|
||||
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
|
||||
for _, target := range targets {
|
||||
target.AccountID = accountID
|
||||
}
|
||||
|
||||
s := &Service{
|
||||
AccountID: accountID,
|
||||
Name: name,
|
||||
Domain: domain,
|
||||
ProxyCluster: proxyCluster,
|
||||
Targets: targets,
|
||||
Enabled: enabled,
|
||||
}
|
||||
s.InitNewRecord()
|
||||
return s
|
||||
}
|
||||
|
||||
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
||||
// Service record. This overwrites any existing ID and Meta fields and should
|
||||
// only be called during initial creation, not for updates.
|
||||
func (s *Service) InitNewRecord() {
|
||||
s.ID = xid.New().String()
|
||||
s.Meta = ServiceMeta{
|
||||
CreatedAt: time.Now(),
|
||||
Status: string(StatusPending),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) ToAPIResponse() *api.Service {
|
||||
s.Auth.ClearSecrets()
|
||||
|
||||
authConfig := api.ServiceAuthConfig{}
|
||||
|
||||
if s.Auth.PasswordAuth != nil {
|
||||
authConfig.PasswordAuth = &api.PasswordAuthConfig{
|
||||
Enabled: s.Auth.PasswordAuth.Enabled,
|
||||
Password: s.Auth.PasswordAuth.Password,
|
||||
}
|
||||
}
|
||||
|
||||
if s.Auth.PinAuth != nil {
|
||||
authConfig.PinAuth = &api.PINAuthConfig{
|
||||
Enabled: s.Auth.PinAuth.Enabled,
|
||||
Pin: s.Auth.PinAuth.Pin,
|
||||
}
|
||||
}
|
||||
|
||||
if s.Auth.BearerAuth != nil {
|
||||
authConfig.BearerAuth = &api.BearerAuthConfig{
|
||||
Enabled: s.Auth.BearerAuth.Enabled,
|
||||
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
|
||||
}
|
||||
}
|
||||
|
||||
// Convert internal targets to API targets
|
||||
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
|
||||
for _, target := range s.Targets {
|
||||
apiTargets = append(apiTargets, api.ServiceTarget{
|
||||
Path: target.Path,
|
||||
Host: &target.Host,
|
||||
Port: target.Port,
|
||||
Protocol: api.ServiceTargetProtocol(target.Protocol),
|
||||
TargetId: target.TargetId,
|
||||
TargetType: api.ServiceTargetTargetType(target.TargetType),
|
||||
Enabled: target.Enabled,
|
||||
})
|
||||
}
|
||||
|
||||
meta := api.ServiceMeta{
|
||||
CreatedAt: s.Meta.CreatedAt,
|
||||
Status: api.ServiceMetaStatus(s.Meta.Status),
|
||||
}
|
||||
|
||||
if !s.Meta.CertificateIssuedAt.IsZero() {
|
||||
meta.CertificateIssuedAt = &s.Meta.CertificateIssuedAt
|
||||
}
|
||||
|
||||
resp := &api.Service{
|
||||
Id: s.ID,
|
||||
Name: s.Name,
|
||||
Domain: s.Domain,
|
||||
Targets: apiTargets,
|
||||
Enabled: s.Enabled,
|
||||
PassHostHeader: &s.PassHostHeader,
|
||||
RewriteRedirects: &s.RewriteRedirects,
|
||||
Auth: authConfig,
|
||||
Meta: meta,
|
||||
}
|
||||
|
||||
if s.ProxyCluster != "" {
|
||||
resp.ProxyCluster = &s.ProxyCluster
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig OIDCValidationConfig) *proto.ProxyMapping {
|
||||
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
|
||||
for _, target := range s.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: Make path prefix stripping configurable per-target.
|
||||
// Currently the matching prefix is baked into the target URL path,
|
||||
// so the proxy strips-then-re-adds it (effectively a no-op).
|
||||
targetURL := url.URL{
|
||||
Scheme: target.Protocol,
|
||||
Host: target.Host,
|
||||
Path: "/", // TODO: support service path
|
||||
}
|
||||
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
|
||||
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
|
||||
}
|
||||
|
||||
path := "/"
|
||||
if target.Path != nil {
|
||||
path = *target.Path
|
||||
}
|
||||
pathMappings = append(pathMappings, &proto.PathMapping{
|
||||
Path: path,
|
||||
Target: targetURL.String(),
|
||||
})
|
||||
}
|
||||
|
||||
auth := &proto.Authentication{
|
||||
SessionKey: s.SessionPublicKey,
|
||||
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
|
||||
}
|
||||
|
||||
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
|
||||
auth.Password = true
|
||||
}
|
||||
|
||||
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
|
||||
auth.Pin = true
|
||||
}
|
||||
|
||||
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
|
||||
auth.Oidc = true
|
||||
}
|
||||
|
||||
return &proto.ProxyMapping{
|
||||
Type: operationToProtoType(operation),
|
||||
Id: s.ID,
|
||||
Domain: s.Domain,
|
||||
Path: pathMappings,
|
||||
AuthToken: authToken,
|
||||
Auth: auth,
|
||||
AccountId: s.AccountID,
|
||||
PassHostHeader: s.PassHostHeader,
|
||||
RewriteRedirects: s.RewriteRedirects,
|
||||
}
|
||||
}
|
||||
|
||||
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
||||
switch op {
|
||||
case Create:
|
||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||
case Update:
|
||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
|
||||
case Delete:
|
||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
|
||||
default:
|
||||
log.Fatalf("unknown operation type: %v", op)
|
||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||
}
|
||||
}
|
||||
|
||||
// isDefaultPort reports whether port is the standard default for the given scheme
|
||||
// (443 for https, 80 for http).
|
||||
func isDefaultPort(scheme string, port int) bool {
|
||||
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
|
||||
}
|
||||
|
||||
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
|
||||
s.Name = req.Name
|
||||
s.Domain = req.Domain
|
||||
s.AccountID = accountID
|
||||
|
||||
targets := make([]*Target, 0, len(req.Targets))
|
||||
for _, apiTarget := range req.Targets {
|
||||
target := &Target{
|
||||
AccountID: accountID,
|
||||
Path: apiTarget.Path,
|
||||
Port: apiTarget.Port,
|
||||
Protocol: string(apiTarget.Protocol),
|
||||
TargetId: apiTarget.TargetId,
|
||||
TargetType: string(apiTarget.TargetType),
|
||||
Enabled: apiTarget.Enabled,
|
||||
}
|
||||
if apiTarget.Host != nil {
|
||||
target.Host = *apiTarget.Host
|
||||
}
|
||||
targets = append(targets, target)
|
||||
}
|
||||
s.Targets = targets
|
||||
|
||||
s.Enabled = req.Enabled
|
||||
|
||||
if req.PassHostHeader != nil {
|
||||
s.PassHostHeader = *req.PassHostHeader
|
||||
}
|
||||
|
||||
if req.RewriteRedirects != nil {
|
||||
s.RewriteRedirects = *req.RewriteRedirects
|
||||
}
|
||||
|
||||
if req.Auth.PasswordAuth != nil {
|
||||
s.Auth.PasswordAuth = &PasswordAuthConfig{
|
||||
Enabled: req.Auth.PasswordAuth.Enabled,
|
||||
Password: req.Auth.PasswordAuth.Password,
|
||||
}
|
||||
}
|
||||
|
||||
if req.Auth.PinAuth != nil {
|
||||
s.Auth.PinAuth = &PINAuthConfig{
|
||||
Enabled: req.Auth.PinAuth.Enabled,
|
||||
Pin: req.Auth.PinAuth.Pin,
|
||||
}
|
||||
}
|
||||
|
||||
if req.Auth.BearerAuth != nil {
|
||||
bearerAuth := &BearerAuthConfig{
|
||||
Enabled: req.Auth.BearerAuth.Enabled,
|
||||
}
|
||||
if req.Auth.BearerAuth.DistributionGroups != nil {
|
||||
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
|
||||
}
|
||||
s.Auth.BearerAuth = bearerAuth
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Validate() error {
|
||||
if s.Name == "" {
|
||||
return errors.New("service name is required")
|
||||
}
|
||||
if len(s.Name) > 255 {
|
||||
return errors.New("service name exceeds maximum length of 255 characters")
|
||||
}
|
||||
|
||||
if s.Domain == "" {
|
||||
return errors.New("service domain is required")
|
||||
}
|
||||
|
||||
if len(s.Targets) == 0 {
|
||||
return errors.New("at least one target is required")
|
||||
}
|
||||
|
||||
for i, target := range s.Targets {
|
||||
switch target.TargetType {
|
||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||
// host field will be ignored
|
||||
case TargetTypeSubnet:
|
||||
if target.Host == "" {
|
||||
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
||||
}
|
||||
if target.TargetId == "" {
|
||||
return fmt.Errorf("target %d has empty target_id", i)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) EventMeta() map[string]any {
|
||||
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster}
|
||||
}
|
||||
|
||||
func (s *Service) Copy() *Service {
|
||||
targets := make([]*Target, len(s.Targets))
|
||||
for i, target := range s.Targets {
|
||||
targetCopy := *target
|
||||
targets[i] = &targetCopy
|
||||
}
|
||||
|
||||
return &Service{
|
||||
ID: s.ID,
|
||||
AccountID: s.AccountID,
|
||||
Name: s.Name,
|
||||
Domain: s.Domain,
|
||||
ProxyCluster: s.ProxyCluster,
|
||||
Targets: targets,
|
||||
Enabled: s.Enabled,
|
||||
PassHostHeader: s.PassHostHeader,
|
||||
RewriteRedirects: s.RewriteRedirects,
|
||||
Auth: s.Auth,
|
||||
Meta: s.Meta,
|
||||
SessionPrivateKey: s.SessionPrivateKey,
|
||||
SessionPublicKey: s.SessionPublicKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.SessionPrivateKey != "" {
|
||||
var err error
|
||||
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.SessionPrivateKey != "" {
|
||||
var err error
|
||||
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,405 +0,0 @@
|
||||
package reverseproxy
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func validProxy() *Service {
|
||||
return &Service{
|
||||
Name: "test",
|
||||
Domain: "example.com",
|
||||
Targets: []*Target{
|
||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_Valid(t *testing.T) {
|
||||
require.NoError(t, validProxy().Validate())
|
||||
}
|
||||
|
||||
func TestValidate_EmptyName(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Name = ""
|
||||
assert.ErrorContains(t, rp.Validate(), "name is required")
|
||||
}
|
||||
|
||||
func TestValidate_EmptyDomain(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Domain = ""
|
||||
assert.ErrorContains(t, rp.Validate(), "domain is required")
|
||||
}
|
||||
|
||||
func TestValidate_NoTargets(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = nil
|
||||
assert.ErrorContains(t, rp.Validate(), "at least one target")
|
||||
}
|
||||
|
||||
func TestValidate_EmptyTargetId(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].TargetId = ""
|
||||
assert.ErrorContains(t, rp.Validate(), "empty target_id")
|
||||
}
|
||||
|
||||
func TestValidate_InvalidTargetType(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].TargetType = "invalid"
|
||||
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
|
||||
}
|
||||
|
||||
func TestValidate_ResourceTarget(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = append(rp.Targets, &Target{
|
||||
TargetId: "resource-1",
|
||||
TargetType: TargetTypeHost,
|
||||
Host: "example.org",
|
||||
Port: 443,
|
||||
Protocol: "https",
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = append(rp.Targets, &Target{
|
||||
TargetId: "",
|
||||
TargetType: TargetTypePeer,
|
||||
Host: "10.0.0.2",
|
||||
Port: 80,
|
||||
Protocol: "http",
|
||||
Enabled: true,
|
||||
})
|
||||
err := rp.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "target 1")
|
||||
assert.Contains(t, err.Error(), "empty target_id")
|
||||
}
|
||||
|
||||
func TestIsDefaultPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
scheme string
|
||||
port int
|
||||
want bool
|
||||
}{
|
||||
{"http", 80, true},
|
||||
{"https", 443, true},
|
||||
{"http", 443, false},
|
||||
{"https", 80, false},
|
||||
{"http", 8080, false},
|
||||
{"https", 8443, false},
|
||||
{"http", 0, false},
|
||||
{"https", 0, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
|
||||
oidcConfig := OIDCValidationConfig{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
protocol string
|
||||
host string
|
||||
port int
|
||||
wantTarget string
|
||||
}{
|
||||
{
|
||||
name: "http with default port 80 omits port",
|
||||
protocol: "http",
|
||||
host: "10.0.0.1",
|
||||
port: 80,
|
||||
wantTarget: "http://10.0.0.1/",
|
||||
},
|
||||
{
|
||||
name: "https with default port 443 omits port",
|
||||
protocol: "https",
|
||||
host: "10.0.0.1",
|
||||
port: 443,
|
||||
wantTarget: "https://10.0.0.1/",
|
||||
},
|
||||
{
|
||||
name: "port 0 omits port",
|
||||
protocol: "http",
|
||||
host: "10.0.0.1",
|
||||
port: 0,
|
||||
wantTarget: "http://10.0.0.1/",
|
||||
},
|
||||
{
|
||||
name: "non-default port is included",
|
||||
protocol: "http",
|
||||
host: "10.0.0.1",
|
||||
port: 8080,
|
||||
wantTarget: "http://10.0.0.1:8080/",
|
||||
},
|
||||
{
|
||||
name: "https with non-default port is included",
|
||||
protocol: "https",
|
||||
host: "10.0.0.1",
|
||||
port: 8443,
|
||||
wantTarget: "https://10.0.0.1:8443/",
|
||||
},
|
||||
{
|
||||
name: "http port 443 is included",
|
||||
protocol: "http",
|
||||
host: "10.0.0.1",
|
||||
port: 443,
|
||||
wantTarget: "http://10.0.0.1:443/",
|
||||
},
|
||||
{
|
||||
name: "https port 80 is included",
|
||||
protocol: "https",
|
||||
host: "10.0.0.1",
|
||||
port: 80,
|
||||
wantTarget: "https://10.0.0.1:80/",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rp := &Service{
|
||||
ID: "test-id",
|
||||
AccountID: "acc-1",
|
||||
Domain: "example.com",
|
||||
Targets: []*Target{
|
||||
{
|
||||
TargetId: "peer-1",
|
||||
TargetType: TargetTypePeer,
|
||||
Host: tt.host,
|
||||
Port: tt.port,
|
||||
Protocol: tt.protocol,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
|
||||
require.Len(t, pm.Path, 1, "should have one path mapping")
|
||||
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
|
||||
rp := &Service{
|
||||
ID: "test-id",
|
||||
AccountID: "acc-1",
|
||||
Domain: "example.com",
|
||||
Targets: []*Target{
|
||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
|
||||
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
|
||||
},
|
||||
}
|
||||
pm := rp.ToProtoMapping(Create, "token", OIDCValidationConfig{})
|
||||
require.Len(t, pm.Path, 1)
|
||||
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
|
||||
}
|
||||
|
||||
func TestToProtoMapping_OperationTypes(t *testing.T) {
|
||||
rp := validProxy()
|
||||
tests := []struct {
|
||||
op Operation
|
||||
want proto.ProxyMappingUpdateType
|
||||
}{
|
||||
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
|
||||
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
|
||||
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.op), func(t *testing.T) {
|
||||
pm := rp.ToProtoMapping(tt.op, "", OIDCValidationConfig{})
|
||||
assert.Equal(t, tt.want, pm.Type)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthConfig_HashSecrets(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *AuthConfig
|
||||
wantErr bool
|
||||
validate func(*testing.T, *AuthConfig)
|
||||
}{
|
||||
{
|
||||
name: "hash password successfully",
|
||||
config: &AuthConfig{
|
||||
PasswordAuth: &PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "testPassword123",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config *AuthConfig) {
|
||||
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
||||
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
|
||||
}
|
||||
// Verify the hash can be verified
|
||||
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
|
||||
t.Errorf("Hash verification failed: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "hash PIN successfully",
|
||||
config: &AuthConfig{
|
||||
PinAuth: &PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "123456",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config *AuthConfig) {
|
||||
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
|
||||
}
|
||||
// Verify the hash can be verified
|
||||
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
|
||||
t.Errorf("Hash verification failed: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "hash both password and PIN",
|
||||
config: &AuthConfig{
|
||||
PasswordAuth: &PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "password",
|
||||
},
|
||||
PinAuth: &PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "9999",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config *AuthConfig) {
|
||||
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
||||
t.Errorf("Password not hashed with argon2id")
|
||||
}
|
||||
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||
t.Errorf("PIN not hashed with argon2id")
|
||||
}
|
||||
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
|
||||
t.Errorf("Password hash verification failed: %v", err)
|
||||
}
|
||||
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
|
||||
t.Errorf("PIN hash verification failed: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "skip disabled password auth",
|
||||
config: &AuthConfig{
|
||||
PasswordAuth: &PasswordAuthConfig{
|
||||
Enabled: false,
|
||||
Password: "password",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config *AuthConfig) {
|
||||
if config.PasswordAuth.Password != "password" {
|
||||
t.Errorf("Disabled password auth should not be hashed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "skip empty password",
|
||||
config: &AuthConfig{
|
||||
PasswordAuth: &PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config *AuthConfig) {
|
||||
if config.PasswordAuth.Password != "" {
|
||||
t.Errorf("Empty password should remain empty")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "skip nil password auth",
|
||||
config: &AuthConfig{
|
||||
PasswordAuth: nil,
|
||||
PinAuth: &PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "1234",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config *AuthConfig) {
|
||||
if config.PasswordAuth != nil {
|
||||
t.Errorf("PasswordAuth should remain nil")
|
||||
}
|
||||
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||
t.Errorf("PIN should still be hashed")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.HashSecrets()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, tt.config)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
|
||||
config := &AuthConfig{
|
||||
PasswordAuth: &PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "correctPassword",
|
||||
},
|
||||
}
|
||||
|
||||
if err := config.HashSecrets(); err != nil {
|
||||
t.Fatalf("HashSecrets() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify with wrong password should fail
|
||||
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
|
||||
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
||||
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthConfig_ClearSecrets(t *testing.T) {
|
||||
config := &AuthConfig{
|
||||
PasswordAuth: &PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "hashedPassword",
|
||||
},
|
||||
PinAuth: &PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "hashedPin",
|
||||
},
|
||||
}
|
||||
|
||||
config.ClearSecrets()
|
||||
|
||||
if config.PasswordAuth.Password != "" {
|
||||
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
|
||||
}
|
||||
if config.PinAuth.Pin != "" {
|
||||
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package reverseproxy
|
||||
package service
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -package reverseproxy -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||
//go:generate go run github.com/golang/mock/mockgen -package service -destination=interface_mock.go -source=./interface.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -14,11 +14,15 @@ type Manager interface {
|
||||
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
|
||||
DeleteAllServices(ctx context.Context, accountID, userID string) error
|
||||
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
|
||||
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
|
||||
SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
|
||||
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
|
||||
ReloadService(ctx context.Context, accountID, serviceID string) error
|
||||
GetGlobalServices(ctx context.Context) ([]*Service, error)
|
||||
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
|
||||
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
|
||||
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
|
||||
CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error)
|
||||
RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
|
||||
StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
|
||||
StartExposeReaper(ctx context.Context)
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ./interface.go
|
||||
|
||||
// Package reverseproxy is a generated GoMock package.
|
||||
package reverseproxy
|
||||
// Package service is a generated GoMock package.
|
||||
package service
|
||||
|
||||
import (
|
||||
context "context"
|
||||
@@ -49,6 +49,21 @@ func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
|
||||
}
|
||||
|
||||
// CreateServiceFromPeer mocks base method.
|
||||
func (m *MockManager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateServiceFromPeer", ctx, accountID, peerID, req)
|
||||
ret0, _ := ret[0].(*ExposeServiceResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateServiceFromPeer indicates an expected call of CreateServiceFromPeer.
|
||||
func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID, req interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
|
||||
}
|
||||
|
||||
// DeleteAllServices mocks base method.
|
||||
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -195,6 +210,20 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
|
||||
}
|
||||
|
||||
// RenewServiceFromPeer mocks base method.
|
||||
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer.
|
||||
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain)
|
||||
}
|
||||
|
||||
// SetCertificateIssuedAt mocks base method.
|
||||
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -210,7 +239,7 @@ func (mr *MockManagerMockRecorder) SetCertificateIssuedAt(ctx, accountID, servic
|
||||
}
|
||||
|
||||
// SetStatus mocks base method.
|
||||
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error {
|
||||
func (m *MockManager) SetStatus(ctx context.Context, accountID, serviceID string, status Status) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SetStatus", ctx, accountID, serviceID, status)
|
||||
ret0, _ := ret[0].(error)
|
||||
@@ -223,6 +252,32 @@ func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status i
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
|
||||
}
|
||||
|
||||
// StartExposeReaper mocks base method.
|
||||
func (m *MockManager) StartExposeReaper(ctx context.Context) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "StartExposeReaper", ctx)
|
||||
}
|
||||
|
||||
// StartExposeReaper indicates an expected call of StartExposeReaper.
|
||||
func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartExposeReaper", reflect.TypeOf((*MockManager)(nil).StartExposeReaper), ctx)
|
||||
}
|
||||
|
||||
// StopServiceFromPeer mocks base method.
|
||||
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StopServiceFromPeer indicates an expected call of StopServiceFromPeer.
|
||||
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain)
|
||||
}
|
||||
|
||||
// UpdateService mocks base method.
|
||||
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6,10 +6,11 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
@@ -20,11 +21,11 @@ import (
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
manager reverseproxy.Manager
|
||||
manager rpservice.Manager
|
||||
}
|
||||
|
||||
// RegisterEndpoints registers all service HTTP endpoints.
|
||||
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
@@ -63,8 +64,11 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth
|
||||
return
|
||||
}
|
||||
|
||||
service := new(reverseproxy.Service)
|
||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||
service := new(rpservice.Service)
|
||||
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
@@ -109,9 +113,12 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth
|
||||
return
|
||||
}
|
||||
|
||||
service := new(reverseproxy.Service)
|
||||
service := new(rpservice.Service)
|
||||
service.ID = serviceID
|
||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||
if err = service.FromAPIRequest(&req, userAuth.AccountId); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err := service.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
@@ -0,0 +1,65 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
exposeTTL = 90 * time.Second
|
||||
exposeReapInterval = 30 * time.Second
|
||||
maxExposesPerPeer = 10
|
||||
exposeReapBatch = 100
|
||||
)
|
||||
|
||||
type exposeReaper struct {
|
||||
manager *Manager
|
||||
}
|
||||
|
||||
// StartExposeReaper starts a background goroutine that reaps expired ephemeral services from the DB.
|
||||
func (r *exposeReaper) StartExposeReaper(ctx context.Context) {
|
||||
go func() {
|
||||
// start with a random delay
|
||||
rn := rand.IntN(10)
|
||||
time.Sleep(time.Duration(rn) * time.Second)
|
||||
|
||||
ticker := time.NewTicker(exposeReapInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
r.reapExpiredExposes(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (r *exposeReaper) reapExpiredExposes(ctx context.Context) {
|
||||
expired, err := r.manager.store.GetExpiredEphemeralServices(ctx, exposeTTL, exposeReapBatch)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get expired ephemeral services: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, svc := range expired {
|
||||
log.Infof("reaping expired expose session for peer %s, domain %s", svc.SourcePeer, svc.Domain)
|
||||
|
||||
err := r.manager.deleteExpiredPeerService(ctx, svc.AccountID, svc.SourcePeer, svc.ID)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if s, ok := status.FromError(err); ok && s.ErrorType == status.NotFound {
|
||||
log.Debugf("service %s was already deleted by another instance", svc.Domain)
|
||||
} else {
|
||||
log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", svc.Domain, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,208 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
func TestReapExpiredExposes(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually expire the service by backdating meta_last_renewed_at
|
||||
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||
|
||||
// Create a non-expired service
|
||||
resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8081,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||
|
||||
// Expired service should be deleted
|
||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
require.Error(t, err, "expired service should be deleted")
|
||||
|
||||
// Non-expired service should remain
|
||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp2.Domain)
|
||||
require.NoError(t, err, "active service should remain")
|
||||
}
|
||||
|
||||
func TestReapAlreadyDeletedService(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||
|
||||
// Delete the service before reaping
|
||||
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reaping should handle the already-deleted service gracefully
|
||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||
}
|
||||
|
||||
func TestConcurrentReapAndRenew(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := range 5 {
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080 + i,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Expire all services
|
||||
services, err := testStore.GetAccountServices(ctx, store.LockingStrengthNone, testAccountID)
|
||||
require.NoError(t, err)
|
||||
for _, svc := range services {
|
||||
if svc.Source == rpservice.SourceEphemeral {
|
||||
expireEphemeralService(t, testStore, testAccountID, svc.Domain)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count, "all expired services should be reaped")
|
||||
}
|
||||
|
||||
func TestRenewEphemeralService(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("renew succeeds for active service", func(t *testing.T) {
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8082,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("renew fails for nonexistent domain", func(t *testing.T) {
|
||||
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no active expose session")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCountAndExistsEphemeralServices(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8083,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
exists, err := mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "service should exist")
|
||||
|
||||
exists, err = mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, "no-such.domain")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "non-existent service should not exist")
|
||||
}
|
||||
|
||||
func TestMaxExposesPerPeerEnforced(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := range maxExposesPerPeer {
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8090 + i,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err, "expose %d should succeed", i)
|
||||
}
|
||||
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 9999,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "maximum number of active expose sessions")
|
||||
}
|
||||
|
||||
func TestReapSkipsRenewedService(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8086,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expire the service
|
||||
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||
|
||||
// Renew it before the reaper runs
|
||||
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reaper should skip it because the re-check sees a fresh timestamp
|
||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||
|
||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
require.NoError(t, err, "renewed service should survive reaping")
|
||||
}
|
||||
|
||||
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
|
||||
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
|
||||
t.Helper()
|
||||
svc, err := s.GetServiceByDomain(context.Background(), accountID, domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
expired := time.Now().Add(-2 * exposeTTL)
|
||||
svc.Meta.LastRenewedAt = &expired
|
||||
err = s.UpdateService(context.Background(), svc)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -0,0 +1,928 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const unknownHostPlaceholder = "unknown"
|
||||
|
||||
// ClusterDeriver derives the proxy cluster from a domain.
|
||||
type ClusterDeriver interface {
|
||||
DeriveClusterFromDomain(ctx context.Context, accountID, domain string) (string, error)
|
||||
GetClusterDomains() []string
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
store store.Store
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
proxyController proxy.Controller
|
||||
clusterDeriver ClusterDeriver
|
||||
exposeReaper *exposeReaper
|
||||
}
|
||||
|
||||
// NewManager creates a new service manager.
|
||||
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, proxyController proxy.Controller, clusterDeriver ClusterDeriver) *Manager {
|
||||
mgr := &Manager{
|
||||
store: store,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
proxyController: proxyController,
|
||||
clusterDeriver: clusterDeriver,
|
||||
}
|
||||
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
||||
return mgr
|
||||
}
|
||||
|
||||
// StartExposeReaper starts the background goroutine that reaps expired ephemeral services.
|
||||
func (m *Manager) StartExposeReaper(ctx context.Context) {
|
||||
m.exposeReaper.StartExposeReaper(ctx)
|
||||
}
|
||||
|
||||
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *service.Service) error {
|
||||
for _, target := range s.Targets {
|
||||
switch target.TargetType {
|
||||
case service.TargetTypePeer:
|
||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get peer by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = peer.IP.String()
|
||||
case service.TargetTypeHost:
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = resource.Prefix.Addr().String()
|
||||
case service.TargetTypeDomain:
|
||||
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, target.TargetId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get resource by id %s for service %s: %v", target.TargetId, s.ID, err)
|
||||
target.Host = unknownHostPlaceholder
|
||||
continue
|
||||
}
|
||||
target.Host = resource.Domain
|
||||
case service.TargetTypeSubnet:
|
||||
// For subnets we do not do any lookups on the resource
|
||||
default:
|
||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetService(ctx context.Context, accountID, userID, serviceID string) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s *service.Service) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Create)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := m.persistNewService(ctx, accountID, s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, s.ID, accountID, activity.ServiceCreated, s.EventMeta())
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID string, service *service.Service) error {
|
||||
if m.clusterDeriver != nil {
|
||||
proxyCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s, updates will broadcast to all proxy servers", service.Domain)
|
||||
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
|
||||
}
|
||||
service.ProxyCluster = proxyCluster
|
||||
}
|
||||
|
||||
service.AccountID = accountID
|
||||
service.InitNewRecord()
|
||||
|
||||
if err := service.Auth.HashSecrets(); err != nil {
|
||||
return fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
keyPair, err := sessionkey.GenerateKeyPair()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate session keys: %w", err)
|
||||
}
|
||||
service.SessionPrivateKey = keyPair.PrivateKey
|
||||
service.SessionPublicKey = keyPair.PublicKey
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// persistNewEphemeralService creates an ephemeral service inside a single transaction
|
||||
// that also enforces the duplicate and per-peer limit checks atomically.
|
||||
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
|
||||
// for the same peer, preventing the per-peer limit from being bypassed.
|
||||
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
// Lock the peer row to serialize concurrent creates for the same peer.
|
||||
// Without this, when no ephemeral rows exist yet, FOR UPDATE on the services
|
||||
// table returns no rows and acquires no locks, allowing concurrent inserts
|
||||
// to bypass the per-peer limit.
|
||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
|
||||
return fmt.Errorf("lock peer row: %w", err)
|
||||
}
|
||||
|
||||
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check existing expose: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
|
||||
}
|
||||
|
||||
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("count peer exposes: %w", err)
|
||||
}
|
||||
if count >= int64(maxExposesPerPeer) {
|
||||
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
|
||||
}
|
||||
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, svc.Domain, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, svc); err != nil {
|
||||
return fmt.Errorf("create service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing service: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if existingService != nil && existingService.ID != excludeServiceID {
|
||||
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, service *service.Service) (*service.Service, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Update)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
if err := service.Auth.HashSecrets(); err != nil {
|
||||
return nil, fmt.Errorf("hash secrets: %w", err)
|
||||
}
|
||||
|
||||
updateInfo, err := m.persistServiceUpdate(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceUpdated, service.EventMeta())
|
||||
|
||||
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
type serviceUpdateInfo struct {
|
||||
oldCluster string
|
||||
domainChanged bool
|
||||
serviceEnabledChanged bool
|
||||
}
|
||||
|
||||
func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, service *service.Service) (*serviceUpdateInfo, error) {
|
||||
var updateInfo serviceUpdateInfo
|
||||
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateInfo.oldCluster = existingService.ProxyCluster
|
||||
updateInfo.domainChanged = existingService.Domain != service.Domain
|
||||
|
||||
if updateInfo.domainChanged {
|
||||
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
service.ProxyCluster = existingService.ProxyCluster
|
||||
}
|
||||
|
||||
m.preserveExistingAuthSecrets(service, existingService)
|
||||
m.preserveServiceMetadata(service, existingService)
|
||||
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("update service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return &updateInfo, err
|
||||
}
|
||||
|
||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.clusterDeriver != nil {
|
||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
||||
} else {
|
||||
service.ProxyCluster = newCluster
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
|
||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||
service.Auth.PasswordAuth.Password == "" {
|
||||
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
|
||||
}
|
||||
|
||||
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
|
||||
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
|
||||
service.Auth.PinAuth.Pin == "" {
|
||||
service.Auth.PinAuth = existingService.Auth.PinAuth
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
|
||||
service.Meta = existingService.Meta
|
||||
service.SessionPrivateKey = existingService.SessionPrivateKey
|
||||
service.SessionPublicKey = existingService.SessionPublicKey
|
||||
}
|
||||
|
||||
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
|
||||
switch {
|
||||
case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
|
||||
case !s.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), s.ProxyCluster)
|
||||
case s.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
|
||||
default:
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", oidcCfg), s.ProxyCluster)
|
||||
}
|
||||
}
|
||||
|
||||
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
|
||||
for _, target := range targets {
|
||||
switch target.TargetType {
|
||||
case service.TargetTypePeer:
|
||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
|
||||
}
|
||||
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
|
||||
}
|
||||
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
|
||||
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
|
||||
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
|
||||
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
|
||||
}
|
||||
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var s *service.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
s, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("failed to delete targets: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("failed to delete service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, s.EventMeta())
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var services []*service.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
services, err = transaction.GetAccountServices(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, svc := range services {
|
||||
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete service: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
|
||||
for _, svc := range services {
|
||||
m.accountManager.StoreEvent(ctx, userID, svc.ID, accountID, activity.ServiceDeleted, svc.EventMeta())
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
|
||||
}
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetCertificateIssuedAt sets the certificate issued timestamp to the current time.
|
||||
// Call this when receiving a gRPC notification that the certificate was issued.
|
||||
func (m *Manager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
service.Meta.CertificateIssuedAt = &now
|
||||
|
||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to update service certificate timestamp: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// SetStatus updates the status of the service (e.g., "active", "tunnel_not_created", etc.)
|
||||
func (m *Manager) SetStatus(ctx context.Context, accountID, serviceID string, status service.Status) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
service, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
service.Meta.Status = string(status)
|
||||
|
||||
if err = transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to update service status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string) error {
|
||||
s, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) ReloadAllServicesForAccount(ctx context.Context, accountID string) error {
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, s := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", s.ID, err)
|
||||
}
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetGlobalServices(ctx context.Context) ([]*service.Service, error) {
|
||||
services, err := m.store.GetServices(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, service.AccountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetServiceByID(ctx context.Context, accountID, serviceID string) (*service.Service, error) {
|
||||
service, err := m.store.GetServiceByID(ctx, store.LockingStrengthNone, accountID, serviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get service: %w", err)
|
||||
}
|
||||
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetAccountServices(ctx context.Context, accountID string) ([]*service.Service, error) {
|
||||
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return services, nil
|
||||
}
|
||||
|
||||
func (m *Manager) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) {
|
||||
target, err := m.store.GetServiceTargetByTargetID(ctx, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("failed to get service target by resource ID: %w", err)
|
||||
}
|
||||
|
||||
if target == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return target.ServiceID, nil
|
||||
}
|
||||
|
||||
// validateExposePermission checks whether the peer is allowed to use the expose feature.
|
||||
// It verifies the account has peer expose enabled and that the peer belongs to an allowed group.
|
||||
func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerID string) error {
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||
return status.Errorf(status.Internal, "get account settings: %v", err)
|
||||
}
|
||||
|
||||
if !settings.PeerExposeEnabled {
|
||||
return status.Errorf(status.PermissionDenied, "peer expose is not enabled for this account")
|
||||
}
|
||||
|
||||
if len(settings.PeerExposeGroups) == 0 {
|
||||
return status.Errorf(status.PermissionDenied, "no group is set for peer expose")
|
||||
}
|
||||
|
||||
peerGroupIDs, err := m.store.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peer group IDs: %v", err)
|
||||
return status.Errorf(status.Internal, "get peer groups: %v", err)
|
||||
}
|
||||
|
||||
for _, pg := range peerGroupIDs {
|
||||
if slices.Contains(settings.PeerExposeGroups, pg) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group")
|
||||
}
|
||||
|
||||
// CreateServiceFromPeer creates a service initiated by a peer expose request.
|
||||
// It validates the request, checks expose permissions, enforces the per-peer limit,
|
||||
// creates the service, and tracks it for TTL-based reaping.
|
||||
func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *service.ExposeServiceRequest) (*service.ExposeServiceResponse, error) {
|
||||
if err := req.Validate(); err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err)
|
||||
}
|
||||
|
||||
if err := m.validateExposePermission(ctx, accountID, peerID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serviceName, err := service.GenerateExposeName(req.NamePrefix)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err)
|
||||
}
|
||||
|
||||
svc := req.ToService(accountID, peerID, serviceName)
|
||||
svc.Source = service.SourceEphemeral
|
||||
|
||||
if svc.Domain == "" {
|
||||
domain, err := m.buildRandomDomain(svc.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
|
||||
}
|
||||
svc.Domain = domain
|
||||
}
|
||||
|
||||
if svc.Auth.BearerAuth != nil && svc.Auth.BearerAuth.Enabled {
|
||||
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, svc.Auth.BearerAuth.DistributionGroups)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group ids for service %s: %w", svc.Name, err)
|
||||
}
|
||||
svc.Auth.BearerAuth.DistributionGroups = groupIDs
|
||||
}
|
||||
|
||||
if err := m.initializeServiceForCreate(ctx, accountID, svc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
svc.SourcePeer = peerID
|
||||
|
||||
now := time.Now()
|
||||
svc.Meta.LastRenewedAt = &now
|
||||
|
||||
if err := m.persistNewEphemeralService(ctx, accountID, peerID, svc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||
m.accountManager.StoreEvent(ctx, peerID, svc.ID, accountID, activity.PeerServiceExposed, meta)
|
||||
|
||||
if err := m.replaceHostByLookup(ctx, accountID, svc); err != nil {
|
||||
return nil, fmt.Errorf("replace host by lookup for service %s: %w", svc.ID, err)
|
||||
}
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return &service.ExposeServiceResponse{
|
||||
ServiceName: svc.Name,
|
||||
ServiceURL: "https://" + svc.Domain,
|
||||
Domain: svc.Domain,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
|
||||
if len(groupNames) == 0 {
|
||||
return []string{}, fmt.Errorf("no group names provided")
|
||||
}
|
||||
groupIDs := make([]string, 0, len(groupNames))
|
||||
for _, groupName := range groupNames {
|
||||
g, err := m.accountManager.GetGroupByName(ctx, groupName, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get group by name %s: %w", groupName, err)
|
||||
}
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildRandomDomain(name string) (string, error) {
|
||||
if m.clusterDeriver == nil {
|
||||
return "", fmt.Errorf("unable to get random domain")
|
||||
}
|
||||
clusterDomains := m.clusterDeriver.GetClusterDomains()
|
||||
if len(clusterDomains) == 0 {
|
||||
return "", fmt.Errorf("no cluster domains found for service %s", name)
|
||||
}
|
||||
index := rand.IntN(len(clusterDomains))
|
||||
domain := name + "." + clusterDomains[index]
|
||||
return domain, nil
|
||||
}
|
||||
|
||||
// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service.
|
||||
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
return m.store.RenewEphemeralService(ctx, accountID, peerID, domain)
|
||||
}
|
||||
|
||||
// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB.
|
||||
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
|
||||
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
|
||||
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
|
||||
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
activityCode := activity.PeerServiceUnexposed
|
||||
if expired {
|
||||
activityCode = activity.PeerServiceExposeExpired
|
||||
}
|
||||
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
|
||||
}
|
||||
|
||||
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
|
||||
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
|
||||
svc, err := m.store.GetServiceByDomain(ctx, accountID, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if svc.Source != service.SourceEphemeral {
|
||||
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
|
||||
}
|
||||
|
||||
if svc.SourcePeer != peerID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
|
||||
}
|
||||
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {
|
||||
var svc *service.Service
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if svc.Source != service.SourceEphemeral {
|
||||
return status.Errorf(status.PermissionDenied, "cannot delete API-created service via peer expose")
|
||||
}
|
||||
|
||||
if svc.SourcePeer != peerID {
|
||||
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("delete service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err)
|
||||
peer = nil
|
||||
}
|
||||
|
||||
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||
|
||||
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activityCode, meta)
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteExpiredPeerService deletes an ephemeral service by ID after re-checking
|
||||
// that it is still expired under a row lock. This prevents deleting a service
|
||||
// that was renewed between the batch query and this delete, and ensures only one
|
||||
// management instance processes the deletion
|
||||
func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerID, serviceID string) error {
|
||||
var svc *service.Service
|
||||
deleted := false
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
svc, err = transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if svc.Source != service.SourceEphemeral || svc.SourcePeer != peerID {
|
||||
return status.Errorf(status.PermissionDenied, "service does not match expected ephemeral owner")
|
||||
}
|
||||
|
||||
if svc.Meta.LastRenewedAt != nil && time.Since(*svc.Meta.LastRenewedAt) <= exposeTTL {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
|
||||
return fmt.Errorf("delete service: %w", err)
|
||||
}
|
||||
deleted = true
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !deleted {
|
||||
return nil
|
||||
}
|
||||
|
||||
peer, err := m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to get peer %s for event metadata: %v", peerID, err)
|
||||
peer = nil
|
||||
}
|
||||
|
||||
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func addPeerInfoToEventMeta(meta map[string]any, peer *nbpeer.Peer) map[string]any {
|
||||
if peer == nil {
|
||||
return meta
|
||||
}
|
||||
meta["peer_name"] = peer.Name
|
||||
if peer.IP != nil {
|
||||
meta["peer_ip"] = peer.IP.String()
|
||||
}
|
||||
return meta
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
817
management/internals/modules/reverseproxy/service/service.go
Normal file
817
management/internals/modules/reverseproxy/service/service.go
Normal file
@@ -0,0 +1,817 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type Operation string
|
||||
|
||||
const (
|
||||
Create Operation = "create"
|
||||
Update Operation = "update"
|
||||
Delete Operation = "delete"
|
||||
)
|
||||
|
||||
type Status string
|
||||
|
||||
const (
|
||||
StatusPending Status = "pending"
|
||||
StatusActive Status = "active"
|
||||
StatusTunnelNotCreated Status = "tunnel_not_created"
|
||||
StatusCertificatePending Status = "certificate_pending"
|
||||
StatusCertificateFailed Status = "certificate_failed"
|
||||
StatusError Status = "error"
|
||||
|
||||
TargetTypePeer = "peer"
|
||||
TargetTypeHost = "host"
|
||||
TargetTypeDomain = "domain"
|
||||
TargetTypeSubnet = "subnet"
|
||||
|
||||
SourcePermanent = "permanent"
|
||||
SourceEphemeral = "ephemeral"
|
||||
)
|
||||
|
||||
type TargetOptions struct {
|
||||
SkipTLSVerify bool `json:"skip_tls_verify"`
|
||||
RequestTimeout time.Duration `json:"request_timeout,omitempty"`
|
||||
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
|
||||
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
ID uint `gorm:"primaryKey" json:"-"`
|
||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||
Port int `gorm:"index:idx_target_port" json:"port"`
|
||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
|
||||
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
||||
Options TargetOptions `gorm:"embedded" json:"options"`
|
||||
}
|
||||
|
||||
type PasswordAuthConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type PINAuthConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Pin string `json:"pin"`
|
||||
}
|
||||
|
||||
type BearerAuthConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
DistributionGroups []string `json:"distribution_groups,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
type AuthConfig struct {
|
||||
PasswordAuth *PasswordAuthConfig `json:"password_auth,omitempty" gorm:"serializer:json"`
|
||||
PinAuth *PINAuthConfig `json:"pin_auth,omitempty" gorm:"serializer:json"`
|
||||
BearerAuth *BearerAuthConfig `json:"bearer_auth,omitempty" gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
func (a *AuthConfig) HashSecrets() error {
|
||||
if a.PasswordAuth != nil && a.PasswordAuth.Enabled && a.PasswordAuth.Password != "" {
|
||||
hashedPassword, err := argon2id.Hash(a.PasswordAuth.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
a.PasswordAuth.Password = hashedPassword
|
||||
}
|
||||
|
||||
if a.PinAuth != nil && a.PinAuth.Enabled && a.PinAuth.Pin != "" {
|
||||
hashedPin, err := argon2id.Hash(a.PinAuth.Pin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash pin: %w", err)
|
||||
}
|
||||
a.PinAuth.Pin = hashedPin
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *AuthConfig) ClearSecrets() {
|
||||
if a.PasswordAuth != nil {
|
||||
a.PasswordAuth.Password = ""
|
||||
}
|
||||
if a.PinAuth != nil {
|
||||
a.PinAuth.Pin = ""
|
||||
}
|
||||
}
|
||||
|
||||
type Meta struct {
|
||||
CreatedAt time.Time
|
||||
CertificateIssuedAt *time.Time
|
||||
Status string
|
||||
LastRenewedAt *time.Time
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
AccountID string `gorm:"index"`
|
||||
Name string
|
||||
Domain string `gorm:"index"`
|
||||
ProxyCluster string `gorm:"index"`
|
||||
Targets []*Target `gorm:"foreignKey:ServiceID;constraint:OnDelete:CASCADE"`
|
||||
Enabled bool
|
||||
PassHostHeader bool
|
||||
RewriteRedirects bool
|
||||
Auth AuthConfig `gorm:"serializer:json"`
|
||||
Meta Meta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
SessionPrivateKey string `gorm:"column:session_private_key"`
|
||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
|
||||
SourcePeer string `gorm:"index:idx_service_source_peer"`
|
||||
}
|
||||
|
||||
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
|
||||
for _, target := range targets {
|
||||
target.AccountID = accountID
|
||||
}
|
||||
|
||||
s := &Service{
|
||||
AccountID: accountID,
|
||||
Name: name,
|
||||
Domain: domain,
|
||||
ProxyCluster: proxyCluster,
|
||||
Targets: targets,
|
||||
Enabled: enabled,
|
||||
}
|
||||
s.InitNewRecord()
|
||||
return s
|
||||
}
|
||||
|
||||
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
||||
// Service record. This overwrites any existing ID and Meta fields and should
|
||||
// only be called during initial creation, not for updates.
|
||||
func (s *Service) InitNewRecord() {
|
||||
s.ID = xid.New().String()
|
||||
s.Meta = Meta{
|
||||
CreatedAt: time.Now(),
|
||||
Status: string(StatusPending),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) ToAPIResponse() *api.Service {
|
||||
s.Auth.ClearSecrets()
|
||||
|
||||
authConfig := api.ServiceAuthConfig{}
|
||||
|
||||
if s.Auth.PasswordAuth != nil {
|
||||
authConfig.PasswordAuth = &api.PasswordAuthConfig{
|
||||
Enabled: s.Auth.PasswordAuth.Enabled,
|
||||
Password: s.Auth.PasswordAuth.Password,
|
||||
}
|
||||
}
|
||||
|
||||
if s.Auth.PinAuth != nil {
|
||||
authConfig.PinAuth = &api.PINAuthConfig{
|
||||
Enabled: s.Auth.PinAuth.Enabled,
|
||||
Pin: s.Auth.PinAuth.Pin,
|
||||
}
|
||||
}
|
||||
|
||||
if s.Auth.BearerAuth != nil {
|
||||
authConfig.BearerAuth = &api.BearerAuthConfig{
|
||||
Enabled: s.Auth.BearerAuth.Enabled,
|
||||
DistributionGroups: &s.Auth.BearerAuth.DistributionGroups,
|
||||
}
|
||||
}
|
||||
|
||||
// Convert internal targets to API targets
|
||||
apiTargets := make([]api.ServiceTarget, 0, len(s.Targets))
|
||||
for _, target := range s.Targets {
|
||||
st := api.ServiceTarget{
|
||||
Path: target.Path,
|
||||
Host: &target.Host,
|
||||
Port: target.Port,
|
||||
Protocol: api.ServiceTargetProtocol(target.Protocol),
|
||||
TargetId: target.TargetId,
|
||||
TargetType: api.ServiceTargetTargetType(target.TargetType),
|
||||
Enabled: target.Enabled,
|
||||
}
|
||||
st.Options = targetOptionsToAPI(target.Options)
|
||||
apiTargets = append(apiTargets, st)
|
||||
}
|
||||
|
||||
meta := api.ServiceMeta{
|
||||
CreatedAt: s.Meta.CreatedAt,
|
||||
Status: api.ServiceMetaStatus(s.Meta.Status),
|
||||
}
|
||||
|
||||
if s.Meta.CertificateIssuedAt != nil {
|
||||
meta.CertificateIssuedAt = s.Meta.CertificateIssuedAt
|
||||
}
|
||||
|
||||
resp := &api.Service{
|
||||
Id: s.ID,
|
||||
Name: s.Name,
|
||||
Domain: s.Domain,
|
||||
Targets: apiTargets,
|
||||
Enabled: s.Enabled,
|
||||
PassHostHeader: &s.PassHostHeader,
|
||||
RewriteRedirects: &s.RewriteRedirects,
|
||||
Auth: authConfig,
|
||||
Meta: meta,
|
||||
}
|
||||
|
||||
if s.ProxyCluster != "" {
|
||||
resp.ProxyCluster = &s.ProxyCluster
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
|
||||
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
|
||||
for _, target := range s.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: Make path prefix stripping configurable per-target.
|
||||
// Currently the matching prefix is baked into the target URL path,
|
||||
// so the proxy strips-then-re-adds it (effectively a no-op).
|
||||
targetURL := url.URL{
|
||||
Scheme: target.Protocol,
|
||||
Host: target.Host,
|
||||
Path: "/", // TODO: support service path
|
||||
}
|
||||
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
|
||||
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
|
||||
}
|
||||
|
||||
path := "/"
|
||||
if target.Path != nil {
|
||||
path = *target.Path
|
||||
}
|
||||
|
||||
pm := &proto.PathMapping{
|
||||
Path: path,
|
||||
Target: targetURL.String(),
|
||||
}
|
||||
|
||||
pm.Options = targetOptionsToProto(target.Options)
|
||||
pathMappings = append(pathMappings, pm)
|
||||
}
|
||||
|
||||
auth := &proto.Authentication{
|
||||
SessionKey: s.SessionPublicKey,
|
||||
MaxSessionAgeSeconds: int64((time.Hour * 24).Seconds()),
|
||||
}
|
||||
|
||||
if s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled {
|
||||
auth.Password = true
|
||||
}
|
||||
|
||||
if s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled {
|
||||
auth.Pin = true
|
||||
}
|
||||
|
||||
if s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled {
|
||||
auth.Oidc = true
|
||||
}
|
||||
|
||||
return &proto.ProxyMapping{
|
||||
Type: operationToProtoType(operation),
|
||||
Id: s.ID,
|
||||
Domain: s.Domain,
|
||||
Path: pathMappings,
|
||||
AuthToken: authToken,
|
||||
Auth: auth,
|
||||
AccountId: s.AccountID,
|
||||
PassHostHeader: s.PassHostHeader,
|
||||
RewriteRedirects: s.RewriteRedirects,
|
||||
}
|
||||
}
|
||||
|
||||
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
||||
switch op {
|
||||
case Create:
|
||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||
case Update:
|
||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED
|
||||
case Delete:
|
||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED
|
||||
default:
|
||||
log.Fatalf("unknown operation type: %v", op)
|
||||
return proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED
|
||||
}
|
||||
}
|
||||
|
||||
// isDefaultPort reports whether port is the standard default for the given scheme
|
||||
// (443 for https, 80 for http).
|
||||
func isDefaultPort(scheme string, port int) bool {
|
||||
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
|
||||
}
|
||||
|
||||
// PathRewriteMode controls how the request path is rewritten before forwarding.
|
||||
type PathRewriteMode string
|
||||
|
||||
const (
|
||||
PathRewritePreserve PathRewriteMode = "preserve"
|
||||
)
|
||||
|
||||
func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
|
||||
switch mode {
|
||||
case PathRewritePreserve:
|
||||
return proto.PathRewriteMode_PATH_REWRITE_PRESERVE
|
||||
default:
|
||||
return proto.PathRewriteMode_PATH_REWRITE_DEFAULT
|
||||
}
|
||||
}
|
||||
|
||||
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
|
||||
return nil
|
||||
}
|
||||
apiOpts := &api.ServiceTargetOptions{}
|
||||
if opts.SkipTLSVerify {
|
||||
apiOpts.SkipTlsVerify = &opts.SkipTLSVerify
|
||||
}
|
||||
if opts.RequestTimeout != 0 {
|
||||
s := opts.RequestTimeout.String()
|
||||
apiOpts.RequestTimeout = &s
|
||||
}
|
||||
if opts.PathRewrite != "" {
|
||||
pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite)
|
||||
apiOpts.PathRewrite = &pr
|
||||
}
|
||||
if len(opts.CustomHeaders) > 0 {
|
||||
apiOpts.CustomHeaders = &opts.CustomHeaders
|
||||
}
|
||||
return apiOpts
|
||||
}
|
||||
|
||||
func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.PathRewrite == "" && opts.RequestTimeout == 0 && len(opts.CustomHeaders) == 0 {
|
||||
return nil
|
||||
}
|
||||
popts := &proto.PathTargetOptions{
|
||||
SkipTlsVerify: opts.SkipTLSVerify,
|
||||
PathRewrite: pathRewriteToProto(opts.PathRewrite),
|
||||
CustomHeaders: opts.CustomHeaders,
|
||||
}
|
||||
if opts.RequestTimeout != 0 {
|
||||
popts.RequestTimeout = durationpb.New(opts.RequestTimeout)
|
||||
}
|
||||
return popts
|
||||
}
|
||||
|
||||
func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) {
|
||||
var opts TargetOptions
|
||||
if o.SkipTlsVerify != nil {
|
||||
opts.SkipTLSVerify = *o.SkipTlsVerify
|
||||
}
|
||||
if o.RequestTimeout != nil {
|
||||
d, err := time.ParseDuration(*o.RequestTimeout)
|
||||
if err != nil {
|
||||
return opts, fmt.Errorf("target %d: parse request_timeout %q: %w", idx, *o.RequestTimeout, err)
|
||||
}
|
||||
opts.RequestTimeout = d
|
||||
}
|
||||
if o.PathRewrite != nil {
|
||||
opts.PathRewrite = PathRewriteMode(*o.PathRewrite)
|
||||
}
|
||||
if o.CustomHeaders != nil {
|
||||
opts.CustomHeaders = *o.CustomHeaders
|
||||
}
|
||||
return opts, nil
|
||||
}
|
||||
|
||||
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) error {
|
||||
s.Name = req.Name
|
||||
s.Domain = req.Domain
|
||||
s.AccountID = accountID
|
||||
|
||||
targets := make([]*Target, 0, len(req.Targets))
|
||||
for i, apiTarget := range req.Targets {
|
||||
target := &Target{
|
||||
AccountID: accountID,
|
||||
Path: apiTarget.Path,
|
||||
Port: apiTarget.Port,
|
||||
Protocol: string(apiTarget.Protocol),
|
||||
TargetId: apiTarget.TargetId,
|
||||
TargetType: string(apiTarget.TargetType),
|
||||
Enabled: apiTarget.Enabled,
|
||||
}
|
||||
if apiTarget.Host != nil {
|
||||
target.Host = *apiTarget.Host
|
||||
}
|
||||
if apiTarget.Options != nil {
|
||||
opts, err := targetOptionsFromAPI(i, apiTarget.Options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
target.Options = opts
|
||||
}
|
||||
targets = append(targets, target)
|
||||
}
|
||||
s.Targets = targets
|
||||
|
||||
s.Enabled = req.Enabled
|
||||
|
||||
if req.PassHostHeader != nil {
|
||||
s.PassHostHeader = *req.PassHostHeader
|
||||
}
|
||||
|
||||
if req.RewriteRedirects != nil {
|
||||
s.RewriteRedirects = *req.RewriteRedirects
|
||||
}
|
||||
|
||||
if req.Auth.PasswordAuth != nil {
|
||||
s.Auth.PasswordAuth = &PasswordAuthConfig{
|
||||
Enabled: req.Auth.PasswordAuth.Enabled,
|
||||
Password: req.Auth.PasswordAuth.Password,
|
||||
}
|
||||
}
|
||||
|
||||
if req.Auth.PinAuth != nil {
|
||||
s.Auth.PinAuth = &PINAuthConfig{
|
||||
Enabled: req.Auth.PinAuth.Enabled,
|
||||
Pin: req.Auth.PinAuth.Pin,
|
||||
}
|
||||
}
|
||||
|
||||
if req.Auth.BearerAuth != nil {
|
||||
bearerAuth := &BearerAuthConfig{
|
||||
Enabled: req.Auth.BearerAuth.Enabled,
|
||||
}
|
||||
if req.Auth.BearerAuth.DistributionGroups != nil {
|
||||
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
|
||||
}
|
||||
s.Auth.BearerAuth = bearerAuth
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Validate() error {
|
||||
if s.Name == "" {
|
||||
return errors.New("service name is required")
|
||||
}
|
||||
if len(s.Name) > 255 {
|
||||
return errors.New("service name exceeds maximum length of 255 characters")
|
||||
}
|
||||
|
||||
if s.Domain == "" {
|
||||
return errors.New("service domain is required")
|
||||
}
|
||||
|
||||
if len(s.Targets) == 0 {
|
||||
return errors.New("at least one target is required")
|
||||
}
|
||||
|
||||
for i, target := range s.Targets {
|
||||
switch target.TargetType {
|
||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||
// host field will be ignored
|
||||
case TargetTypeSubnet:
|
||||
if target.Host == "" {
|
||||
return fmt.Errorf("target %d has empty host but target_type is %q", i, target.TargetType)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("target %d has invalid target_type %q", i, target.TargetType)
|
||||
}
|
||||
if target.TargetId == "" {
|
||||
return fmt.Errorf("target %d has empty target_id", i)
|
||||
}
|
||||
if err := validateTargetOptions(i, &target.Options); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
maxRequestTimeout = 5 * time.Minute
|
||||
maxCustomHeaders = 16
|
||||
maxHeaderKeyLen = 128
|
||||
maxHeaderValueLen = 4096
|
||||
)
|
||||
|
||||
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
|
||||
var httpHeaderNameRe = regexp.MustCompile(`^[!#$%&'*+\-.^_` + "`" + `|~0-9A-Za-z]+$`)
|
||||
|
||||
// hopByHopHeaders are headers that must not be set as custom headers
|
||||
// because they are connection-level and stripped by the proxy.
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
"Connection": {},
|
||||
"Keep-Alive": {},
|
||||
"Proxy-Authenticate": {},
|
||||
"Proxy-Authorization": {},
|
||||
"Proxy-Connection": {},
|
||||
"Te": {},
|
||||
"Trailer": {},
|
||||
"Transfer-Encoding": {},
|
||||
"Upgrade": {},
|
||||
}
|
||||
|
||||
// reservedHeaders are set authoritatively by the proxy or control HTTP framing
|
||||
// and cannot be overridden.
|
||||
var reservedHeaders = map[string]struct{}{
|
||||
"Content-Length": {},
|
||||
"Content-Type": {},
|
||||
"Cookie": {},
|
||||
"Forwarded": {},
|
||||
"X-Forwarded-For": {},
|
||||
"X-Forwarded-Host": {},
|
||||
"X-Forwarded-Port": {},
|
||||
"X-Forwarded-Proto": {},
|
||||
"X-Real-Ip": {},
|
||||
}
|
||||
|
||||
func validateTargetOptions(idx int, opts *TargetOptions) error {
|
||||
if opts.PathRewrite != "" && opts.PathRewrite != PathRewritePreserve {
|
||||
return fmt.Errorf("target %d: unknown path_rewrite mode %q", idx, opts.PathRewrite)
|
||||
}
|
||||
|
||||
if opts.RequestTimeout != 0 {
|
||||
if opts.RequestTimeout <= 0 {
|
||||
return fmt.Errorf("target %d: request_timeout must be positive", idx)
|
||||
}
|
||||
if opts.RequestTimeout > maxRequestTimeout {
|
||||
return fmt.Errorf("target %d: request_timeout exceeds maximum of %s", idx, maxRequestTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateCustomHeaders(idx int, headers map[string]string) error {
|
||||
if len(headers) > maxCustomHeaders {
|
||||
return fmt.Errorf("target %d: custom_headers count %d exceeds maximum of %d", idx, len(headers), maxCustomHeaders)
|
||||
}
|
||||
seen := make(map[string]string, len(headers))
|
||||
for key, value := range headers {
|
||||
if !httpHeaderNameRe.MatchString(key) {
|
||||
return fmt.Errorf("target %d: custom header key %q is not a valid HTTP header name", idx, key)
|
||||
}
|
||||
if len(key) > maxHeaderKeyLen {
|
||||
return fmt.Errorf("target %d: custom header key %q exceeds maximum length of %d", idx, key, maxHeaderKeyLen)
|
||||
}
|
||||
if len(value) > maxHeaderValueLen {
|
||||
return fmt.Errorf("target %d: custom header %q value exceeds maximum length of %d", idx, key, maxHeaderValueLen)
|
||||
}
|
||||
if containsCRLF(key) || containsCRLF(value) {
|
||||
return fmt.Errorf("target %d: custom header %q contains invalid characters", idx, key)
|
||||
}
|
||||
canonical := http.CanonicalHeaderKey(key)
|
||||
if prev, ok := seen[canonical]; ok {
|
||||
return fmt.Errorf("target %d: custom header keys %q and %q collide (both canonicalize to %q)", idx, prev, key, canonical)
|
||||
}
|
||||
seen[canonical] = key
|
||||
if _, ok := hopByHopHeaders[canonical]; ok {
|
||||
return fmt.Errorf("target %d: custom header %q is a hop-by-hop header and cannot be set", idx, key)
|
||||
}
|
||||
if _, ok := reservedHeaders[canonical]; ok {
|
||||
return fmt.Errorf("target %d: custom header %q is managed by the proxy and cannot be overridden", idx, key)
|
||||
}
|
||||
if canonical == "Host" {
|
||||
return fmt.Errorf("target %d: use pass_host_header instead of setting Host as a custom header", idx)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func containsCRLF(s string) bool {
|
||||
return strings.ContainsAny(s, "\r\n")
|
||||
}
|
||||
|
||||
func (s *Service) EventMeta() map[string]any {
|
||||
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
|
||||
}
|
||||
|
||||
func (s *Service) isAuthEnabled() bool {
|
||||
return s.Auth.PasswordAuth != nil || s.Auth.PinAuth != nil || s.Auth.BearerAuth != nil
|
||||
}
|
||||
|
||||
func (s *Service) Copy() *Service {
|
||||
targets := make([]*Target, len(s.Targets))
|
||||
for i, target := range s.Targets {
|
||||
targetCopy := *target
|
||||
if len(target.Options.CustomHeaders) > 0 {
|
||||
targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders))
|
||||
for k, v := range target.Options.CustomHeaders {
|
||||
targetCopy.Options.CustomHeaders[k] = v
|
||||
}
|
||||
}
|
||||
targets[i] = &targetCopy
|
||||
}
|
||||
|
||||
return &Service{
|
||||
ID: s.ID,
|
||||
AccountID: s.AccountID,
|
||||
Name: s.Name,
|
||||
Domain: s.Domain,
|
||||
ProxyCluster: s.ProxyCluster,
|
||||
Targets: targets,
|
||||
Enabled: s.Enabled,
|
||||
PassHostHeader: s.PassHostHeader,
|
||||
RewriteRedirects: s.RewriteRedirects,
|
||||
Auth: s.Auth,
|
||||
Meta: s.Meta,
|
||||
SessionPrivateKey: s.SessionPrivateKey,
|
||||
SessionPublicKey: s.SessionPublicKey,
|
||||
Source: s.Source,
|
||||
SourcePeer: s.SourcePeer,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.SessionPrivateKey != "" {
|
||||
var err error
|
||||
s.SessionPrivateKey, err = enc.Encrypt(s.SessionPrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.SessionPrivateKey != "" {
|
||||
var err error
|
||||
s.SessionPrivateKey, err = enc.Decrypt(s.SessionPrivateKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
|
||||
|
||||
const alphanumCharset = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
var validNamePrefix = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,30}[a-z0-9])?$`)
|
||||
|
||||
// ExposeServiceRequest contains the parameters for creating a peer-initiated expose service.
|
||||
type ExposeServiceRequest struct {
|
||||
NamePrefix string
|
||||
Port int
|
||||
Protocol string
|
||||
Domain string
|
||||
Pin string
|
||||
Password string
|
||||
UserGroups []string
|
||||
}
|
||||
|
||||
// Validate checks all fields of the expose request.
|
||||
func (r *ExposeServiceRequest) Validate() error {
|
||||
if r == nil {
|
||||
return errors.New("request cannot be nil")
|
||||
}
|
||||
|
||||
if r.Port < 1 || r.Port > 65535 {
|
||||
return fmt.Errorf("port must be between 1 and 65535, got %d", r.Port)
|
||||
}
|
||||
|
||||
if r.Protocol != "http" && r.Protocol != "https" {
|
||||
return fmt.Errorf("unsupported protocol %q: must be http or https", r.Protocol)
|
||||
}
|
||||
|
||||
if r.Pin != "" && !pinRegexp.MatchString(r.Pin) {
|
||||
return errors.New("invalid pin: must be exactly 6 digits")
|
||||
}
|
||||
|
||||
for _, g := range r.UserGroups {
|
||||
if g == "" {
|
||||
return errors.New("user group name cannot be empty")
|
||||
}
|
||||
}
|
||||
|
||||
if r.NamePrefix != "" && !validNamePrefix.MatchString(r.NamePrefix) {
|
||||
return fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", r.NamePrefix)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToService builds a Service from the expose request.
|
||||
func (r *ExposeServiceRequest) ToService(accountID, peerID, serviceName string) *Service {
|
||||
service := &Service{
|
||||
AccountID: accountID,
|
||||
Name: serviceName,
|
||||
Enabled: true,
|
||||
Targets: []*Target{
|
||||
{
|
||||
AccountID: accountID,
|
||||
Port: r.Port,
|
||||
Protocol: r.Protocol,
|
||||
TargetId: peerID,
|
||||
TargetType: TargetTypePeer,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if r.Domain != "" {
|
||||
service.Domain = serviceName + "." + r.Domain
|
||||
}
|
||||
|
||||
if r.Pin != "" {
|
||||
service.Auth.PinAuth = &PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: r.Pin,
|
||||
}
|
||||
}
|
||||
|
||||
if r.Password != "" {
|
||||
service.Auth.PasswordAuth = &PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: r.Password,
|
||||
}
|
||||
}
|
||||
|
||||
if len(r.UserGroups) > 0 {
|
||||
service.Auth.BearerAuth = &BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: r.UserGroups,
|
||||
}
|
||||
}
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
// ExposeServiceResponse contains the result of a successful peer expose creation.
|
||||
type ExposeServiceResponse struct {
|
||||
ServiceName string
|
||||
ServiceURL string
|
||||
Domain string
|
||||
}
|
||||
|
||||
// GenerateExposeName generates a random service name for peer-exposed services.
|
||||
// The prefix, if provided, must be a valid DNS label component (lowercase alphanumeric and hyphens).
|
||||
func GenerateExposeName(prefix string) (string, error) {
|
||||
if prefix != "" && !validNamePrefix.MatchString(prefix) {
|
||||
return "", fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", prefix)
|
||||
}
|
||||
|
||||
suffixLen := 12
|
||||
if prefix != "" {
|
||||
suffixLen = 4
|
||||
}
|
||||
|
||||
suffix, err := randomAlphanumeric(suffixLen)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate random name: %w", err)
|
||||
}
|
||||
|
||||
if prefix == "" {
|
||||
return suffix, nil
|
||||
}
|
||||
return prefix + "-" + suffix, nil
|
||||
}
|
||||
|
||||
func randomAlphanumeric(n int) (string, error) {
|
||||
result := make([]byte, n)
|
||||
charsetLen := big.NewInt(int64(len(alphanumCharset)))
|
||||
for i := range result {
|
||||
idx, err := rand.Int(rand.Reader, charsetLen)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
result[i] = alphanumCharset[idx.Int64()]
|
||||
}
|
||||
return string(result), nil
|
||||
}
|
||||
@@ -0,0 +1,732 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/shared/hash/argon2id"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
func validProxy() *Service {
|
||||
return &Service{
|
||||
Name: "test",
|
||||
Domain: "example.com",
|
||||
Targets: []*Target{
|
||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 80, Protocol: "http", Enabled: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_Valid(t *testing.T) {
|
||||
require.NoError(t, validProxy().Validate())
|
||||
}
|
||||
|
||||
func TestValidate_EmptyName(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Name = ""
|
||||
assert.ErrorContains(t, rp.Validate(), "name is required")
|
||||
}
|
||||
|
||||
func TestValidate_EmptyDomain(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Domain = ""
|
||||
assert.ErrorContains(t, rp.Validate(), "domain is required")
|
||||
}
|
||||
|
||||
func TestValidate_NoTargets(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = nil
|
||||
assert.ErrorContains(t, rp.Validate(), "at least one target")
|
||||
}
|
||||
|
||||
func TestValidate_EmptyTargetId(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].TargetId = ""
|
||||
assert.ErrorContains(t, rp.Validate(), "empty target_id")
|
||||
}
|
||||
|
||||
func TestValidate_InvalidTargetType(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].TargetType = "invalid"
|
||||
assert.ErrorContains(t, rp.Validate(), "invalid target_type")
|
||||
}
|
||||
|
||||
func TestValidate_ResourceTarget(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = append(rp.Targets, &Target{
|
||||
TargetId: "resource-1",
|
||||
TargetType: TargetTypeHost,
|
||||
Host: "example.org",
|
||||
Port: 443,
|
||||
Protocol: "https",
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_MultipleTargetsOneInvalid(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = append(rp.Targets, &Target{
|
||||
TargetId: "",
|
||||
TargetType: TargetTypePeer,
|
||||
Host: "10.0.0.2",
|
||||
Port: 80,
|
||||
Protocol: "http",
|
||||
Enabled: true,
|
||||
})
|
||||
err := rp.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "target 1")
|
||||
assert.Contains(t, err.Error(), "empty target_id")
|
||||
}
|
||||
|
||||
func TestValidateTargetOptions_PathRewrite(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mode PathRewriteMode
|
||||
wantErr string
|
||||
}{
|
||||
{"empty is default", "", ""},
|
||||
{"preserve is valid", PathRewritePreserve, ""},
|
||||
{"unknown rejected", "regex", "unknown path_rewrite mode"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.PathRewrite = tt.mode
|
||||
err := rp.Validate()
|
||||
if tt.wantErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
timeout time.Duration
|
||||
wantErr string
|
||||
}{
|
||||
{"valid 30s", 30 * time.Second, ""},
|
||||
{"valid 2m", 2 * time.Minute, ""},
|
||||
{"zero is fine", 0, ""},
|
||||
{"negative", -1 * time.Second, "must be positive"},
|
||||
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.RequestTimeout = tt.timeout
|
||||
err := rp.Validate()
|
||||
if tt.wantErr == "" {
|
||||
assert.NoError(t, err)
|
||||
} else {
|
||||
assert.ErrorContains(t, err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateTargetOptions_CustomHeaders(t *testing.T) {
|
||||
t.Run("valid headers", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{
|
||||
"X-Custom": "value",
|
||||
"X-Trace": "abc123",
|
||||
}
|
||||
assert.NoError(t, rp.Validate())
|
||||
})
|
||||
|
||||
t.Run("CRLF in key", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Bad\r\nKey": "value"}
|
||||
assert.ErrorContains(t, rp.Validate(), "not a valid HTTP header name")
|
||||
})
|
||||
|
||||
t.Run("CRLF in value", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Good": "bad\nvalue"}
|
||||
assert.ErrorContains(t, rp.Validate(), "invalid characters")
|
||||
})
|
||||
|
||||
t.Run("hop-by-hop header rejected", func(t *testing.T) {
|
||||
for _, h := range []string{"Connection", "Transfer-Encoding", "Keep-Alive", "Upgrade", "Proxy-Connection"} {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
|
||||
assert.ErrorContains(t, rp.Validate(), "hop-by-hop", "header %q should be rejected", h)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reserved header rejected", func(t *testing.T) {
|
||||
for _, h := range []string{"X-Forwarded-For", "X-Real-IP", "X-Forwarded-Proto", "X-Forwarded-Host", "X-Forwarded-Port", "Cookie", "Forwarded", "Content-Length", "Content-Type"} {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{h: "value"}
|
||||
assert.ErrorContains(t, rp.Validate(), "managed by the proxy", "header %q should be rejected", h)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Host header rejected", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{"Host": "evil.com"}
|
||||
assert.ErrorContains(t, rp.Validate(), "pass_host_header")
|
||||
})
|
||||
|
||||
t.Run("too many headers", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
headers := make(map[string]string, 17)
|
||||
for i := range 17 {
|
||||
headers[fmt.Sprintf("X-H%d", i)] = "v"
|
||||
}
|
||||
rp.Targets[0].Options.CustomHeaders = headers
|
||||
assert.ErrorContains(t, rp.Validate(), "exceeds maximum of 16")
|
||||
})
|
||||
|
||||
t.Run("key too long", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{strings.Repeat("X", 129): "v"}
|
||||
assert.ErrorContains(t, rp.Validate(), "key")
|
||||
assert.ErrorContains(t, rp.Validate(), "exceeds maximum length")
|
||||
})
|
||||
|
||||
t.Run("value too long", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{"X-Ok": strings.Repeat("v", 4097)}
|
||||
assert.ErrorContains(t, rp.Validate(), "value exceeds maximum length")
|
||||
})
|
||||
|
||||
t.Run("duplicate canonical keys rejected", func(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].Options.CustomHeaders = map[string]string{
|
||||
"x-custom": "a",
|
||||
"X-Custom": "b",
|
||||
}
|
||||
assert.ErrorContains(t, rp.Validate(), "collide")
|
||||
})
|
||||
}
|
||||
|
||||
func TestToProtoMapping_TargetOptions(t *testing.T) {
|
||||
rp := &Service{
|
||||
ID: "svc-1",
|
||||
AccountID: "acc-1",
|
||||
Domain: "example.com",
|
||||
Targets: []*Target{
|
||||
{
|
||||
TargetId: "peer-1",
|
||||
TargetType: TargetTypePeer,
|
||||
Host: "10.0.0.1",
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Enabled: true,
|
||||
Options: TargetOptions{
|
||||
SkipTLSVerify: true,
|
||||
RequestTimeout: 30 * time.Second,
|
||||
PathRewrite: PathRewritePreserve,
|
||||
CustomHeaders: map[string]string{"X-Custom": "val"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
|
||||
require.Len(t, pm.Path, 1)
|
||||
|
||||
opts := pm.Path[0].Options
|
||||
require.NotNil(t, opts, "options should be populated")
|
||||
assert.True(t, opts.SkipTlsVerify)
|
||||
assert.Equal(t, proto.PathRewriteMode_PATH_REWRITE_PRESERVE, opts.PathRewrite)
|
||||
assert.Equal(t, map[string]string{"X-Custom": "val"}, opts.CustomHeaders)
|
||||
require.NotNil(t, opts.RequestTimeout)
|
||||
assert.Equal(t, int64(30), opts.RequestTimeout.Seconds)
|
||||
}
|
||||
|
||||
func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) {
|
||||
rp := &Service{
|
||||
ID: "svc-1",
|
||||
AccountID: "acc-1",
|
||||
Domain: "example.com",
|
||||
Targets: []*Target{
|
||||
{
|
||||
TargetId: "peer-1",
|
||||
TargetType: TargetTypePeer,
|
||||
Host: "10.0.0.1",
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
|
||||
require.Len(t, pm.Path, 1)
|
||||
assert.Nil(t, pm.Path[0].Options, "options should be nil when all defaults")
|
||||
}
|
||||
|
||||
func TestIsDefaultPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
scheme string
|
||||
port int
|
||||
want bool
|
||||
}{
|
||||
{"http", 80, true},
|
||||
{"https", 443, true},
|
||||
{"http", 443, false},
|
||||
{"https", 80, false},
|
||||
{"http", 8080, false},
|
||||
{"https", 8443, false},
|
||||
{"http", 0, false},
|
||||
{"https", 0, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("%s/%d", tt.scheme, tt.port), func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, isDefaultPort(tt.scheme, tt.port))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProtoMapping_PortInTargetURL(t *testing.T) {
|
||||
oidcConfig := proxy.OIDCValidationConfig{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
protocol string
|
||||
host string
|
||||
port int
|
||||
wantTarget string
|
||||
}{
|
||||
{
|
||||
name: "http with default port 80 omits port",
|
||||
protocol: "http",
|
||||
host: "10.0.0.1",
|
||||
port: 80,
|
||||
wantTarget: "http://10.0.0.1/",
|
||||
},
|
||||
{
|
||||
name: "https with default port 443 omits port",
|
||||
protocol: "https",
|
||||
host: "10.0.0.1",
|
||||
port: 443,
|
||||
wantTarget: "https://10.0.0.1/",
|
||||
},
|
||||
{
|
||||
name: "port 0 omits port",
|
||||
protocol: "http",
|
||||
host: "10.0.0.1",
|
||||
port: 0,
|
||||
wantTarget: "http://10.0.0.1/",
|
||||
},
|
||||
{
|
||||
name: "non-default port is included",
|
||||
protocol: "http",
|
||||
host: "10.0.0.1",
|
||||
port: 8080,
|
||||
wantTarget: "http://10.0.0.1:8080/",
|
||||
},
|
||||
{
|
||||
name: "https with non-default port is included",
|
||||
protocol: "https",
|
||||
host: "10.0.0.1",
|
||||
port: 8443,
|
||||
wantTarget: "https://10.0.0.1:8443/",
|
||||
},
|
||||
{
|
||||
name: "http port 443 is included",
|
||||
protocol: "http",
|
||||
host: "10.0.0.1",
|
||||
port: 443,
|
||||
wantTarget: "http://10.0.0.1:443/",
|
||||
},
|
||||
{
|
||||
name: "https port 80 is included",
|
||||
protocol: "https",
|
||||
host: "10.0.0.1",
|
||||
port: 80,
|
||||
wantTarget: "https://10.0.0.1:80/",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rp := &Service{
|
||||
ID: "test-id",
|
||||
AccountID: "acc-1",
|
||||
Domain: "example.com",
|
||||
Targets: []*Target{
|
||||
{
|
||||
TargetId: "peer-1",
|
||||
TargetType: TargetTypePeer,
|
||||
Host: tt.host,
|
||||
Port: tt.port,
|
||||
Protocol: tt.protocol,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
pm := rp.ToProtoMapping(Create, "token", oidcConfig)
|
||||
require.Len(t, pm.Path, 1, "should have one path mapping")
|
||||
assert.Equal(t, tt.wantTarget, pm.Path[0].Target)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProtoMapping_DisabledTargetSkipped(t *testing.T) {
|
||||
rp := &Service{
|
||||
ID: "test-id",
|
||||
AccountID: "acc-1",
|
||||
Domain: "example.com",
|
||||
Targets: []*Target{
|
||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Host: "10.0.0.1", Port: 8080, Protocol: "http", Enabled: false},
|
||||
{TargetId: "peer-2", TargetType: TargetTypePeer, Host: "10.0.0.2", Port: 9090, Protocol: "http", Enabled: true},
|
||||
},
|
||||
}
|
||||
pm := rp.ToProtoMapping(Create, "token", proxy.OIDCValidationConfig{})
|
||||
require.Len(t, pm.Path, 1)
|
||||
assert.Equal(t, "http://10.0.0.2:9090/", pm.Path[0].Target)
|
||||
}
|
||||
|
||||
func TestToProtoMapping_OperationTypes(t *testing.T) {
|
||||
rp := validProxy()
|
||||
tests := []struct {
|
||||
op Operation
|
||||
want proto.ProxyMappingUpdateType
|
||||
}{
|
||||
{Create, proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED},
|
||||
{Update, proto.ProxyMappingUpdateType_UPDATE_TYPE_MODIFIED},
|
||||
{Delete, proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(string(tt.op), func(t *testing.T) {
|
||||
pm := rp.ToProtoMapping(tt.op, "", proxy.OIDCValidationConfig{})
|
||||
assert.Equal(t, tt.want, pm.Type)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthConfig_HashSecrets(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config *AuthConfig
|
||||
wantErr bool
|
||||
validate func(*testing.T, *AuthConfig)
|
||||
}{
|
||||
{
|
||||
name: "hash password successfully",
|
||||
config: &AuthConfig{
|
||||
PasswordAuth: &PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "testPassword123",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config *AuthConfig) {
|
||||
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
||||
t.Errorf("Password not hashed with argon2id, got: %s", config.PasswordAuth.Password)
|
||||
}
|
||||
// Verify the hash can be verified
|
||||
if err := argon2id.Verify("testPassword123", config.PasswordAuth.Password); err != nil {
|
||||
t.Errorf("Hash verification failed: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "hash PIN successfully",
|
||||
config: &AuthConfig{
|
||||
PinAuth: &PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "123456",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config *AuthConfig) {
|
||||
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||
t.Errorf("PIN not hashed with argon2id, got: %s", config.PinAuth.Pin)
|
||||
}
|
||||
// Verify the hash can be verified
|
||||
if err := argon2id.Verify("123456", config.PinAuth.Pin); err != nil {
|
||||
t.Errorf("Hash verification failed: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "hash both password and PIN",
|
||||
config: &AuthConfig{
|
||||
PasswordAuth: &PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "password",
|
||||
},
|
||||
PinAuth: &PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "9999",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config *AuthConfig) {
|
||||
if !strings.HasPrefix(config.PasswordAuth.Password, "$argon2id$") {
|
||||
t.Errorf("Password not hashed with argon2id")
|
||||
}
|
||||
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||
t.Errorf("PIN not hashed with argon2id")
|
||||
}
|
||||
if err := argon2id.Verify("password", config.PasswordAuth.Password); err != nil {
|
||||
t.Errorf("Password hash verification failed: %v", err)
|
||||
}
|
||||
if err := argon2id.Verify("9999", config.PinAuth.Pin); err != nil {
|
||||
t.Errorf("PIN hash verification failed: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "skip disabled password auth",
|
||||
config: &AuthConfig{
|
||||
PasswordAuth: &PasswordAuthConfig{
|
||||
Enabled: false,
|
||||
Password: "password",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config *AuthConfig) {
|
||||
if config.PasswordAuth.Password != "password" {
|
||||
t.Errorf("Disabled password auth should not be hashed")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "skip empty password",
|
||||
config: &AuthConfig{
|
||||
PasswordAuth: &PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config *AuthConfig) {
|
||||
if config.PasswordAuth.Password != "" {
|
||||
t.Errorf("Empty password should remain empty")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "skip nil password auth",
|
||||
config: &AuthConfig{
|
||||
PasswordAuth: nil,
|
||||
PinAuth: &PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "1234",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
validate: func(t *testing.T, config *AuthConfig) {
|
||||
if config.PasswordAuth != nil {
|
||||
t.Errorf("PasswordAuth should remain nil")
|
||||
}
|
||||
if !strings.HasPrefix(config.PinAuth.Pin, "$argon2id$") {
|
||||
t.Errorf("PIN should still be hashed")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.HashSecrets()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("HashSecrets() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if tt.validate != nil {
|
||||
tt.validate(t, tt.config)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthConfig_HashSecrets_VerifyIncorrectSecret(t *testing.T) {
|
||||
config := &AuthConfig{
|
||||
PasswordAuth: &PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "correctPassword",
|
||||
},
|
||||
}
|
||||
|
||||
if err := config.HashSecrets(); err != nil {
|
||||
t.Fatalf("HashSecrets() error = %v", err)
|
||||
}
|
||||
|
||||
// Verify with wrong password should fail
|
||||
err := argon2id.Verify("wrongPassword", config.PasswordAuth.Password)
|
||||
if !errors.Is(err, argon2id.ErrMismatchedHashAndPassword) {
|
||||
t.Errorf("Expected ErrMismatchedHashAndPassword, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthConfig_ClearSecrets(t *testing.T) {
|
||||
config := &AuthConfig{
|
||||
PasswordAuth: &PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: "hashedPassword",
|
||||
},
|
||||
PinAuth: &PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: "hashedPin",
|
||||
},
|
||||
}
|
||||
|
||||
config.ClearSecrets()
|
||||
|
||||
if config.PasswordAuth.Password != "" {
|
||||
t.Errorf("Password not cleared, got: %s", config.PasswordAuth.Password)
|
||||
}
|
||||
if config.PinAuth.Pin != "" {
|
||||
t.Errorf("PIN not cleared, got: %s", config.PinAuth.Pin)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateExposeName(t *testing.T) {
|
||||
t.Run("no prefix generates 12-char name", func(t *testing.T) {
|
||||
name, err := GenerateExposeName("")
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, name, 12)
|
||||
assert.Regexp(t, `^[a-z0-9]+$`, name)
|
||||
})
|
||||
|
||||
t.Run("with prefix generates prefix-XXXX", func(t *testing.T) {
|
||||
name, err := GenerateExposeName("myapp")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, strings.HasPrefix(name, "myapp-"), "name should start with prefix")
|
||||
suffix := strings.TrimPrefix(name, "myapp-")
|
||||
assert.Len(t, suffix, 4, "suffix should be 4 chars")
|
||||
assert.Regexp(t, `^[a-z0-9]+$`, suffix)
|
||||
})
|
||||
|
||||
t.Run("unique names", func(t *testing.T) {
|
||||
names := make(map[string]bool)
|
||||
for i := 0; i < 50; i++ {
|
||||
name, err := GenerateExposeName("")
|
||||
require.NoError(t, err)
|
||||
names[name] = true
|
||||
}
|
||||
assert.Greater(t, len(names), 45, "should generate mostly unique names")
|
||||
})
|
||||
|
||||
t.Run("valid prefixes", func(t *testing.T) {
|
||||
validPrefixes := []string{"a", "ab", "a1", "my-app", "web-server-01", "a-b"}
|
||||
for _, prefix := range validPrefixes {
|
||||
name, err := GenerateExposeName(prefix)
|
||||
assert.NoError(t, err, "prefix %q should be valid", prefix)
|
||||
assert.True(t, strings.HasPrefix(name, prefix+"-"), "name should start with %q-", prefix)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid prefixes", func(t *testing.T) {
|
||||
invalidPrefixes := []string{
|
||||
"-starts-with-dash",
|
||||
"ends-with-dash-",
|
||||
"has.dots",
|
||||
"HAS-UPPER",
|
||||
"has spaces",
|
||||
"has/slash",
|
||||
"a--",
|
||||
}
|
||||
for _, prefix := range invalidPrefixes {
|
||||
_, err := GenerateExposeName(prefix)
|
||||
assert.Error(t, err, "prefix %q should be invalid", prefix)
|
||||
assert.Contains(t, err.Error(), "invalid name prefix")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExposeServiceRequest_ToService(t *testing.T) {
|
||||
t.Run("basic HTTP service", func(t *testing.T) {
|
||||
req := &ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
}
|
||||
|
||||
service := req.ToService("account-1", "peer-1", "mysvc")
|
||||
|
||||
assert.Equal(t, "account-1", service.AccountID)
|
||||
assert.Equal(t, "mysvc", service.Name)
|
||||
assert.True(t, service.Enabled)
|
||||
assert.Empty(t, service.Domain, "domain should be empty when not specified")
|
||||
require.Len(t, service.Targets, 1)
|
||||
|
||||
target := service.Targets[0]
|
||||
assert.Equal(t, 8080, target.Port)
|
||||
assert.Equal(t, "http", target.Protocol)
|
||||
assert.Equal(t, "peer-1", target.TargetId)
|
||||
assert.Equal(t, TargetTypePeer, target.TargetType)
|
||||
assert.True(t, target.Enabled)
|
||||
assert.Equal(t, "account-1", target.AccountID)
|
||||
})
|
||||
|
||||
t.Run("with custom domain", func(t *testing.T) {
|
||||
req := &ExposeServiceRequest{
|
||||
Port: 3000,
|
||||
Domain: "example.com",
|
||||
}
|
||||
|
||||
service := req.ToService("acc", "peer", "web")
|
||||
assert.Equal(t, "web.example.com", service.Domain)
|
||||
})
|
||||
|
||||
t.Run("with PIN auth", func(t *testing.T) {
|
||||
req := &ExposeServiceRequest{
|
||||
Port: 80,
|
||||
Pin: "1234",
|
||||
}
|
||||
|
||||
service := req.ToService("acc", "peer", "svc")
|
||||
require.NotNil(t, service.Auth.PinAuth)
|
||||
assert.True(t, service.Auth.PinAuth.Enabled)
|
||||
assert.Equal(t, "1234", service.Auth.PinAuth.Pin)
|
||||
assert.Nil(t, service.Auth.PasswordAuth)
|
||||
assert.Nil(t, service.Auth.BearerAuth)
|
||||
})
|
||||
|
||||
t.Run("with password auth", func(t *testing.T) {
|
||||
req := &ExposeServiceRequest{
|
||||
Port: 80,
|
||||
Password: "secret",
|
||||
}
|
||||
|
||||
service := req.ToService("acc", "peer", "svc")
|
||||
require.NotNil(t, service.Auth.PasswordAuth)
|
||||
assert.True(t, service.Auth.PasswordAuth.Enabled)
|
||||
assert.Equal(t, "secret", service.Auth.PasswordAuth.Password)
|
||||
})
|
||||
|
||||
t.Run("with user groups (bearer auth)", func(t *testing.T) {
|
||||
req := &ExposeServiceRequest{
|
||||
Port: 80,
|
||||
UserGroups: []string{"admins", "devs"},
|
||||
}
|
||||
|
||||
service := req.ToService("acc", "peer", "svc")
|
||||
require.NotNil(t, service.Auth.BearerAuth)
|
||||
assert.True(t, service.Auth.BearerAuth.Enabled)
|
||||
assert.Equal(t, []string{"admins", "devs"}, service.Auth.BearerAuth.DistributionGroups)
|
||||
})
|
||||
|
||||
t.Run("with all auth types", func(t *testing.T) {
|
||||
req := &ExposeServiceRequest{
|
||||
Port: 443,
|
||||
Domain: "myco.com",
|
||||
Pin: "9999",
|
||||
Password: "pass",
|
||||
UserGroups: []string{"ops"},
|
||||
}
|
||||
|
||||
service := req.ToService("acc", "peer", "full")
|
||||
assert.Equal(t, "full.myco.com", service.Domain)
|
||||
require.NotNil(t, service.Auth.PinAuth)
|
||||
require.NotNil(t, service.Auth.PasswordAuth)
|
||||
require.NotNil(t, service.Auth.BearerAuth)
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user