Merge branch 'main' into refactor/permissions-manager

# Conflicts:
#	management/internals/modules/reverseproxy/service/manager/api.go
#	management/server/http/testing/testing_tools/channel/channel.go
This commit is contained in:
pascal
2026-03-27 14:37:29 +01:00
125 changed files with 11320 additions and 413 deletions

View File

@@ -36,6 +36,7 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
accesslogsmanager.RegisterEndpoints(router, accessLogsManager, permissionsManager)
router.HandleFunc("/reverse-proxies/clusters", permissionsManager.WithPermission(modules.Services, operations.Read, h.getClusters)).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllServices)).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Create, h.createService)).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Read, h.getService)).Methods("GET", "OPTIONS")
@@ -151,3 +152,21 @@ func (h *handler) deleteService(w http.ResponseWriter, r *http.Request, userAuth
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func (h *handler) getClusters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
for _, c := range clusters {
apiClusters = append(apiClusters, api.ProxyCluster{
Address: c.Address,
ConnectedProxies: c.ConnectedProxies,
})
}
util.WriteJSONObject(r.Context(), w, apiClusters)
}

View File

@@ -75,6 +75,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
mockCtrl := proxy.NewMockController(ctrl)
mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes()
mockCtrl.EXPECT().ClusterRequireSubdomain(gomock.Any()).Return((*bool)(nil)).AnyTimes()
mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()

View File

@@ -14,6 +14,8 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer"
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
@@ -95,6 +97,19 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
m.exposeReaper.StartExposeReaper(ctx)
}
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetActiveProxyClusters(ctx)
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
services, err := m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
if err != nil {
@@ -192,6 +207,10 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
}
service.ProxyCluster = proxyCluster
if err := m.validateSubdomainRequirement(service.Domain, proxyCluster); err != nil {
return err
}
}
service.AccountID = accountID
@@ -217,6 +236,20 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return nil
}
// validateSubdomainRequirement checks whether the domain can be used bare
// (without a subdomain label) on the given cluster. If the cluster reports
// require_subdomain=true and the domain equals the cluster domain, it rejects.
func (m *Manager) validateSubdomainRequirement(domain, cluster string) error {
if domain != cluster {
return nil
}
requireSub := m.proxyController.ClusterRequireSubdomain(cluster)
if requireSub != nil && *requireSub {
return status.Errorf(status.InvalidArgument, "domain %s requires a subdomain label", domain)
}
return nil
}
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if svc.Domain != "" {
@@ -437,51 +470,63 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo)
})
return &updateInfo, err
}
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
if existingService.Terminated {
return status.Errorf(status.PermissionDenied, "service is terminated and cannot be updated")
}
if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil {
return err
}
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
return err
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
m.preserveExistingAuthSecrets(service, existingService)
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
return err
}
m.preserveServiceMetadata(service, existingService)
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
} else {
service.ProxyCluster = existingService.ProxyCluster
}
if err := m.ensureL4Port(ctx, transaction, service); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
if err := m.validateSubdomainRequirement(service.Domain, service.ProxyCluster); err != nil {
return err
}
return nil
})
m.preserveExistingAuthSecrets(service, existingService)
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
return err
}
m.preserveServiceMetadata(service, existingService)
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
return &updateInfo, err
if err := m.ensureL4Port(ctx, transaction, service); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
return nil
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
@@ -599,18 +644,12 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
for _, target := range targets {
switch target.TargetType {
case service.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
if err := validatePeerTarget(ctx, transaction, accountID, target); err != nil {
return err
}
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
return err
}
default:
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
@@ -619,6 +658,39 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
return nil
}
func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
return nil
}
func validateResourceTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId)
if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
return validateResourceTargetType(target, resource)
}
// validateResourceTargetType checks that target_type matches the actual network resource type.
func validateResourceTargetType(target *service.Target, resource *resourcetypes.NetworkResource) error {
expected := resourcetypes.NetworkResourceType(target.TargetType)
if resource.Type != expected {
return status.Errorf(status.InvalidArgument,
"target %q has target_type %q but resource is of type %q",
target.TargetId, target.TargetType, resource.Type,
)
}
return nil
}
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
var s *service.Service
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {

View File

@@ -22,6 +22,7 @@ import (
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -1206,3 +1207,126 @@ func TestValidateProtocolChange(t *testing.T) {
})
}
}
func TestValidateTargetReferences_ResourceTypeMismatch(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
tests := []struct {
name string
targetType rpservice.TargetType
resourceType resourcetypes.NetworkResourceType
wantErr bool
}{
{"host matches host", rpservice.TargetTypeHost, resourcetypes.Host, false},
{"domain matches domain", rpservice.TargetTypeDomain, resourcetypes.Domain, false},
{"subnet matches subnet", rpservice.TargetTypeSubnet, resourcetypes.Subnet, false},
{"host but resource is domain", rpservice.TargetTypeHost, resourcetypes.Domain, true},
{"domain but resource is host", rpservice.TargetTypeDomain, resourcetypes.Host, true},
{"host but resource is subnet", rpservice.TargetTypeHost, resourcetypes.Subnet, true},
{"subnet but resource is domain", rpservice.TargetTypeSubnet, resourcetypes.Domain, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore.EXPECT().
GetNetworkResourceByID(gomock.Any(), store.LockingStrengthShare, accountID, "resource-1").
Return(&resourcetypes.NetworkResource{Type: tt.resourceType}, nil)
targets := []*rpservice.Target{
{TargetId: "resource-1", TargetType: tt.targetType, Host: "10.0.0.1"},
}
err := validateTargetReferences(ctx, mockStore, accountID, targets)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "target_type")
} else {
require.NoError(t, err)
}
})
}
}
func TestValidateTargetReferences_PeerValid(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
mockStore.EXPECT().
GetPeerByID(gomock.Any(), store.LockingStrengthShare, accountID, "peer-1").
Return(&nbpeer.Peer{}, nil)
targets := []*rpservice.Target{
{TargetId: "peer-1", TargetType: rpservice.TargetTypePeer},
}
require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets))
}
func TestValidateSubdomainRequirement(t *testing.T) {
ptrBool := func(b bool) *bool { return &b }
tests := []struct {
name string
domain string
cluster string
requireSubdomain *bool
wantErr bool
}{
{
name: "subdomain present, require_subdomain true",
domain: "app.eu1.proxy.netbird.io",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: ptrBool(true),
wantErr: false,
},
{
name: "bare cluster domain, require_subdomain true",
domain: "eu1.proxy.netbird.io",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: ptrBool(true),
wantErr: true,
},
{
name: "bare cluster domain, require_subdomain false",
domain: "eu1.proxy.netbird.io",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: ptrBool(false),
wantErr: false,
},
{
name: "bare cluster domain, require_subdomain nil (default)",
domain: "eu1.proxy.netbird.io",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: nil,
wantErr: false,
},
{
name: "custom domain apex is not the cluster",
domain: "example.com",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: ptrBool(true),
wantErr: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
mockCtrl := proxy.NewMockController(ctrl)
mockCtrl.EXPECT().ClusterRequireSubdomain(tc.cluster).Return(tc.requireSubdomain).AnyTimes()
mgr := &Manager{proxyController: mockCtrl}
err := mgr.validateSubdomainRequirement(tc.domain, tc.cluster)
if tc.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "requires a subdomain label")
} else {
require.NoError(t, err)
}
})
}
}