Merge remote-tracking branch 'origin/main' into refactor/permissions-manager

# Conflicts:
#	management/internals/modules/reverseproxy/domain/manager/manager.go
#	management/internals/modules/reverseproxy/service/manager/api.go
#	management/internals/server/modules.go
#	management/server/http/testing/testing_tools/channel/channel.go
This commit is contained in:
pascal
2026-03-17 12:38:08 +01:00
244 changed files with 17304 additions and 3509 deletions

View File

@@ -4,7 +4,10 @@ import (
"context"
"fmt"
"math/rand/v2"
"net/http"
"os"
"slices"
"strconv"
"time"
log "github.com/sirupsen/logrus"
@@ -20,6 +23,45 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
const (
defaultAutoAssignPortMin uint16 = 10000
defaultAutoAssignPortMax uint16 = 49151
// EnvAutoAssignPortMin overrides the lower bound for auto-assigned L4 listen ports.
EnvAutoAssignPortMin = "NB_PROXY_PORT_MIN"
// EnvAutoAssignPortMax overrides the upper bound for auto-assigned L4 listen ports.
EnvAutoAssignPortMax = "NB_PROXY_PORT_MAX"
)
var (
autoAssignPortMin = defaultAutoAssignPortMin
autoAssignPortMax = defaultAutoAssignPortMax
)
func init() {
autoAssignPortMin = portFromEnv(EnvAutoAssignPortMin, defaultAutoAssignPortMin)
autoAssignPortMax = portFromEnv(EnvAutoAssignPortMax, defaultAutoAssignPortMax)
if autoAssignPortMin > autoAssignPortMax {
log.Warnf("port range invalid: %s (%d) > %s (%d), using defaults",
EnvAutoAssignPortMin, autoAssignPortMin, EnvAutoAssignPortMax, autoAssignPortMax)
autoAssignPortMin = defaultAutoAssignPortMin
autoAssignPortMax = defaultAutoAssignPortMax
}
}
func portFromEnv(key string, fallback uint16) uint16 {
val := os.Getenv(key)
if val == "" {
return fallback
}
n, err := strconv.ParseUint(val, 10, 16)
if err != nil {
log.Warnf("invalid %s value %q, using default %d: %v", key, val, fallback, err)
return fallback
}
return uint16(n)
}
const unknownHostPlaceholder = "unknown"
// ClusterDeriver derives the proxy cluster from a domain.
@@ -102,6 +144,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
}
return nil
}
@@ -158,6 +201,12 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return fmt.Errorf("hash secrets: %w", err)
}
for i, h := range service.Auth.HeaderAuths {
if h != nil && h.Enabled && h.Value == "" {
return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i)
}
}
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return fmt.Errorf("generate session keys: %w", err)
@@ -168,55 +217,19 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return nil
}
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
if svc.Domain != "" {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
return err
}
}
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
return nil
})
}
// persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Lock the peer row to serialize concurrent creates for the same peer.
// Without this, when no ephemeral rows exist yet, FOR UPDATE on the services
// table returns no rows and acquires no locks, allowing concurrent inserts
// to bypass the per-peer limit.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
if err := m.checkDomainAvailable(ctx, transaction, accountID, svc.Domain, ""); err != nil {
if err := m.checkPortConflict(ctx, transaction, svc); err != nil {
return err
}
@@ -232,17 +245,161 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
})
}
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service) error {
if !service.IsL4Protocol(svc.Mode) {
return nil
}
customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster)
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
}
svc.ListenPort = 0
}
if svc.ListenPort == 0 {
port, err := m.assignPort(ctx, tx, svc.ProxyCluster)
if err != nil {
return err
}
svc.ListenPort = port
svc.PortAutoAssigned = true
}
return nil
}
// checkPortConflict rejects L4 services that would conflict on the same listener.
// For TCP/UDP: unique per cluster+protocol+port.
// For TLS: unique per cluster+port+domain (SNI routing allows sharing ports).
// Cross-protocol conflicts (TLS vs raw TCP) are intentionally not checked:
// the proxy router multiplexes TLS (via SNI) and raw TCP (via fallback) on the same listener.
func (m *Manager) checkPortConflict(ctx context.Context, transaction store.Store, svc *service.Service) error {
if !service.IsL4Protocol(svc.Mode) || svc.ListenPort == 0 {
return nil
}
existing, err := transaction.GetServicesByClusterAndPort(ctx, store.LockingStrengthUpdate, svc.ProxyCluster, svc.Mode, svc.ListenPort)
if err != nil {
return fmt.Errorf("query port conflicts: %w", err)
}
for _, s := range existing {
if s.ID == svc.ID {
continue
}
// TLS services on the same port are allowed if they have different domains (SNI routing)
if svc.Mode == service.ModeTLS && s.Domain != svc.Domain {
continue
}
return status.Errorf(status.AlreadyExists,
"%s port %d is already in use by service %q on cluster %s",
svc.Mode, svc.ListenPort, s.Name, svc.ProxyCluster)
}
return nil
}
// assignPort picks a random available port on the cluster within the auto-assign range.
func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string) (uint16, error) {
services, err := tx.GetServicesByCluster(ctx, store.LockingStrengthUpdate, cluster)
if err != nil {
return 0, fmt.Errorf("query cluster ports: %w", err)
}
occupied := make(map[uint16]struct{}, len(services))
for _, s := range services {
if s.ListenPort > 0 {
occupied[s.ListenPort] = struct{}{}
}
}
portRange := int(autoAssignPortMax-autoAssignPortMin) + 1
for range 100 {
port := autoAssignPortMin + uint16(rand.IntN(portRange))
if _, taken := occupied[port]; !taken {
return port, nil
}
}
for port := autoAssignPortMin; port <= autoAssignPortMax; port++ {
if _, taken := occupied[port]; !taken {
return port, nil
}
}
return 0, status.Errorf(status.PreconditionFailed, "no available ports on cluster %s", cluster)
}
// persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
return err
}
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, svc); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
return nil
})
}
func (m *Manager) validateEphemeralPreconditions(ctx context.Context, transaction store.Store, accountID, peerID string, svc *service.Service) error {
// Lock the peer row to serialize concurrent creates for the same peer.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
return err
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
return nil
}
// checkDomainAvailable checks that no other service already uses this domain.
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err)
return fmt.Errorf("check existing service: %w", err)
}
return nil
}
if existingService != nil && existingService.ID != excludeServiceID {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
return status.Errorf(status.AlreadyExists, "domain already taken")
}
return nil
@@ -285,6 +442,10 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
return err
}
if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil {
return err
}
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
@@ -297,13 +458,22 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
}
m.preserveExistingAuthSecrets(service, existingService)
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
return err
}
m.preserveServiceMetadata(service, existingService)
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := m.ensureL4Port(ctx, transaction, service); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
@@ -314,35 +484,85 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
return &updateInfo, err
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil {
return err
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
} else {
service.ProxyCluster = newCluster
svc.ProxyCluster = newCluster
}
}
return nil
}
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
// validateProtocolChange rejects mode changes on update.
// Only empty<->HTTP is allowed; all other transitions are rejected.
func validateProtocolChange(oldMode, newMode string) error {
if newMode == "" || newMode == oldMode {
return nil
}
if isHTTPFamily(oldMode) && isHTTPFamily(newMode) {
return nil
}
return status.Errorf(status.InvalidArgument, "cannot change mode from %q to %q", oldMode, newMode)
}
func isHTTPFamily(mode string) bool {
return mode == "" || mode == "http"
}
func (m *Manager) preserveExistingAuthSecrets(svc, existingService *service.Service) {
if svc.Auth.PasswordAuth != nil && svc.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
svc.Auth.PasswordAuth.Password == "" {
svc.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
if svc.Auth.PinAuth != nil && svc.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
svc.Auth.PinAuth.Pin == "" {
svc.Auth.PinAuth = existingService.Auth.PinAuth
}
preserveHeaderAuthHashes(svc.Auth.HeaderAuths, existingService.Auth.HeaderAuths)
}
// preserveHeaderAuthHashes fills in empty header auth values from the existing
// service so that unchanged secrets are not lost on update.
func preserveHeaderAuthHashes(headers, existing []*service.HeaderAuthConfig) {
if len(headers) == 0 || len(existing) == 0 {
return
}
existingByHeader := make(map[string]string, len(existing))
for _, h := range existing {
if h != nil && h.Value != "" {
existingByHeader[http.CanonicalHeaderKey(h.Header)] = h.Value
}
}
for _, h := range headers {
if h != nil && h.Enabled && h.Value == "" {
if hash, ok := existingByHeader[http.CanonicalHeaderKey(h.Header)]; ok {
h.Value = hash
}
}
}
}
// validateHeaderAuthValues checks that all enabled header auths have a value
// (either freshly provided or preserved from the existing service).
func validateHeaderAuthValues(headers []*service.HeaderAuthConfig) error {
for i, h := range headers {
if h != nil && h.Enabled && h.Value == "" {
return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i)
}
}
return nil
}
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
@@ -351,11 +571,18 @@ func (m *Manager) preserveServiceMetadata(service, existingService *service.Serv
service.SessionPublicKey = existingService.SessionPublicKey
}
func (m *Manager) preserveListenPort(svc, existing *service.Service) {
if existing.ListenPort > 0 && svc.ListenPort == 0 {
svc.ListenPort = existing.ListenPort
svc.PortAutoAssigned = existing.PortAutoAssigned
}
}
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyController.GetOIDCValidationConfig()
switch {
case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
case updateInfo.domainChanged || updateInfo.oldCluster != s.ProxyCluster:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
case !s.Enabled && updateInfo.serviceEnabledChanged:
@@ -385,6 +612,8 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
default:
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
}
}
return nil
@@ -622,6 +851,10 @@ func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerI
return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group")
}
func (m *Manager) resolveDefaultDomain(serviceName string) (string, error) {
return m.buildRandomDomain(serviceName)
}
// CreateServiceFromPeer creates a service initiated by a peer expose request.
// It validates the request, checks expose permissions, enforces the per-peer limit,
// creates the service, and tracks it for TTL-based reaping.
@@ -643,9 +876,9 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
svc.Source = service.SourceEphemeral
if svc.Domain == "" {
domain, err := m.buildRandomDomain(svc.Name)
domain, err := m.resolveDefaultDomain(svc.Name)
if err != nil {
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
return nil, err
}
svc.Domain = domain
}
@@ -686,10 +919,16 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
serviceURL := "https://" + svc.Domain
if service.IsL4Protocol(svc.Mode) {
serviceURL = fmt.Sprintf("%s://%s:%d", svc.Mode, svc.Domain, svc.ListenPort)
}
return &service.ExposeServiceResponse{
ServiceName: svc.Name,
ServiceURL: "https://" + svc.Domain,
Domain: svc.Domain,
ServiceName: svc.Name,
ServiceURL: serviceURL,
Domain: svc.Domain,
PortAutoAssigned: svc.PortAutoAssigned,
}, nil
}
@@ -708,64 +947,47 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
return groupIDs, nil
}
func (m *Manager) buildRandomDomain(name string) (string, error) {
func (m *Manager) getDefaultClusterDomain() (string, error) {
if m.clusterDeriver == nil {
return "", fmt.Errorf("unable to get random domain")
return "", fmt.Errorf("unable to get cluster domain")
}
clusterDomains := m.clusterDeriver.GetClusterDomains()
if len(clusterDomains) == 0 {
return "", fmt.Errorf("no cluster domains found for service %s", name)
return "", fmt.Errorf("no cluster domains available")
}
index := rand.IntN(len(clusterDomains))
domain := name + "." + clusterDomains[index]
return domain, nil
return clusterDomains[rand.IntN(len(clusterDomains))], nil
}
func (m *Manager) buildRandomDomain(name string) (string, error) {
domain, err := m.getDefaultClusterDomain()
if err != nil {
return "", err
}
return name + "." + domain, nil
}
// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service.
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
return m.store.RenewEphemeralService(ctx, accountID, peerID, domain)
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
return m.store.RenewEphemeralService(ctx, accountID, peerID, serviceID)
}
// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB.
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, serviceID, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service %s: %v", serviceID, err)
return err
}
return nil
}
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
// deleteServiceFromPeer deletes a peer-initiated service identified by service ID.
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
if err != nil {
return err
}
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string, expired bool) error {
activityCode := activity.PeerServiceUnexposed
if expired {
activityCode = activity.PeerServiceExposeExpired
}
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
}
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
svc, err := m.store.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
return nil, err
}
if svc.Source != service.SourceEphemeral {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
}
if svc.SourcePeer != peerID {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
}
return svc, nil
return m.deletePeerService(ctx, accountID, peerID, serviceID, activityCode)
}
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {