Merge remote-tracking branch 'origin/main' into refactor/permissions-manager

# Conflicts:
#	management/internals/modules/reverseproxy/domain/manager/manager.go
#	management/internals/modules/reverseproxy/service/manager/api.go
#	management/internals/server/modules.go
#	management/server/http/testing/testing_tools/channel/channel.go
This commit is contained in:
pascal
2026-03-17 12:38:08 +01:00
244 changed files with 17304 additions and 3509 deletions

View File

@@ -22,7 +22,7 @@ type Manager interface {
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error)
RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StartExposeReaper(ctx context.Context)
}

View File

@@ -211,17 +211,17 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter
}
// RenewServiceFromPeer mocks base method.
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain)
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer.
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, serviceID)
}
// SetCertificateIssuedAt mocks base method.
@@ -265,17 +265,17 @@ func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Ca
}
// StopServiceFromPeer mocks base method.
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain)
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// StopServiceFromPeer indicates an expected call of StopServiceFromPeer.
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, serviceID)
}
// UpdateService mocks base method.

View File

@@ -20,13 +20,15 @@ import (
)
type handler struct {
manager rpservice.Manager
manager rpservice.Manager
permissionsManager permissions.Manager
}
// RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, permissionsManager permissions.Manager, router *mux.Router) {
h := &handler{
manager: manager,
manager: manager,
permissionsManager: permissionsManager,
}
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()

View File

@@ -18,8 +18,8 @@ func TestReapExpiredExposes(t *testing.T) {
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
@@ -28,19 +28,19 @@ func TestReapExpiredExposes(t *testing.T) {
// Create a non-expired service
resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8081,
Protocol: "http",
Port: 8081,
Mode: "http",
})
require.NoError(t, err)
mgr.exposeReaper.reapExpiredExposes(ctx)
// Expired service should be deleted
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "expired service should be deleted")
// Non-expired service should remain
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp2.Domain)
_, err = testStore.GetServiceByDomain(ctx, resp2.Domain)
require.NoError(t, err, "active service should remain")
}
@@ -49,15 +49,16 @@ func TestReapAlreadyDeletedService(t *testing.T) {
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Delete the service before reaping
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
// Reaping should handle the already-deleted service gracefully
@@ -70,8 +71,8 @@ func TestConcurrentReapAndRenew(t *testing.T) {
for i := range 5 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
Port: uint16(8080 + i),
Mode: "http",
})
require.NoError(t, err)
}
@@ -108,17 +109,19 @@ func TestRenewEphemeralService(t *testing.T) {
t.Run("renew succeeds for active service", func(t *testing.T) {
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8082,
Protocol: "http",
Port: 8082,
Mode: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svc, lookupErr := mgr.store.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, lookupErr)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID)
require.NoError(t, err)
})
t.Run("renew fails for nonexistent domain", func(t *testing.T) {
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id")
require.Error(t, err)
assert.Contains(t, err.Error(), "no active expose session")
})
@@ -133,8 +136,8 @@ func TestCountAndExistsEphemeralServices(t *testing.T) {
assert.Equal(t, int64(0), count)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8083,
Protocol: "http",
Port: 8083,
Mode: "http",
})
require.NoError(t, err)
@@ -157,15 +160,15 @@ func TestMaxExposesPerPeerEnforced(t *testing.T) {
for i := range maxExposesPerPeer {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8090 + i,
Protocol: "http",
Port: uint16(8090 + i),
Mode: "http",
})
require.NoError(t, err, "expose %d should succeed", i)
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9999,
Protocol: "http",
Port: 9999,
Mode: "http",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "maximum number of active expose sessions")
@@ -176,8 +179,8 @@ func TestReapSkipsRenewedService(t *testing.T) {
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8086,
Protocol: "http",
Port: 8086,
Mode: "http",
})
require.NoError(t, err)
@@ -185,20 +188,30 @@ func TestReapSkipsRenewedService(t *testing.T) {
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Renew it before the reaper runs
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svc, err := testStore.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID)
require.NoError(t, err)
// Reaper should skip it because the re-check sees a fresh timestamp
mgr.exposeReaper.reapExpiredExposes(ctx)
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, err, "renewed service should survive reaping")
}
// resolveServiceIDByDomain looks up a service ID by domain in tests.
func resolveServiceIDByDomain(t *testing.T, s store.Store, domain string) string {
t.Helper()
svc, err := s.GetServiceByDomain(context.Background(), domain)
require.NoError(t, err)
return svc.ID
}
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
t.Helper()
svc, err := s.GetServiceByDomain(context.Background(), accountID, domain)
svc, err := s.GetServiceByDomain(context.Background(), domain)
require.NoError(t, err)
expired := time.Now().Add(-2 * exposeTTL)

View File

@@ -0,0 +1,580 @@
package manager
import (
"context"
"net"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
const testCluster = "test-cluster"
func boolPtr(v bool) *bool { return &v }
// setupL4Test creates a manager with a mock proxy controller for L4 port tests.
func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Store, *proxy.MockController) {
t.Helper()
ctrl := gomock.NewController(t)
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
err = testStore.SaveAccount(ctx, &types.Account{
Id: testAccountID,
CreatedBy: testUserID,
Settings: &types.Settings{
PeerExposeEnabled: true,
PeerExposeGroups: []string{testGroupID},
},
Users: map[string]*types.User{
testUserID: {
Id: testUserID,
AccountID: testAccountID,
Role: types.UserRoleAdmin,
},
},
Peers: map[string]*nbpeer.Peer{
testPeerID: {
ID: testPeerID,
AccountID: testAccountID,
Key: "test-key",
DNSLabel: "test-peer",
Name: "test-peer",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
},
},
Groups: map[string]*types.Group{
testGroupID: {
ID: testGroupID,
AccountID: testAccountID,
Name: "Expose Group",
},
},
})
require.NoError(t, err)
err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID)
require.NoError(t, err)
mockCtrl := proxy.NewMockController(ctrl)
mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes()
mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
},
}
mgr := &Manager{
store: testStore,
accountManager: accountMgr,
proxyController: mockCtrl,
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
}
mgr.exposeReaper = &exposeReaper{manager: mgr}
return mgr, testStore, mockCtrl
}
// seedService creates a service directly in the store for test setup.
func seedService(t *testing.T, s store.Store, name, protocol, domain, cluster string, port uint16) *rpservice.Service {
t.Helper()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: name,
Mode: protocol,
Domain: domain,
ProxyCluster: cluster,
ListenPort: port,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: protocol, Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := s.CreateService(context.Background(), svc)
require.NoError(t, err)
return svc
}
func TestPortConflict_TCPSamePortCluster(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tcp", "tcp", testCluster, testCluster, 5432)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "conflicting-tcp",
Mode: "tcp",
Domain: "conflicting-tcp." + testCluster,
ProxyCluster: testCluster,
ListenPort: 5432,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "TCP+TCP on same port/cluster should be rejected")
assert.Contains(t, err.Error(), "already in use")
}
func TestPortConflict_UDPSamePortCluster(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-udp", "udp", testCluster, testCluster, 5432)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "conflicting-udp",
Mode: "udp",
Domain: "conflicting-udp." + testCluster,
ProxyCluster: testCluster,
ListenPort: 5432,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "udp", Port: 9090, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "UDP+UDP on same port/cluster should be rejected")
assert.Contains(t, err.Error(), "already in use")
}
func TestPortConflict_TLSSamePortDifferentDomain(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tls", "tls", "app1.example.com", testCluster, 443)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "new-tls",
Mode: "tls",
Domain: "app2.example.com",
ProxyCluster: testCluster,
ListenPort: 443,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
assert.NoError(t, err, "TLS+TLS on same port with different domains should be allowed (SNI routing)")
}
func TestPortConflict_TLSSamePortSameDomain(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "duplicate-tls",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 443,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "TLS+TLS on same domain should be rejected")
assert.Contains(t, err.Error(), "domain already taken")
}
func TestPortConflict_TLSAndTCPSamePort(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "new-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 443,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
assert.NoError(t, err, "TLS+TCP on same port should be allowed (multiplexed)")
}
func TestAutoAssign_TCPNoListenPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "auto-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.NoError(t, err)
assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax,
"auto-assigned port %d should be in range [%d, %d]", svc.ListenPort, autoAssignPortMin, autoAssignPortMax)
assert.True(t, svc.PortAutoAssigned, "PortAutoAssigned should be set")
}
func TestAutoAssign_TCPCustomPortRejectedWhenNotSupported(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "custom-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 5555,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "TCP with custom port should be rejected when cluster doesn't support it")
assert.Contains(t, err.Error(), "custom ports")
}
func TestAutoAssign_TLSCustomPortAlwaysAllowed(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "custom-tls",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 9999,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
assert.NoError(t, err, "TLS with custom port should always be allowed regardless of cluster capability")
assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden")
assert.False(t, svc.PortAutoAssigned, "PortAutoAssigned should not be set for TLS")
}
func TestAutoAssign_EphemeralOverridesPortWhenNotSupported(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "ephemeral-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 5555,
Enabled: true,
Source: "ephemeral",
SourcePeer: testPeerID,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc)
require.NoError(t, err)
assert.NotEqual(t, uint16(5555), svc.ListenPort, "requested port should be overridden")
assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax,
"auto-assigned port %d should be in range", svc.ListenPort)
assert.True(t, svc.PortAutoAssigned)
}
func TestAutoAssign_EphemeralTLSKeepsCustomPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "ephemeral-tls",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 9999,
Enabled: true,
Source: "ephemeral",
SourcePeer: testPeerID,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc)
require.NoError(t, err)
assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden")
assert.False(t, svc.PortAutoAssigned)
}
func TestAutoAssign_AvoidsExistingPorts(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
existingPort := uint16(20000)
seedService(t, testStore, "existing", "tcp", testCluster, testCluster, existingPort)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "auto-tcp",
Mode: "tcp",
Domain: "auto-tcp." + testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.NoError(t, err)
assert.NotEqual(t, existingPort, svc.ListenPort, "auto-assigned port should not collide with existing")
assert.True(t, svc.PortAutoAssigned)
}
func TestAutoAssign_TCPCustomPortAllowedWhenSupported(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "custom-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 5555,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.NoError(t, err)
assert.Equal(t, uint16(5555), svc.ListenPort, "custom port should be preserved when supported")
assert.False(t, svc.PortAutoAssigned)
}
func TestUpdate_PreservesExistingListenPort(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc-renamed",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err)
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when update sends 0")
}
func TestUpdate_AllowsPortChange(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 54321,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err)
assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied")
}
func TestCreateServiceFromPeer_TCP(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 5432,
Mode: "tcp",
})
require.NoError(t, err)
assert.NotEmpty(t, resp.ServiceName)
assert.Contains(t, resp.Domain, ".test.netbird.io", "TCP uses unique subdomain")
assert.True(t, resp.PortAutoAssigned, "port should be auto-assigned when cluster doesn't support custom ports")
assert.Contains(t, resp.ServiceURL, "tcp://")
}
func TestCreateServiceFromPeer_TCP_CustomPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 5432,
Mode: "tcp",
ListenPort: 15432,
})
require.NoError(t, err)
assert.False(t, resp.PortAutoAssigned)
assert.Contains(t, resp.ServiceURL, ":15432")
}
func TestCreateServiceFromPeer_TCP_DefaultListenPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 5432,
Mode: "tcp",
})
require.NoError(t, err)
// When no explicit listen port, defaults to target port
assert.Contains(t, resp.ServiceURL, ":5432")
assert.False(t, resp.PortAutoAssigned)
}
func TestCreateServiceFromPeer_TLS(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 443,
Mode: "tls",
})
require.NoError(t, err)
assert.Contains(t, resp.Domain, ".test.netbird.io", "TLS uses subdomain")
assert.Contains(t, resp.ServiceURL, "tls://")
assert.Contains(t, resp.ServiceURL, ":443")
// TLS always keeps its port (not port-based protocol for auto-assign)
assert.False(t, resp.PortAutoAssigned)
}
func TestCreateServiceFromPeer_TCP_StopAndRenew(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "tcp",
})
require.NoError(t, err)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
// Renew after stop should fail
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.Error(t, err)
}
func TestCreateServiceFromPeer_L4_RejectsAuth(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "tcp",
Pin: "123456",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "authentication is not supported")
}

View File

@@ -4,7 +4,10 @@ import (
"context"
"fmt"
"math/rand/v2"
"net/http"
"os"
"slices"
"strconv"
"time"
log "github.com/sirupsen/logrus"
@@ -20,6 +23,45 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
const (
defaultAutoAssignPortMin uint16 = 10000
defaultAutoAssignPortMax uint16 = 49151
// EnvAutoAssignPortMin overrides the lower bound for auto-assigned L4 listen ports.
EnvAutoAssignPortMin = "NB_PROXY_PORT_MIN"
// EnvAutoAssignPortMax overrides the upper bound for auto-assigned L4 listen ports.
EnvAutoAssignPortMax = "NB_PROXY_PORT_MAX"
)
var (
autoAssignPortMin = defaultAutoAssignPortMin
autoAssignPortMax = defaultAutoAssignPortMax
)
func init() {
autoAssignPortMin = portFromEnv(EnvAutoAssignPortMin, defaultAutoAssignPortMin)
autoAssignPortMax = portFromEnv(EnvAutoAssignPortMax, defaultAutoAssignPortMax)
if autoAssignPortMin > autoAssignPortMax {
log.Warnf("port range invalid: %s (%d) > %s (%d), using defaults",
EnvAutoAssignPortMin, autoAssignPortMin, EnvAutoAssignPortMax, autoAssignPortMax)
autoAssignPortMin = defaultAutoAssignPortMin
autoAssignPortMax = defaultAutoAssignPortMax
}
}
func portFromEnv(key string, fallback uint16) uint16 {
val := os.Getenv(key)
if val == "" {
return fallback
}
n, err := strconv.ParseUint(val, 10, 16)
if err != nil {
log.Warnf("invalid %s value %q, using default %d: %v", key, val, fallback, err)
return fallback
}
return uint16(n)
}
const unknownHostPlaceholder = "unknown"
// ClusterDeriver derives the proxy cluster from a domain.
@@ -102,6 +144,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
}
return nil
}
@@ -158,6 +201,12 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return fmt.Errorf("hash secrets: %w", err)
}
for i, h := range service.Auth.HeaderAuths {
if h != nil && h.Enabled && h.Value == "" {
return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i)
}
}
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return fmt.Errorf("generate session keys: %w", err)
@@ -168,55 +217,19 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return nil
}
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, ""); err != nil {
if svc.Domain != "" {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
return err
}
}
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
return nil
})
}
// persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Lock the peer row to serialize concurrent creates for the same peer.
// Without this, when no ephemeral rows exist yet, FOR UPDATE on the services
// table returns no rows and acquires no locks, allowing concurrent inserts
// to bypass the per-peer limit.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
if err := m.checkDomainAvailable(ctx, transaction, accountID, svc.Domain, ""); err != nil {
if err := m.checkPortConflict(ctx, transaction, svc); err != nil {
return err
}
@@ -232,17 +245,161 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
})
}
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, accountID, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, accountID, domain)
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service) error {
if !service.IsL4Protocol(svc.Mode) {
return nil
}
customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster)
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
}
svc.ListenPort = 0
}
if svc.ListenPort == 0 {
port, err := m.assignPort(ctx, tx, svc.ProxyCluster)
if err != nil {
return err
}
svc.ListenPort = port
svc.PortAutoAssigned = true
}
return nil
}
// checkPortConflict rejects L4 services that would conflict on the same listener.
// For TCP/UDP: unique per cluster+protocol+port.
// For TLS: unique per cluster+port+domain (SNI routing allows sharing ports).
// Cross-protocol conflicts (TLS vs raw TCP) are intentionally not checked:
// the proxy router multiplexes TLS (via SNI) and raw TCP (via fallback) on the same listener.
func (m *Manager) checkPortConflict(ctx context.Context, transaction store.Store, svc *service.Service) error {
if !service.IsL4Protocol(svc.Mode) || svc.ListenPort == 0 {
return nil
}
existing, err := transaction.GetServicesByClusterAndPort(ctx, store.LockingStrengthUpdate, svc.ProxyCluster, svc.Mode, svc.ListenPort)
if err != nil {
return fmt.Errorf("query port conflicts: %w", err)
}
for _, s := range existing {
if s.ID == svc.ID {
continue
}
// TLS services on the same port are allowed if they have different domains (SNI routing)
if svc.Mode == service.ModeTLS && s.Domain != svc.Domain {
continue
}
return status.Errorf(status.AlreadyExists,
"%s port %d is already in use by service %q on cluster %s",
svc.Mode, svc.ListenPort, s.Name, svc.ProxyCluster)
}
return nil
}
// assignPort picks a random available port on the cluster within the auto-assign range.
func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string) (uint16, error) {
services, err := tx.GetServicesByCluster(ctx, store.LockingStrengthUpdate, cluster)
if err != nil {
return 0, fmt.Errorf("query cluster ports: %w", err)
}
occupied := make(map[uint16]struct{}, len(services))
for _, s := range services {
if s.ListenPort > 0 {
occupied[s.ListenPort] = struct{}{}
}
}
portRange := int(autoAssignPortMax-autoAssignPortMin) + 1
for range 100 {
port := autoAssignPortMin + uint16(rand.IntN(portRange))
if _, taken := occupied[port]; !taken {
return port, nil
}
}
for port := autoAssignPortMin; port <= autoAssignPortMax; port++ {
if _, taken := occupied[port]; !taken {
return port, nil
}
}
return 0, status.Errorf(status.PreconditionFailed, "no available ports on cluster %s", cluster)
}
// persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
return err
}
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, svc); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
return nil
})
}
func (m *Manager) validateEphemeralPreconditions(ctx context.Context, transaction store.Store, accountID, peerID string, svc *service.Service) error {
// Lock the peer row to serialize concurrent creates for the same peer.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
return err
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
return nil
}
// checkDomainAvailable checks that no other service already uses this domain.
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err)
return fmt.Errorf("check existing service: %w", err)
}
return nil
}
if existingService != nil && existingService.ID != excludeServiceID {
return status.Errorf(status.AlreadyExists, "service with domain %s already exists", domain)
return status.Errorf(status.AlreadyExists, "domain already taken")
}
return nil
@@ -285,6 +442,10 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
return err
}
if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil {
return err
}
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
@@ -297,13 +458,22 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
}
m.preserveExistingAuthSecrets(service, existingService)
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
return err
}
m.preserveServiceMetadata(service, existingService)
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := m.ensureL4Port(ctx, transaction, service); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
@@ -314,35 +484,85 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
return &updateInfo, err
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, accountID, service.Domain, service.ID); err != nil {
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil {
return err
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
} else {
service.ProxyCluster = newCluster
svc.ProxyCluster = newCluster
}
}
return nil
}
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
// validateProtocolChange rejects mode changes on update.
// Only empty<->HTTP is allowed; all other transitions are rejected.
func validateProtocolChange(oldMode, newMode string) error {
if newMode == "" || newMode == oldMode {
return nil
}
if isHTTPFamily(oldMode) && isHTTPFamily(newMode) {
return nil
}
return status.Errorf(status.InvalidArgument, "cannot change mode from %q to %q", oldMode, newMode)
}
func isHTTPFamily(mode string) bool {
return mode == "" || mode == "http"
}
func (m *Manager) preserveExistingAuthSecrets(svc, existingService *service.Service) {
if svc.Auth.PasswordAuth != nil && svc.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
svc.Auth.PasswordAuth.Password == "" {
svc.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
if svc.Auth.PinAuth != nil && svc.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
svc.Auth.PinAuth.Pin == "" {
svc.Auth.PinAuth = existingService.Auth.PinAuth
}
preserveHeaderAuthHashes(svc.Auth.HeaderAuths, existingService.Auth.HeaderAuths)
}
// preserveHeaderAuthHashes fills in empty header auth values from the existing
// service so that unchanged secrets are not lost on update.
func preserveHeaderAuthHashes(headers, existing []*service.HeaderAuthConfig) {
if len(headers) == 0 || len(existing) == 0 {
return
}
existingByHeader := make(map[string]string, len(existing))
for _, h := range existing {
if h != nil && h.Value != "" {
existingByHeader[http.CanonicalHeaderKey(h.Header)] = h.Value
}
}
for _, h := range headers {
if h != nil && h.Enabled && h.Value == "" {
if hash, ok := existingByHeader[http.CanonicalHeaderKey(h.Header)]; ok {
h.Value = hash
}
}
}
}
// validateHeaderAuthValues checks that all enabled header auths have a value
// (either freshly provided or preserved from the existing service).
func validateHeaderAuthValues(headers []*service.HeaderAuthConfig) error {
for i, h := range headers {
if h != nil && h.Enabled && h.Value == "" {
return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i)
}
}
return nil
}
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
@@ -351,11 +571,18 @@ func (m *Manager) preserveServiceMetadata(service, existingService *service.Serv
service.SessionPublicKey = existingService.SessionPublicKey
}
func (m *Manager) preserveListenPort(svc, existing *service.Service) {
if existing.ListenPort > 0 && svc.ListenPort == 0 {
svc.ListenPort = existing.ListenPort
svc.PortAutoAssigned = existing.PortAutoAssigned
}
}
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyController.GetOIDCValidationConfig()
switch {
case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
case updateInfo.domainChanged || updateInfo.oldCluster != s.ProxyCluster:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
case !s.Enabled && updateInfo.serviceEnabledChanged:
@@ -385,6 +612,8 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
default:
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
}
}
return nil
@@ -622,6 +851,10 @@ func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerI
return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group")
}
func (m *Manager) resolveDefaultDomain(serviceName string) (string, error) {
return m.buildRandomDomain(serviceName)
}
// CreateServiceFromPeer creates a service initiated by a peer expose request.
// It validates the request, checks expose permissions, enforces the per-peer limit,
// creates the service, and tracks it for TTL-based reaping.
@@ -643,9 +876,9 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
svc.Source = service.SourceEphemeral
if svc.Domain == "" {
domain, err := m.buildRandomDomain(svc.Name)
domain, err := m.resolveDefaultDomain(svc.Name)
if err != nil {
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
return nil, err
}
svc.Domain = domain
}
@@ -686,10 +919,16 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
serviceURL := "https://" + svc.Domain
if service.IsL4Protocol(svc.Mode) {
serviceURL = fmt.Sprintf("%s://%s:%d", svc.Mode, svc.Domain, svc.ListenPort)
}
return &service.ExposeServiceResponse{
ServiceName: svc.Name,
ServiceURL: "https://" + svc.Domain,
Domain: svc.Domain,
ServiceName: svc.Name,
ServiceURL: serviceURL,
Domain: svc.Domain,
PortAutoAssigned: svc.PortAutoAssigned,
}, nil
}
@@ -708,64 +947,47 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
return groupIDs, nil
}
func (m *Manager) buildRandomDomain(name string) (string, error) {
func (m *Manager) getDefaultClusterDomain() (string, error) {
if m.clusterDeriver == nil {
return "", fmt.Errorf("unable to get random domain")
return "", fmt.Errorf("unable to get cluster domain")
}
clusterDomains := m.clusterDeriver.GetClusterDomains()
if len(clusterDomains) == 0 {
return "", fmt.Errorf("no cluster domains found for service %s", name)
return "", fmt.Errorf("no cluster domains available")
}
index := rand.IntN(len(clusterDomains))
domain := name + "." + clusterDomains[index]
return domain, nil
return clusterDomains[rand.IntN(len(clusterDomains))], nil
}
func (m *Manager) buildRandomDomain(name string) (string, error) {
domain, err := m.getDefaultClusterDomain()
if err != nil {
return "", err
}
return name + "." + domain, nil
}
// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service.
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
return m.store.RenewEphemeralService(ctx, accountID, peerID, domain)
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
return m.store.RenewEphemeralService(ctx, accountID, peerID, serviceID)
}
// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB.
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, serviceID, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service %s: %v", serviceID, err)
return err
}
return nil
}
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
// deleteServiceFromPeer deletes a peer-initiated service identified by service ID.
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
if err != nil {
return err
}
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string, expired bool) error {
activityCode := activity.PeerServiceUnexposed
if expired {
activityCode = activity.PeerServiceExposeExpired
}
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
}
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
svc, err := m.store.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
return nil, err
}
if svc.Source != service.SourceEphemeral {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
}
if svc.SourcePeer != peerID {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
}
return svc, nil
return m.deletePeerService(ctx, accountID, peerID, serviceID, activityCode)
}
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {

View File

@@ -72,7 +72,6 @@ func TestInitializeServiceForCreate(t *testing.T) {
func TestCheckDomainAvailable(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
tests := []struct {
name string
@@ -88,7 +87,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "available.com").
GetServiceByDomain(ctx, "available.com").
Return(nil, status.Errorf(status.NotFound, "not found"))
},
expectedError: false,
@@ -99,7 +98,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
GetServiceByDomain(ctx, "exists.com").
Return(&rpservice.Service{ID: "existing-id", Domain: "exists.com"}, nil)
},
expectedError: true,
@@ -111,7 +110,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "service-123",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
GetServiceByDomain(ctx, "exists.com").
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: false,
@@ -122,7 +121,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "service-456",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "exists.com").
GetServiceByDomain(ctx, "exists.com").
Return(&rpservice.Service{ID: "service-123", Domain: "exists.com"}, nil)
},
expectedError: true,
@@ -134,7 +133,7 @@ func TestCheckDomainAvailable(t *testing.T) {
excludeServiceID: "",
setupMock: func(ms *store.MockStore) {
ms.EXPECT().
GetServiceByDomain(ctx, accountID, "error.com").
GetServiceByDomain(ctx, "error.com").
Return(nil, errors.New("database error"))
},
expectedError: true,
@@ -150,7 +149,7 @@ func TestCheckDomainAvailable(t *testing.T) {
tt.setupMock(mockStore)
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, tt.domain, tt.excludeServiceID)
err := mgr.checkDomainAvailable(ctx, mockStore, tt.domain, tt.excludeServiceID)
if tt.expectedError {
require.Error(t, err)
@@ -168,7 +167,6 @@ func TestCheckDomainAvailable(t *testing.T) {
func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
ctx := context.Background()
accountID := "test-account"
t.Run("empty domain", func(t *testing.T) {
ctrl := gomock.NewController(t)
@@ -176,11 +174,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "").
GetServiceByDomain(ctx, "").
Return(nil, status.Errorf(status.NotFound, "not found"))
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "", "")
err := mgr.checkDomainAvailable(ctx, mockStore, "", "")
assert.NoError(t, err)
})
@@ -191,11 +189,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "test.com").
GetServiceByDomain(ctx, "test.com").
Return(&rpservice.Service{ID: "some-id", Domain: "test.com"}, nil)
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "test.com", "")
err := mgr.checkDomainAvailable(ctx, mockStore, "test.com", "")
assert.Error(t, err)
sErr, ok := status.FromError(err)
@@ -209,11 +207,11 @@ func TestCheckDomainAvailable_EdgeCases(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().
GetServiceByDomain(ctx, accountID, "nil.com").
GetServiceByDomain(ctx, "nil.com").
Return(nil, nil)
mgr := &Manager{}
err := mgr.checkDomainAvailable(ctx, mockStore, accountID, "nil.com", "")
err := mgr.checkDomainAvailable(ctx, mockStore, "nil.com", "")
assert.NoError(t, err)
})
@@ -241,7 +239,7 @@ func TestPersistNewService(t *testing.T) {
// Create another mock for the transaction
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByDomain(ctx, accountID, "new.com").
GetServiceByDomain(ctx, "new.com").
Return(nil, status.Errorf(status.NotFound, "not found"))
txMock.EXPECT().
CreateService(ctx, service).
@@ -272,7 +270,7 @@ func TestPersistNewService(t *testing.T) {
DoAndReturn(func(ctx context.Context, fn func(store.Store) error) error {
txMock := store.NewMockStore(ctrl)
txMock.EXPECT().
GetServiceByDomain(ctx, accountID, "existing.com").
GetServiceByDomain(ctx, "existing.com").
Return(&rpservice.Service{ID: "other-id", Domain: "existing.com"}, nil)
return fn(txMock)
@@ -425,8 +423,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper()
tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(srv.Close)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
return srv
}
@@ -703,8 +702,9 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(proxySrv.Close)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
@@ -800,8 +800,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -811,7 +811,7 @@ func TestCreateServiceFromPeer(t *testing.T) {
assert.NotEmpty(t, resp.ServiceURL, "service URL should be set")
// Verify service is persisted in store
persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
persisted, err := testStore.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, err)
assert.Equal(t, resp.Domain, persisted.Domain)
assert.Equal(t, rpservice.SourceEphemeral, persisted.Source, "source should be ephemeral")
@@ -823,9 +823,9 @@ func TestCreateServiceFromPeer(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 80,
Protocol: "http",
Domain: "example.com",
Port: 80,
Mode: "http",
Domain: "example.com",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -844,8 +844,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
require.NoError(t, err)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -857,8 +857,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 0,
Protocol: "http",
Port: 0,
Mode: "http",
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -875,62 +875,52 @@ func TestExposeServiceRequestValidate(t *testing.T) {
}{
{
name: "valid http request",
req: rpservice.ExposeServiceRequest{Port: 8080, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: 8080, Mode: "http"},
wantErr: "",
},
{
name: "valid https request with pin",
req: rpservice.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
wantErr: "",
name: "https mode rejected",
req: rpservice.ExposeServiceRequest{Port: 443, Mode: "https", Pin: "123456"},
wantErr: "unsupported mode",
},
{
name: "port zero rejected",
req: rpservice.ExposeServiceRequest{Port: 0, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: 0, Mode: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "negative port rejected",
req: rpservice.ExposeServiceRequest{Port: -1, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "port above 65535 rejected",
req: rpservice.ExposeServiceRequest{Port: 65536, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "unsupported protocol",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
wantErr: "unsupported protocol",
name: "unsupported mode",
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "ftp"},
wantErr: "unsupported mode",
},
{
name: "invalid pin format",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "abc"},
wantErr: "invalid pin",
},
{
name: "pin too short",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "12345"},
wantErr: "invalid pin",
},
{
name: "valid 6-digit pin",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "000000"},
wantErr: "",
},
{
name: "empty user group name",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", UserGroups: []string{"valid", ""}},
wantErr: "user group name cannot be empty",
},
{
name: "invalid name prefix",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "INVALID"},
wantErr: "invalid name prefix",
},
{
name: "valid name prefix",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "my-service"},
wantErr: "",
},
}
@@ -963,32 +953,33 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
// First create a service
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
// Delete by domain using unexported method
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, false)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, false)
require.NoError(t, err)
// Verify service is deleted
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "service should be deleted")
})
t.Run("expire uses correct activity", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
mgr, testStore := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, true)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, true)
require.NoError(t, err)
})
}
@@ -1000,16 +991,17 @@ func TestStopServiceFromPeer(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "service should be deleted")
})
}
@@ -1019,8 +1011,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
@@ -1028,7 +1020,7 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, int64(1), count, "one ephemeral service should exist after create")
svc, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
svc, err := testStore.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, err)
err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID)
@@ -1039,8 +1031,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
assert.Equal(t, int64(0), count, "ephemeral service should be deleted after API delete")
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9090,
Protocol: "http",
Port: 9090,
Mode: "http",
})
assert.NoError(t, err, "new expose should succeed after API delete")
}
@@ -1051,8 +1043,8 @@ func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) {
for i := range 3 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
Port: uint16(8080 + i),
Mode: "http",
})
require.NoError(t, err)
}
@@ -1073,21 +1065,22 @@ func TestRenewServiceFromPeer(t *testing.T) {
ctx := context.Background()
t.Run("renews tracked expose", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
})
t.Run("fails for untracked domain", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id")
require.Error(t, err)
})
}
@@ -1132,8 +1125,9 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
t.Cleanup(proxySrv.Close)
pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100)
require.NoError(t, err)
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err)
@@ -1182,3 +1176,33 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
require.NoError(t, err)
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
}
func TestValidateProtocolChange(t *testing.T) {
tests := []struct {
name string
oldP string
newP string
wantErr bool
}{
{"empty to http", "", "http", false},
{"http to http", "http", "http", false},
{"same protocol", "tcp", "tcp", false},
{"empty new proto", "tcp", "", false},
{"http to tcp", "http", "tcp", true},
{"tcp to udp", "tcp", "udp", true},
{"tls to http", "tls", "http", true},
{"udp to tls", "udp", "tls", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateProtocolChange(tt.oldP, tt.newP)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "cannot change mode")
} else {
require.NoError(t, err)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -44,7 +44,7 @@ func TestValidate_EmptyDomain(t *testing.T) {
func TestValidate_NoTargets(t *testing.T) {
rp := validProxy()
rp.Targets = nil
assert.ErrorContains(t, rp.Validate(), "at least one target")
assert.ErrorContains(t, rp.Validate(), "at least one target is required")
}
func TestValidate_EmptyTargetId(t *testing.T) {
@@ -120,9 +120,9 @@ func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
}{
{"valid 30s", 30 * time.Second, ""},
{"valid 2m", 2 * time.Minute, ""},
{"valid 10m", 10 * time.Minute, ""},
{"zero is fine", 0, ""},
{"negative", -1 * time.Second, "must be positive"},
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -273,7 +273,7 @@ func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) {
func TestIsDefaultPort(t *testing.T) {
tests := []struct {
scheme string
port int
port uint16
want bool
}{
{"http", 80, true},
@@ -299,7 +299,7 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) {
name string
protocol string
host string
port int
port uint16
wantTarget string
}{
{
@@ -645,8 +645,8 @@ func TestGenerateExposeName(t *testing.T) {
func TestExposeServiceRequest_ToService(t *testing.T) {
t.Run("basic HTTP service", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
service := req.ToService("account-1", "peer-1", "mysvc")
@@ -658,7 +658,7 @@ func TestExposeServiceRequest_ToService(t *testing.T) {
require.Len(t, service.Targets, 1)
target := service.Targets[0]
assert.Equal(t, 8080, target.Port)
assert.Equal(t, uint16(8080), target.Port)
assert.Equal(t, "http", target.Protocol)
assert.Equal(t, "peer-1", target.TargetId)
assert.Equal(t, TargetTypePeer, target.TargetType)
@@ -730,3 +730,182 @@ func TestExposeServiceRequest_ToService(t *testing.T) {
require.NotNil(t, service.Auth.BearerAuth)
})
}
func TestValidate_TLSOnly(t *testing.T) {
rp := &Service{
Name: "tls-svc",
Mode: "tls",
Domain: "example.com",
ListenPort: 8443,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true},
},
}
require.NoError(t, rp.Validate())
}
func TestValidate_TLSMissingListenPort(t *testing.T) {
rp := &Service{
Name: "tls-svc",
Mode: "tls",
Domain: "example.com",
ListenPort: 0,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true},
},
}
assert.ErrorContains(t, rp.Validate(), "listen_port is required")
}
func TestValidate_TLSMissingDomain(t *testing.T) {
rp := &Service{
Name: "tls-svc",
Mode: "tls",
ListenPort: 8443,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true},
},
}
assert.ErrorContains(t, rp.Validate(), "domain is required")
}
func TestValidate_TCPValid(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
ListenPort: 5432,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
},
}
require.NoError(t, rp.Validate())
}
func TestValidate_TCPMissingListenPort(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
},
}
require.NoError(t, rp.Validate(), "TCP with listen_port=0 is valid (auto-assigned by manager)")
}
func TestValidate_L4MultipleTargets(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
ListenPort: 5432,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
{TargetId: "peer-2", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
},
}
assert.ErrorContains(t, rp.Validate(), "exactly one target")
}
func TestValidate_L4TargetMissingPort(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
ListenPort: 5432,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 0, Enabled: true},
},
}
assert.ErrorContains(t, rp.Validate(), "port is required")
}
func TestValidate_TLSInvalidTargetType(t *testing.T) {
rp := &Service{
Name: "tls-svc",
Mode: "tls",
Domain: "example.com",
ListenPort: 443,
Targets: []*Target{
{TargetId: "peer-1", TargetType: "invalid", Protocol: "tcp", Port: 443, Enabled: true},
},
}
assert.Error(t, rp.Validate())
}
func TestValidate_TLSSubnetValid(t *testing.T) {
rp := &Service{
Name: "tls-subnet",
Mode: "tls",
Domain: "example.com",
ListenPort: 8443,
Targets: []*Target{
{TargetId: "subnet-1", TargetType: TargetTypeSubnet, Protocol: "tcp", Port: 443, Host: "10.0.0.5", Enabled: true},
},
}
require.NoError(t, rp.Validate())
}
func TestValidate_HTTPProxyProtocolRejected(t *testing.T) {
rp := validProxy()
rp.Targets[0].ProxyProtocol = true
assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for HTTP")
}
func TestValidate_UDPProxyProtocolRejected(t *testing.T) {
rp := &Service{
Name: "udp-svc",
Mode: "udp",
Domain: "cluster.test",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "udp", Port: 5432, Enabled: true, ProxyProtocol: true},
},
}
assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for UDP")
}
func TestValidate_TCPProxyProtocolAllowed(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
ListenPort: 5432,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true, ProxyProtocol: true},
},
}
require.NoError(t, rp.Validate())
}
func TestExposeServiceRequest_Validate_L4RejectsAuth(t *testing.T) {
tests := []struct {
name string
req ExposeServiceRequest
}{
{
name: "tcp with pin",
req: ExposeServiceRequest{Port: 8080, Mode: "tcp", Pin: "123456"},
},
{
name: "udp with password",
req: ExposeServiceRequest{Port: 8080, Mode: "udp", Password: "secret"},
},
{
name: "tls with user groups",
req: ExposeServiceRequest{Port: 443, Mode: "tls", UserGroups: []string{"admins"}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.req.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "authentication is not supported")
})
}
}
func TestExposeServiceRequest_Validate_HTTPAllowsAuth(t *testing.T) {
req := ExposeServiceRequest{Port: 8080, Mode: "http", Pin: "123456"}
require.NoError(t, req.Validate())
}