[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)

This commit is contained in:
Viktor Liu
2026-03-14 01:36:44 +08:00
committed by GitHub
parent fe9b844511
commit 3e6baea405
90 changed files with 9611 additions and 1397 deletions

View File

@@ -22,7 +22,7 @@ type Manager interface {
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error)
RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StartExposeReaper(ctx context.Context)
}

View File

@@ -211,17 +211,17 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter
}
// RenewServiceFromPeer mocks base method.
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain)
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer.
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, serviceID)
}
// SetCertificateIssuedAt mocks base method.
@@ -265,17 +265,17 @@ func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Ca
}
// StopServiceFromPeer mocks base method.
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain)
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// StopServiceFromPeer indicates an expected call of StopServiceFromPeer.
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, serviceID)
}
// UpdateService mocks base method.

View File

@@ -11,19 +11,22 @@ import (
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager rpservice.Manager
manager rpservice.Manager
permissionsManager permissions.Manager
}
// RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, permissionsManager permissions.Manager, router *mux.Router) {
h := &handler{
manager: manager,
manager: manager,
permissionsManager: permissionsManager,
}
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()

View File

@@ -18,8 +18,8 @@ func TestReapExpiredExposes(t *testing.T) {
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
@@ -28,8 +28,8 @@ func TestReapExpiredExposes(t *testing.T) {
// Create a non-expired service
resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8081,
Protocol: "http",
Port: 8081,
Mode: "http",
})
require.NoError(t, err)
@@ -49,15 +49,16 @@ func TestReapAlreadyDeletedService(t *testing.T) {
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Delete the service before reaping
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
// Reaping should handle the already-deleted service gracefully
@@ -70,8 +71,8 @@ func TestConcurrentReapAndRenew(t *testing.T) {
for i := range 5 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
Port: uint16(8080 + i),
Mode: "http",
})
require.NoError(t, err)
}
@@ -108,17 +109,19 @@ func TestRenewEphemeralService(t *testing.T) {
t.Run("renew succeeds for active service", func(t *testing.T) {
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8082,
Protocol: "http",
Port: 8082,
Mode: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svc, lookupErr := mgr.store.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, lookupErr)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID)
require.NoError(t, err)
})
t.Run("renew fails for nonexistent domain", func(t *testing.T) {
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id")
require.Error(t, err)
assert.Contains(t, err.Error(), "no active expose session")
})
@@ -133,8 +136,8 @@ func TestCountAndExistsEphemeralServices(t *testing.T) {
assert.Equal(t, int64(0), count)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8083,
Protocol: "http",
Port: 8083,
Mode: "http",
})
require.NoError(t, err)
@@ -157,15 +160,15 @@ func TestMaxExposesPerPeerEnforced(t *testing.T) {
for i := range maxExposesPerPeer {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8090 + i,
Protocol: "http",
Port: uint16(8090 + i),
Mode: "http",
})
require.NoError(t, err, "expose %d should succeed", i)
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9999,
Protocol: "http",
Port: 9999,
Mode: "http",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "maximum number of active expose sessions")
@@ -176,8 +179,8 @@ func TestReapSkipsRenewedService(t *testing.T) {
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8086,
Protocol: "http",
Port: 8086,
Mode: "http",
})
require.NoError(t, err)
@@ -185,7 +188,9 @@ func TestReapSkipsRenewedService(t *testing.T) {
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Renew it before the reaper runs
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svc, err := testStore.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID)
require.NoError(t, err)
// Reaper should skip it because the re-check sees a fresh timestamp
@@ -195,6 +200,14 @@ func TestReapSkipsRenewedService(t *testing.T) {
require.NoError(t, err, "renewed service should survive reaping")
}
// resolveServiceIDByDomain looks up a service ID by domain in tests.
func resolveServiceIDByDomain(t *testing.T, s store.Store, domain string) string {
t.Helper()
svc, err := s.GetServiceByDomain(context.Background(), domain)
require.NoError(t, err)
return svc.ID
}
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
t.Helper()

View File

@@ -0,0 +1,582 @@
package manager
import (
"context"
"net"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
const testCluster = "test-cluster"
func boolPtr(v bool) *bool { return &v }
// setupL4Test creates a manager with a mock proxy controller for L4 port tests.
func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Store, *proxy.MockController) {
t.Helper()
ctrl := gomock.NewController(t)
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
err = testStore.SaveAccount(ctx, &types.Account{
Id: testAccountID,
CreatedBy: testUserID,
Settings: &types.Settings{
PeerExposeEnabled: true,
PeerExposeGroups: []string{testGroupID},
},
Users: map[string]*types.User{
testUserID: {
Id: testUserID,
AccountID: testAccountID,
Role: types.UserRoleAdmin,
},
},
Peers: map[string]*nbpeer.Peer{
testPeerID: {
ID: testPeerID,
AccountID: testAccountID,
Key: "test-key",
DNSLabel: "test-peer",
Name: "test-peer",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
},
},
Groups: map[string]*types.Group{
testGroupID: {
ID: testGroupID,
AccountID: testAccountID,
Name: "Expose Group",
},
},
})
require.NoError(t, err)
err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID)
require.NoError(t, err)
mockCtrl := proxy.NewMockController(ctrl)
mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes()
mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
},
}
mgr := &Manager{
store: testStore,
accountManager: accountMgr,
permissionsManager: permissions.NewManager(testStore),
proxyController: mockCtrl,
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
}
mgr.exposeReaper = &exposeReaper{manager: mgr}
return mgr, testStore, mockCtrl
}
// seedService creates a service directly in the store for test setup.
func seedService(t *testing.T, s store.Store, name, protocol, domain, cluster string, port uint16) *rpservice.Service {
t.Helper()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: name,
Mode: protocol,
Domain: domain,
ProxyCluster: cluster,
ListenPort: port,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: protocol, Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := s.CreateService(context.Background(), svc)
require.NoError(t, err)
return svc
}
func TestPortConflict_TCPSamePortCluster(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tcp", "tcp", testCluster, testCluster, 5432)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "conflicting-tcp",
Mode: "tcp",
Domain: "conflicting-tcp." + testCluster,
ProxyCluster: testCluster,
ListenPort: 5432,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "TCP+TCP on same port/cluster should be rejected")
assert.Contains(t, err.Error(), "already in use")
}
func TestPortConflict_UDPSamePortCluster(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-udp", "udp", testCluster, testCluster, 5432)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "conflicting-udp",
Mode: "udp",
Domain: "conflicting-udp." + testCluster,
ProxyCluster: testCluster,
ListenPort: 5432,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "udp", Port: 9090, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "UDP+UDP on same port/cluster should be rejected")
assert.Contains(t, err.Error(), "already in use")
}
func TestPortConflict_TLSSamePortDifferentDomain(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tls", "tls", "app1.example.com", testCluster, 443)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "new-tls",
Mode: "tls",
Domain: "app2.example.com",
ProxyCluster: testCluster,
ListenPort: 443,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
assert.NoError(t, err, "TLS+TLS on same port with different domains should be allowed (SNI routing)")
}
func TestPortConflict_TLSSamePortSameDomain(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "duplicate-tls",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 443,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "TLS+TLS on same domain should be rejected")
assert.Contains(t, err.Error(), "domain already taken")
}
func TestPortConflict_TLSAndTCPSamePort(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "new-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 443,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
assert.NoError(t, err, "TLS+TCP on same port should be allowed (multiplexed)")
}
func TestAutoAssign_TCPNoListenPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "auto-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.NoError(t, err)
assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax,
"auto-assigned port %d should be in range [%d, %d]", svc.ListenPort, autoAssignPortMin, autoAssignPortMax)
assert.True(t, svc.PortAutoAssigned, "PortAutoAssigned should be set")
}
func TestAutoAssign_TCPCustomPortRejectedWhenNotSupported(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "custom-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 5555,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "TCP with custom port should be rejected when cluster doesn't support it")
assert.Contains(t, err.Error(), "custom ports")
}
func TestAutoAssign_TLSCustomPortAlwaysAllowed(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "custom-tls",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 9999,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
assert.NoError(t, err, "TLS with custom port should always be allowed regardless of cluster capability")
assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden")
assert.False(t, svc.PortAutoAssigned, "PortAutoAssigned should not be set for TLS")
}
func TestAutoAssign_EphemeralOverridesPortWhenNotSupported(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "ephemeral-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 5555,
Enabled: true,
Source: "ephemeral",
SourcePeer: testPeerID,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc)
require.NoError(t, err)
assert.NotEqual(t, uint16(5555), svc.ListenPort, "requested port should be overridden")
assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax,
"auto-assigned port %d should be in range", svc.ListenPort)
assert.True(t, svc.PortAutoAssigned)
}
func TestAutoAssign_EphemeralTLSKeepsCustomPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "ephemeral-tls",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 9999,
Enabled: true,
Source: "ephemeral",
SourcePeer: testPeerID,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc)
require.NoError(t, err)
assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden")
assert.False(t, svc.PortAutoAssigned)
}
func TestAutoAssign_AvoidsExistingPorts(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
existingPort := uint16(20000)
seedService(t, testStore, "existing", "tcp", testCluster, testCluster, existingPort)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "auto-tcp",
Mode: "tcp",
Domain: "auto-tcp." + testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.NoError(t, err)
assert.NotEqual(t, existingPort, svc.ListenPort, "auto-assigned port should not collide with existing")
assert.True(t, svc.PortAutoAssigned)
}
func TestAutoAssign_TCPCustomPortAllowedWhenSupported(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "custom-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 5555,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.NoError(t, err)
assert.Equal(t, uint16(5555), svc.ListenPort, "custom port should be preserved when supported")
assert.False(t, svc.PortAutoAssigned)
}
func TestUpdate_PreservesExistingListenPort(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc-renamed",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err)
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when update sends 0")
}
func TestUpdate_AllowsPortChange(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 54321,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err)
assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied")
}
func TestCreateServiceFromPeer_TCP(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 5432,
Mode: "tcp",
})
require.NoError(t, err)
assert.NotEmpty(t, resp.ServiceName)
assert.Contains(t, resp.Domain, ".test.netbird.io", "TCP uses unique subdomain")
assert.True(t, resp.PortAutoAssigned, "port should be auto-assigned when cluster doesn't support custom ports")
assert.Contains(t, resp.ServiceURL, "tcp://")
}
func TestCreateServiceFromPeer_TCP_CustomPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 5432,
Mode: "tcp",
ListenPort: 15432,
})
require.NoError(t, err)
assert.False(t, resp.PortAutoAssigned)
assert.Contains(t, resp.ServiceURL, ":15432")
}
func TestCreateServiceFromPeer_TCP_DefaultListenPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 5432,
Mode: "tcp",
})
require.NoError(t, err)
// When no explicit listen port, defaults to target port
assert.Contains(t, resp.ServiceURL, ":5432")
assert.False(t, resp.PortAutoAssigned)
}
func TestCreateServiceFromPeer_TLS(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 443,
Mode: "tls",
})
require.NoError(t, err)
assert.Contains(t, resp.Domain, ".test.netbird.io", "TLS uses subdomain")
assert.Contains(t, resp.ServiceURL, "tls://")
assert.Contains(t, resp.ServiceURL, ":443")
// TLS always keeps its port (not port-based protocol for auto-assign)
assert.False(t, resp.PortAutoAssigned)
}
func TestCreateServiceFromPeer_TCP_StopAndRenew(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "tcp",
})
require.NoError(t, err)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
// Renew after stop should fail
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.Error(t, err)
}
func TestCreateServiceFromPeer_L4_RejectsAuth(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "tcp",
Pin: "123456",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "authentication is not supported")
}

View File

@@ -4,7 +4,9 @@ import (
"context"
"fmt"
"math/rand/v2"
"os"
"slices"
"strconv"
"time"
log "github.com/sirupsen/logrus"
@@ -23,6 +25,45 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
const (
defaultAutoAssignPortMin uint16 = 10000
defaultAutoAssignPortMax uint16 = 49151
// EnvAutoAssignPortMin overrides the lower bound for auto-assigned L4 listen ports.
EnvAutoAssignPortMin = "NB_PROXY_PORT_MIN"
// EnvAutoAssignPortMax overrides the upper bound for auto-assigned L4 listen ports.
EnvAutoAssignPortMax = "NB_PROXY_PORT_MAX"
)
var (
autoAssignPortMin = defaultAutoAssignPortMin
autoAssignPortMax = defaultAutoAssignPortMax
)
func init() {
autoAssignPortMin = portFromEnv(EnvAutoAssignPortMin, defaultAutoAssignPortMin)
autoAssignPortMax = portFromEnv(EnvAutoAssignPortMax, defaultAutoAssignPortMax)
if autoAssignPortMin > autoAssignPortMax {
log.Warnf("port range invalid: %s (%d) > %s (%d), using defaults",
EnvAutoAssignPortMin, autoAssignPortMin, EnvAutoAssignPortMax, autoAssignPortMax)
autoAssignPortMin = defaultAutoAssignPortMin
autoAssignPortMax = defaultAutoAssignPortMax
}
}
func portFromEnv(key string, fallback uint16) uint16 {
val := os.Getenv(key)
if val == "" {
return fallback
}
n, err := strconv.ParseUint(val, 10, 16)
if err != nil {
log.Warnf("invalid %s value %q, using default %d: %v", key, val, fallback, err)
return fallback
}
return uint16(n)
}
const unknownHostPlaceholder = "unknown"
// ClusterDeriver derives the proxy cluster from a domain.
@@ -115,6 +156,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
}
return nil
}
@@ -197,55 +239,19 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return nil
}
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil {
if svc.Domain != "" {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
return err
}
}
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
return nil
})
}
// persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Lock the peer row to serialize concurrent creates for the same peer.
// Without this, when no ephemeral rows exist yet, FOR UPDATE on the services
// table returns no rows and acquires no locks, allowing concurrent inserts
// to bypass the per-peer limit.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
if err := m.checkPortConflict(ctx, transaction, svc); err != nil {
return err
}
@@ -261,11 +267,155 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
})
}
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service) error {
if !service.IsL4Protocol(svc.Mode) {
return nil
}
customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster)
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
}
svc.ListenPort = 0
}
if svc.ListenPort == 0 {
port, err := m.assignPort(ctx, tx, svc.ProxyCluster)
if err != nil {
return err
}
svc.ListenPort = port
svc.PortAutoAssigned = true
}
return nil
}
// checkPortConflict rejects L4 services that would conflict on the same listener.
// For TCP/UDP: unique per cluster+protocol+port.
// For TLS: unique per cluster+port+domain (SNI routing allows sharing ports).
// Cross-protocol conflicts (TLS vs raw TCP) are intentionally not checked:
// the proxy router multiplexes TLS (via SNI) and raw TCP (via fallback) on the same listener.
func (m *Manager) checkPortConflict(ctx context.Context, transaction store.Store, svc *service.Service) error {
if !service.IsL4Protocol(svc.Mode) || svc.ListenPort == 0 {
return nil
}
existing, err := transaction.GetServicesByClusterAndPort(ctx, store.LockingStrengthUpdate, svc.ProxyCluster, svc.Mode, svc.ListenPort)
if err != nil {
return fmt.Errorf("query port conflicts: %w", err)
}
for _, s := range existing {
if s.ID == svc.ID {
continue
}
// TLS services on the same port are allowed if they have different domains (SNI routing)
if svc.Mode == service.ModeTLS && s.Domain != svc.Domain {
continue
}
return status.Errorf(status.AlreadyExists,
"%s port %d is already in use by service %q on cluster %s",
svc.Mode, svc.ListenPort, s.Name, svc.ProxyCluster)
}
return nil
}
// assignPort picks a random available port on the cluster within the auto-assign range.
func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string) (uint16, error) {
services, err := tx.GetServicesByCluster(ctx, store.LockingStrengthUpdate, cluster)
if err != nil {
return 0, fmt.Errorf("query cluster ports: %w", err)
}
occupied := make(map[uint16]struct{}, len(services))
for _, s := range services {
if s.ListenPort > 0 {
occupied[s.ListenPort] = struct{}{}
}
}
portRange := int(autoAssignPortMax-autoAssignPortMin) + 1
for range 100 {
port := autoAssignPortMin + uint16(rand.IntN(portRange))
if _, taken := occupied[port]; !taken {
return port, nil
}
}
for port := autoAssignPortMin; port <= autoAssignPortMax; port++ {
if _, taken := occupied[port]; !taken {
return port, nil
}
}
return 0, status.Errorf(status.PreconditionFailed, "no available ports on cluster %s", cluster)
}
// persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
return err
}
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, svc); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
return nil
})
}
func (m *Manager) validateEphemeralPreconditions(ctx context.Context, transaction store.Store, accountID, peerID string, svc *service.Service) error {
// Lock the peer row to serialize concurrent creates for the same peer.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
return err
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
return nil
}
// checkDomainAvailable checks that no other service already uses this domain.
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err)
return fmt.Errorf("check existing service: %w", err)
}
return nil
}
@@ -322,6 +472,10 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
return err
}
if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil {
return err
}
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
@@ -335,12 +489,18 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
m.preserveExistingAuthSecrets(service, existingService)
m.preserveServiceMetadata(service, existingService)
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := m.ensureL4Port(ctx, transaction, service); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
@@ -351,23 +511,39 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
return &updateInfo, err
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil {
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil {
return err
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
} else {
service.ProxyCluster = newCluster
svc.ProxyCluster = newCluster
}
}
return nil
}
// validateProtocolChange rejects mode changes on update.
// Only empty<->HTTP is allowed; all other transitions are rejected.
func validateProtocolChange(oldMode, newMode string) error {
if newMode == "" || newMode == oldMode {
return nil
}
if isHTTPFamily(oldMode) && isHTTPFamily(newMode) {
return nil
}
return status.Errorf(status.InvalidArgument, "cannot change mode from %q to %q", oldMode, newMode)
}
func isHTTPFamily(mode string) bool {
return mode == "" || mode == "http"
}
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
@@ -388,6 +564,13 @@ func (m *Manager) preserveServiceMetadata(service, existingService *service.Serv
service.SessionPublicKey = existingService.SessionPublicKey
}
func (m *Manager) preserveListenPort(svc, existing *service.Service) {
if existing.ListenPort > 0 && svc.ListenPort == 0 {
svc.ListenPort = existing.ListenPort
svc.PortAutoAssigned = existing.PortAutoAssigned
}
}
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyController.GetOIDCValidationConfig()
@@ -675,6 +858,10 @@ func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerI
return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group")
}
func (m *Manager) resolveDefaultDomain(serviceName string) (string, error) {
return m.buildRandomDomain(serviceName)
}
// CreateServiceFromPeer creates a service initiated by a peer expose request.
// It validates the request, checks expose permissions, enforces the per-peer limit,
// creates the service, and tracks it for TTL-based reaping.
@@ -696,9 +883,9 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
svc.Source = service.SourceEphemeral
if svc.Domain == "" {
domain, err := m.buildRandomDomain(svc.Name)
domain, err := m.resolveDefaultDomain(svc.Name)
if err != nil {
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
return nil, err
}
svc.Domain = domain
}
@@ -739,10 +926,16 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
serviceURL := "https://" + svc.Domain
if service.IsL4Protocol(svc.Mode) {
serviceURL = fmt.Sprintf("%s://%s:%d", svc.Mode, svc.Domain, svc.ListenPort)
}
return &service.ExposeServiceResponse{
ServiceName: svc.Name,
ServiceURL: "https://" + svc.Domain,
Domain: svc.Domain,
ServiceName: svc.Name,
ServiceURL: serviceURL,
Domain: svc.Domain,
PortAutoAssigned: svc.PortAutoAssigned,
}, nil
}
@@ -761,64 +954,47 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
return groupIDs, nil
}
func (m *Manager) buildRandomDomain(name string) (string, error) {
func (m *Manager) getDefaultClusterDomain() (string, error) {
if m.clusterDeriver == nil {
return "", fmt.Errorf("unable to get random domain")
return "", fmt.Errorf("unable to get cluster domain")
}
clusterDomains := m.clusterDeriver.GetClusterDomains()
if len(clusterDomains) == 0 {
return "", fmt.Errorf("no cluster domains found for service %s", name)
return "", fmt.Errorf("no cluster domains available")
}
index := rand.IntN(len(clusterDomains))
domain := name + "." + clusterDomains[index]
return domain, nil
return clusterDomains[rand.IntN(len(clusterDomains))], nil
}
func (m *Manager) buildRandomDomain(name string) (string, error) {
domain, err := m.getDefaultClusterDomain()
if err != nil {
return "", err
}
return name + "." + domain, nil
}
// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service.
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
return m.store.RenewEphemeralService(ctx, accountID, peerID, domain)
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
return m.store.RenewEphemeralService(ctx, accountID, peerID, serviceID)
}
// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB.
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, serviceID, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service %s: %v", serviceID, err)
return err
}
return nil
}
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
// deleteServiceFromPeer deletes a peer-initiated service identified by service ID.
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
if err != nil {
return err
}
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string, expired bool) error {
activityCode := activity.PeerServiceUnexposed
if expired {
activityCode = activity.PeerServiceExposeExpired
}
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
}
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
svc, err := m.store.GetServiceByDomain(ctx, domain)
if err != nil {
return nil, err
}
if svc.Source != service.SourceEphemeral {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
}
if svc.SourcePeer != peerID {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
}
return svc, nil
return m.deletePeerService(ctx, accountID, peerID, serviceID, activityCode)
}
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {

View File

@@ -803,8 +803,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -826,9 +826,9 @@ func TestCreateServiceFromPeer(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 80,
Protocol: "http",
Domain: "example.com",
Port: 80,
Mode: "http",
Domain: "example.com",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -847,8 +847,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
require.NoError(t, err)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -860,8 +860,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 0,
Protocol: "http",
Port: 0,
Mode: "http",
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -878,62 +878,52 @@ func TestExposeServiceRequestValidate(t *testing.T) {
}{
{
name: "valid http request",
req: rpservice.ExposeServiceRequest{Port: 8080, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: 8080, Mode: "http"},
wantErr: "",
},
{
name: "valid https request with pin",
req: rpservice.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
wantErr: "",
name: "https mode rejected",
req: rpservice.ExposeServiceRequest{Port: 443, Mode: "https", Pin: "123456"},
wantErr: "unsupported mode",
},
{
name: "port zero rejected",
req: rpservice.ExposeServiceRequest{Port: 0, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: 0, Mode: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "negative port rejected",
req: rpservice.ExposeServiceRequest{Port: -1, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "port above 65535 rejected",
req: rpservice.ExposeServiceRequest{Port: 65536, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "unsupported protocol",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
wantErr: "unsupported protocol",
name: "unsupported mode",
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "ftp"},
wantErr: "unsupported mode",
},
{
name: "invalid pin format",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "abc"},
wantErr: "invalid pin",
},
{
name: "pin too short",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "12345"},
wantErr: "invalid pin",
},
{
name: "valid 6-digit pin",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "000000"},
wantErr: "",
},
{
name: "empty user group name",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", UserGroups: []string{"valid", ""}},
wantErr: "user group name cannot be empty",
},
{
name: "invalid name prefix",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "INVALID"},
wantErr: "invalid name prefix",
},
{
name: "valid name prefix",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "my-service"},
wantErr: "",
},
}
@@ -966,14 +956,14 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
// First create a service
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
// Delete by domain using unexported method
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, false)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, false)
require.NoError(t, err)
// Verify service is deleted
@@ -982,16 +972,17 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
})
t.Run("expire uses correct activity", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
mgr, testStore := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, true)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, true)
require.NoError(t, err)
})
}
@@ -1003,13 +994,14 @@ func TestStopServiceFromPeer(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
@@ -1022,8 +1014,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
@@ -1042,8 +1034,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
assert.Equal(t, int64(0), count, "ephemeral service should be deleted after API delete")
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9090,
Protocol: "http",
Port: 9090,
Mode: "http",
})
assert.NoError(t, err, "new expose should succeed after API delete")
}
@@ -1054,8 +1046,8 @@ func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) {
for i := range 3 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
Port: uint16(8080 + i),
Mode: "http",
})
require.NoError(t, err)
}
@@ -1076,21 +1068,22 @@ func TestRenewServiceFromPeer(t *testing.T) {
ctx := context.Background()
t.Run("renews tracked expose", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
})
t.Run("fails for untracked domain", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id")
require.Error(t, err)
})
}
@@ -1191,3 +1184,33 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
require.NoError(t, err)
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
}
func TestValidateProtocolChange(t *testing.T) {
tests := []struct {
name string
oldP string
newP string
wantErr bool
}{
{"empty to http", "", "http", false},
{"http to http", "http", "http", false},
{"same protocol", "tcp", "tcp", false},
{"empty new proto", "tcp", "", false},
{"http to tcp", "http", "tcp", true},
{"tcp to udp", "tcp", "udp", true},
{"tls to http", "tls", "http", true},
{"udp to tls", "udp", "tls", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateProtocolChange(tt.oldP, tt.newP)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "cannot change mode")
} else {
require.NoError(t, err)
}
})
}
}

View File

@@ -34,6 +34,7 @@ const (
)
type Status string
type TargetType string
const (
StatusPending Status = "pending"
@@ -43,34 +44,36 @@ const (
StatusCertificateFailed Status = "certificate_failed"
StatusError Status = "error"
TargetTypePeer = "peer"
TargetTypeHost = "host"
TargetTypeDomain = "domain"
TargetTypeSubnet = "subnet"
TargetTypePeer TargetType = "peer"
TargetTypeHost TargetType = "host"
TargetTypeDomain TargetType = "domain"
TargetTypeSubnet TargetType = "subnet"
SourcePermanent = "permanent"
SourceEphemeral = "ephemeral"
)
type TargetOptions struct {
SkipTLSVerify bool `json:"skip_tls_verify"`
RequestTimeout time.Duration `json:"request_timeout,omitempty"`
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
SkipTLSVerify bool `json:"skip_tls_verify"`
RequestTimeout time.Duration `json:"request_timeout,omitempty"`
SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"`
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
}
type Target struct {
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port int `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
Options TargetOptions `gorm:"embedded" json:"options"`
ID uint `gorm:"primaryKey" json:"-"`
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
Path *string `json:"path,omitempty"`
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
Port uint16 `gorm:"index:idx_target_port" json:"port"`
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
TargetType TargetType `gorm:"index:idx_target_type" json:"target_type"`
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
Options TargetOptions `gorm:"embedded" json:"options"`
ProxyProtocol bool `json:"proxy_protocol"`
}
type PasswordAuthConfig struct {
@@ -146,23 +149,10 @@ type Service struct {
SessionPublicKey string `gorm:"column:session_public_key"`
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
SourcePeer string `gorm:"index:idx_service_source_peer"`
}
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
for _, target := range targets {
target.AccountID = accountID
}
s := &Service{
AccountID: accountID,
Name: name,
Domain: domain,
ProxyCluster: proxyCluster,
Targets: targets,
Enabled: enabled,
}
s.InitNewRecord()
return s
// Mode determines the service type: "http", "tcp", "udp", or "tls".
Mode string `gorm:"default:'http'"`
ListenPort uint16
PortAutoAssigned bool
}
// InitNewRecord generates a new unique ID and resets metadata for a newly created
@@ -177,21 +167,17 @@ func (s *Service) InitNewRecord() {
}
func (s *Service) ToAPIResponse() *api.Service {
s.Auth.ClearSecrets()
authConfig := api.ServiceAuthConfig{}
if s.Auth.PasswordAuth != nil {
authConfig.PasswordAuth = &api.PasswordAuthConfig{
Enabled: s.Auth.PasswordAuth.Enabled,
Password: s.Auth.PasswordAuth.Password,
Enabled: s.Auth.PasswordAuth.Enabled,
}
}
if s.Auth.PinAuth != nil {
authConfig.PinAuth = &api.PINAuthConfig{
Enabled: s.Auth.PinAuth.Enabled,
Pin: s.Auth.PinAuth.Pin,
}
}
@@ -208,13 +194,18 @@ func (s *Service) ToAPIResponse() *api.Service {
st := api.ServiceTarget{
Path: target.Path,
Host: &target.Host,
Port: target.Port,
Port: int(target.Port),
Protocol: api.ServiceTargetProtocol(target.Protocol),
TargetId: target.TargetId,
TargetType: api.ServiceTargetTargetType(target.TargetType),
Enabled: target.Enabled,
}
st.Options = targetOptionsToAPI(target.Options)
opts := targetOptionsToAPI(target.Options)
if opts == nil {
opts = &api.ServiceTargetOptions{}
}
opts.ProxyProtocol = &target.ProxyProtocol
st.Options = opts
apiTargets = append(apiTargets, st)
}
@@ -227,6 +218,9 @@ func (s *Service) ToAPIResponse() *api.Service {
meta.CertificateIssuedAt = s.Meta.CertificateIssuedAt
}
mode := api.ServiceMode(s.Mode)
listenPort := int(s.ListenPort)
resp := &api.Service{
Id: s.ID,
Name: s.Name,
@@ -237,6 +231,9 @@ func (s *Service) ToAPIResponse() *api.Service {
RewriteRedirects: &s.RewriteRedirects,
Auth: authConfig,
Meta: meta,
Mode: &mode,
ListenPort: &listenPort,
PortAutoAssigned: &s.PortAutoAssigned,
}
if s.ProxyCluster != "" {
@@ -247,37 +244,7 @@ func (s *Service) ToAPIResponse() *api.Service {
}
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
for _, target := range s.Targets {
if !target.Enabled {
continue
}
// TODO: Make path prefix stripping configurable per-target.
// Currently the matching prefix is baked into the target URL path,
// so the proxy strips-then-re-adds it (effectively a no-op).
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Path: "/", // TODO: support service path
}
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
}
path := "/"
if target.Path != nil {
path = *target.Path
}
pm := &proto.PathMapping{
Path: path,
Target: targetURL.String(),
}
pm.Options = targetOptionsToProto(target.Options)
pathMappings = append(pathMappings, pm)
}
pathMappings := s.buildPathMappings()
auth := &proto.Authentication{
SessionKey: s.SessionPublicKey,
@@ -306,9 +273,58 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
AccountId: s.AccountID,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Mode: s.Mode,
ListenPort: int32(s.ListenPort), //nolint:gosec
}
}
// buildPathMappings constructs PathMapping entries from targets.
// For HTTP/HTTPS, each target becomes a path-based route with a full URL.
// For L4/TLS, a single target maps to a host:port address.
func (s *Service) buildPathMappings() []*proto.PathMapping {
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
for _, target := range s.Targets {
if !target.Enabled {
continue
}
if IsL4Protocol(s.Mode) {
pm := &proto.PathMapping{
Target: net.JoinHostPort(target.Host, strconv.FormatUint(uint64(target.Port), 10)),
}
opts := l4TargetOptionsToProto(target)
if opts != nil {
pm.Options = opts
}
pathMappings = append(pathMappings, pm)
continue
}
// HTTP/HTTPS: build full URL
targetURL := url.URL{
Scheme: target.Protocol,
Host: target.Host,
Path: "/",
}
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10))
}
path := "/"
if target.Path != nil {
path = *target.Path
}
pm := &proto.PathMapping{
Path: path,
Target: targetURL.String(),
}
pm.Options = targetOptionsToProto(target.Options)
pathMappings = append(pathMappings, pm)
}
return pathMappings
}
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
switch op {
case Create:
@@ -325,8 +341,8 @@ func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
// isDefaultPort reports whether port is the standard default for the given scheme
// (443 for https, 80 for http).
func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
func isDefaultPort(scheme string, port uint16) bool {
return (scheme == TargetProtoHTTPS && port == 443) || (scheme == TargetProtoHTTP && port == 80)
}
// PathRewriteMode controls how the request path is rewritten before forwarding.
@@ -346,7 +362,7 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
}
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
return nil
}
apiOpts := &api.ServiceTargetOptions{}
@@ -357,6 +373,10 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
s := opts.RequestTimeout.String()
apiOpts.RequestTimeout = &s
}
if opts.SessionIdleTimeout != 0 {
s := opts.SessionIdleTimeout.String()
apiOpts.SessionIdleTimeout = &s
}
if opts.PathRewrite != "" {
pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite)
apiOpts.PathRewrite = &pr
@@ -382,6 +402,23 @@ func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
return popts
}
// l4TargetOptionsToProto converts L4-relevant target options to proto.
func l4TargetOptionsToProto(target *Target) *proto.PathTargetOptions {
if !target.ProxyProtocol && target.Options.RequestTimeout == 0 && target.Options.SessionIdleTimeout == 0 {
return nil
}
opts := &proto.PathTargetOptions{
ProxyProtocol: target.ProxyProtocol,
}
if target.Options.RequestTimeout > 0 {
opts.RequestTimeout = durationpb.New(target.Options.RequestTimeout)
}
if target.Options.SessionIdleTimeout > 0 {
opts.SessionIdleTimeout = durationpb.New(target.Options.SessionIdleTimeout)
}
return opts
}
func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) {
var opts TargetOptions
if o.SkipTlsVerify != nil {
@@ -394,6 +431,13 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions,
}
opts.RequestTimeout = d
}
if o.SessionIdleTimeout != nil {
d, err := time.ParseDuration(*o.SessionIdleTimeout)
if err != nil {
return opts, fmt.Errorf("target %d: parse session_idle_timeout %q: %w", idx, *o.SessionIdleTimeout, err)
}
opts.SessionIdleTimeout = d
}
if o.PathRewrite != nil {
opts.PathRewrite = PathRewriteMode(*o.PathRewrite)
}
@@ -408,15 +452,49 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
s.Domain = req.Domain
s.AccountID = accountID
targets := make([]*Target, 0, len(req.Targets))
for i, apiTarget := range req.Targets {
if req.Mode != nil {
s.Mode = string(*req.Mode)
}
if req.ListenPort != nil {
s.ListenPort = uint16(*req.ListenPort) //nolint:gosec
}
targets, err := targetsFromAPI(accountID, req.Targets)
if err != nil {
return err
}
s.Targets = targets
s.Enabled = req.Enabled
if req.PassHostHeader != nil {
s.PassHostHeader = *req.PassHostHeader
}
if req.RewriteRedirects != nil {
s.RewriteRedirects = *req.RewriteRedirects
}
if req.Auth != nil {
s.Auth = authFromAPI(req.Auth)
}
return nil
}
func targetsFromAPI(accountID string, apiTargetsPtr *[]api.ServiceTarget) ([]*Target, error) {
var apiTargets []api.ServiceTarget
if apiTargetsPtr != nil {
apiTargets = *apiTargetsPtr
}
targets := make([]*Target, 0, len(apiTargets))
for i, apiTarget := range apiTargets {
target := &Target{
AccountID: accountID,
Path: apiTarget.Path,
Port: apiTarget.Port,
Port: uint16(apiTarget.Port), //nolint:gosec // validated by API layer
Protocol: string(apiTarget.Protocol),
TargetId: apiTarget.TargetId,
TargetType: string(apiTarget.TargetType),
TargetType: TargetType(apiTarget.TargetType),
Enabled: apiTarget.Enabled,
}
if apiTarget.Host != nil {
@@ -425,49 +503,42 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
if apiTarget.Options != nil {
opts, err := targetOptionsFromAPI(i, apiTarget.Options)
if err != nil {
return err
return nil, err
}
target.Options = opts
if apiTarget.Options.ProxyProtocol != nil {
target.ProxyProtocol = *apiTarget.Options.ProxyProtocol
}
}
targets = append(targets, target)
}
s.Targets = targets
return targets, nil
}
s.Enabled = req.Enabled
if req.PassHostHeader != nil {
s.PassHostHeader = *req.PassHostHeader
}
if req.RewriteRedirects != nil {
s.RewriteRedirects = *req.RewriteRedirects
}
if req.Auth.PasswordAuth != nil {
s.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: req.Auth.PasswordAuth.Enabled,
Password: req.Auth.PasswordAuth.Password,
func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig {
var auth AuthConfig
if reqAuth.PasswordAuth != nil {
auth.PasswordAuth = &PasswordAuthConfig{
Enabled: reqAuth.PasswordAuth.Enabled,
Password: reqAuth.PasswordAuth.Password,
}
}
if req.Auth.PinAuth != nil {
s.Auth.PinAuth = &PINAuthConfig{
Enabled: req.Auth.PinAuth.Enabled,
Pin: req.Auth.PinAuth.Pin,
if reqAuth.PinAuth != nil {
auth.PinAuth = &PINAuthConfig{
Enabled: reqAuth.PinAuth.Enabled,
Pin: reqAuth.PinAuth.Pin,
}
}
if req.Auth.BearerAuth != nil {
if reqAuth.BearerAuth != nil {
bearerAuth := &BearerAuthConfig{
Enabled: req.Auth.BearerAuth.Enabled,
Enabled: reqAuth.BearerAuth.Enabled,
}
if req.Auth.BearerAuth.DistributionGroups != nil {
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
if reqAuth.BearerAuth.DistributionGroups != nil {
bearerAuth.DistributionGroups = *reqAuth.BearerAuth.DistributionGroups
}
s.Auth.BearerAuth = bearerAuth
auth.BearerAuth = bearerAuth
}
return nil
return auth
}
func (s *Service) Validate() error {
@@ -478,14 +549,69 @@ func (s *Service) Validate() error {
return errors.New("service name exceeds maximum length of 255 characters")
}
if s.Domain == "" {
return errors.New("service domain is required")
}
if len(s.Targets) == 0 {
return errors.New("at least one target is required")
}
if s.Mode == "" {
s.Mode = ModeHTTP
}
switch s.Mode {
case ModeHTTP:
return s.validateHTTPMode()
case ModeTCP, ModeUDP:
return s.validateTCPUDPMode()
case ModeTLS:
return s.validateTLSMode()
default:
return fmt.Errorf("unsupported mode %q", s.Mode)
}
}
func (s *Service) validateHTTPMode() error {
if s.Domain == "" {
return errors.New("service domain is required")
}
if s.ListenPort != 0 {
return errors.New("listen_port is not supported for HTTP services")
}
return s.validateHTTPTargets()
}
func (s *Service) validateTCPUDPMode() error {
if s.Domain == "" {
return errors.New("domain is required for TCP/UDP services (used for cluster derivation)")
}
if s.isAuthEnabled() {
return errors.New("auth is not supported for TCP/UDP services")
}
if len(s.Targets) != 1 {
return errors.New("TCP/UDP services must have exactly one target")
}
if s.Mode == ModeUDP && s.Targets[0].ProxyProtocol {
return errors.New("proxy_protocol is not supported for UDP services")
}
return s.validateL4Target(s.Targets[0])
}
func (s *Service) validateTLSMode() error {
if s.Domain == "" {
return errors.New("domain is required for TLS services (used for SNI matching)")
}
if s.isAuthEnabled() {
return errors.New("auth is not supported for TLS services")
}
if s.ListenPort == 0 {
return errors.New("listen_port is required for TLS services")
}
if len(s.Targets) != 1 {
return errors.New("TLS services must have exactly one target")
}
return s.validateL4Target(s.Targets[0])
}
func (s *Service) validateHTTPTargets() error {
for i, target := range s.Targets {
switch target.TargetType {
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
@@ -500,6 +626,9 @@ func (s *Service) Validate() error {
if target.TargetId == "" {
return fmt.Errorf("target %d has empty target_id", i)
}
if target.ProxyProtocol {
return fmt.Errorf("target %d: proxy_protocol is not supported for HTTP services", i)
}
if err := validateTargetOptions(i, &target.Options); err != nil {
return err
}
@@ -508,11 +637,62 @@ func (s *Service) Validate() error {
return nil
}
func (s *Service) validateL4Target(target *Target) error {
if target.Port == 0 {
return errors.New("target port is required for L4 services")
}
if target.TargetId == "" {
return errors.New("target_id is required for L4 services")
}
switch target.TargetType {
case TargetTypePeer, TargetTypeHost:
// OK
case TargetTypeSubnet:
if target.Host == "" {
return errors.New("target host is required for subnet targets")
}
default:
return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType)
}
if target.Path != nil && *target.Path != "" && *target.Path != "/" {
return errors.New("path is not supported for L4 services")
}
return nil
}
// Service mode constants.
const (
maxRequestTimeout = 5 * time.Minute
maxCustomHeaders = 16
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
ModeHTTP = "http"
ModeTCP = "tcp"
ModeUDP = "udp"
ModeTLS = "tls"
)
// Target protocol constants (URL scheme for backend connections).
const (
TargetProtoHTTP = "http"
TargetProtoHTTPS = "https"
TargetProtoTCP = "tcp"
TargetProtoUDP = "udp"
)
// IsL4Protocol returns true if the mode requires port-based routing (TCP, UDP, or TLS).
func IsL4Protocol(mode string) bool {
return mode == ModeTCP || mode == ModeUDP || mode == ModeTLS
}
// IsPortBasedProtocol returns true if the mode relies on dedicated port allocation.
// TLS is excluded because it uses SNI routing and can share ports with other TLS services.
func IsPortBasedProtocol(mode string) bool {
return mode == ModeTCP || mode == ModeUDP
}
const (
maxRequestTimeout = 5 * time.Minute
maxSessionIdleTimeout = 10 * time.Minute
maxCustomHeaders = 16
maxHeaderKeyLen = 128
maxHeaderValueLen = 4096
)
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
@@ -560,6 +740,15 @@ func validateTargetOptions(idx int, opts *TargetOptions) error {
}
}
if opts.SessionIdleTimeout != 0 {
if opts.SessionIdleTimeout <= 0 {
return fmt.Errorf("target %d: session_idle_timeout must be positive", idx)
}
if opts.SessionIdleTimeout > maxSessionIdleTimeout {
return fmt.Errorf("target %d: session_idle_timeout exceeds maximum of %s", idx, maxSessionIdleTimeout)
}
}
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
return err
}
@@ -608,17 +797,49 @@ func containsCRLF(s string) bool {
}
func (s *Service) EventMeta() map[string]any {
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
meta := map[string]any{
"name": s.Name,
"domain": s.Domain,
"proxy_cluster": s.ProxyCluster,
"source": s.Source,
"auth": s.isAuthEnabled(),
"mode": s.Mode,
}
if s.ListenPort != 0 {
meta["listen_port"] = s.ListenPort
}
if len(s.Targets) > 0 {
t := s.Targets[0]
if t.ProxyProtocol {
meta["proxy_protocol"] = true
}
if t.Options.RequestTimeout != 0 {
meta["request_timeout"] = t.Options.RequestTimeout.String()
}
if t.Options.SessionIdleTimeout != 0 {
meta["session_idle_timeout"] = t.Options.SessionIdleTimeout.String()
}
}
return meta
}
func (s *Service) isAuthEnabled() bool {
return s.Auth.PasswordAuth != nil || s.Auth.PinAuth != nil || s.Auth.BearerAuth != nil
return (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) ||
(s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled) ||
(s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled)
}
func (s *Service) Copy() *Service {
targets := make([]*Target, len(s.Targets))
for i, target := range s.Targets {
targetCopy := *target
if target.Path != nil {
p := *target.Path
targetCopy.Path = &p
}
if len(target.Options.CustomHeaders) > 0 {
targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders))
for k, v := range target.Options.CustomHeaders {
@@ -628,6 +849,24 @@ func (s *Service) Copy() *Service {
targets[i] = &targetCopy
}
authCopy := s.Auth
if s.Auth.PasswordAuth != nil {
pa := *s.Auth.PasswordAuth
authCopy.PasswordAuth = &pa
}
if s.Auth.PinAuth != nil {
pa := *s.Auth.PinAuth
authCopy.PinAuth = &pa
}
if s.Auth.BearerAuth != nil {
ba := *s.Auth.BearerAuth
if len(s.Auth.BearerAuth.DistributionGroups) > 0 {
ba.DistributionGroups = make([]string, len(s.Auth.BearerAuth.DistributionGroups))
copy(ba.DistributionGroups, s.Auth.BearerAuth.DistributionGroups)
}
authCopy.BearerAuth = &ba
}
return &Service{
ID: s.ID,
AccountID: s.AccountID,
@@ -638,12 +877,15 @@ func (s *Service) Copy() *Service {
Enabled: s.Enabled,
PassHostHeader: s.PassHostHeader,
RewriteRedirects: s.RewriteRedirects,
Auth: s.Auth,
Auth: authCopy,
Meta: s.Meta,
SessionPrivateKey: s.SessionPrivateKey,
SessionPublicKey: s.SessionPublicKey,
Source: s.Source,
SourcePeer: s.SourcePeer,
Mode: s.Mode,
ListenPort: s.ListenPort,
PortAutoAssigned: s.PortAutoAssigned,
}
}
@@ -688,12 +930,16 @@ var validNamePrefix = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,30}[a-z0-9])?$`)
// ExposeServiceRequest contains the parameters for creating a peer-initiated expose service.
type ExposeServiceRequest struct {
NamePrefix string
Port int
Protocol string
Domain string
Pin string
Password string
UserGroups []string
Port uint16
Mode string
// TargetProtocol is the protocol used to connect to the peer backend.
// For HTTP mode: "http" (default) or "https". For L4 modes: "tcp" or "udp".
TargetProtocol string
Domain string
Pin string
Password string
UserGroups []string
ListenPort uint16
}
// Validate checks all fields of the expose request.
@@ -702,12 +948,20 @@ func (r *ExposeServiceRequest) Validate() error {
return errors.New("request cannot be nil")
}
if r.Port < 1 || r.Port > 65535 {
if r.Port == 0 {
return fmt.Errorf("port must be between 1 and 65535, got %d", r.Port)
}
if r.Protocol != "http" && r.Protocol != "https" {
return fmt.Errorf("unsupported protocol %q: must be http or https", r.Protocol)
switch r.Mode {
case ModeHTTP, ModeTCP, ModeUDP, ModeTLS:
default:
return fmt.Errorf("unsupported mode %q", r.Mode)
}
if IsL4Protocol(r.Mode) {
if r.Pin != "" || r.Password != "" || len(r.UserGroups) > 0 {
return fmt.Errorf("authentication is not supported for %s mode", r.Mode)
}
}
if r.Pin != "" && !pinRegexp.MatchString(r.Pin) {
@@ -729,55 +983,79 @@ func (r *ExposeServiceRequest) Validate() error {
// ToService builds a Service from the expose request.
func (r *ExposeServiceRequest) ToService(accountID, peerID, serviceName string) *Service {
service := &Service{
svc := &Service{
AccountID: accountID,
Name: serviceName,
Mode: r.Mode,
Enabled: true,
Targets: []*Target{
{
AccountID: accountID,
Port: r.Port,
Protocol: r.Protocol,
TargetId: peerID,
TargetType: TargetTypePeer,
Enabled: true,
},
}
// If domain is empty, CreateServiceFromPeer generates a unique subdomain.
// When explicitly provided, the service name is prepended as a subdomain.
if r.Domain != "" {
svc.Domain = serviceName + "." + r.Domain
}
if IsL4Protocol(r.Mode) {
svc.ListenPort = r.Port
if r.ListenPort > 0 {
svc.ListenPort = r.ListenPort
}
}
var targetProto string
switch {
case !IsL4Protocol(r.Mode):
targetProto = TargetProtoHTTP
if r.TargetProtocol != "" {
targetProto = r.TargetProtocol
}
case r.Mode == ModeUDP:
targetProto = TargetProtoUDP
default:
targetProto = TargetProtoTCP
}
svc.Targets = []*Target{
{
AccountID: accountID,
Port: r.Port,
Protocol: targetProto,
TargetId: peerID,
TargetType: TargetTypePeer,
Enabled: true,
},
}
if r.Domain != "" {
service.Domain = serviceName + "." + r.Domain
}
if r.Pin != "" {
service.Auth.PinAuth = &PINAuthConfig{
svc.Auth.PinAuth = &PINAuthConfig{
Enabled: true,
Pin: r.Pin,
}
}
if r.Password != "" {
service.Auth.PasswordAuth = &PasswordAuthConfig{
svc.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: true,
Password: r.Password,
}
}
if len(r.UserGroups) > 0 {
service.Auth.BearerAuth = &BearerAuthConfig{
svc.Auth.BearerAuth = &BearerAuthConfig{
Enabled: true,
DistributionGroups: r.UserGroups,
}
}
return service
return svc
}
// ExposeServiceResponse contains the result of a successful peer expose creation.
type ExposeServiceResponse struct {
ServiceName string
ServiceURL string
Domain string
ServiceName string
ServiceURL string
Domain string
PortAutoAssigned bool
}
// GenerateExposeName generates a random service name for peer-exposed services.

View File

@@ -44,7 +44,7 @@ func TestValidate_EmptyDomain(t *testing.T) {
func TestValidate_NoTargets(t *testing.T) {
rp := validProxy()
rp.Targets = nil
assert.ErrorContains(t, rp.Validate(), "at least one target")
assert.ErrorContains(t, rp.Validate(), "at least one target is required")
}
func TestValidate_EmptyTargetId(t *testing.T) {
@@ -273,7 +273,7 @@ func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) {
func TestIsDefaultPort(t *testing.T) {
tests := []struct {
scheme string
port int
port uint16
want bool
}{
{"http", 80, true},
@@ -299,7 +299,7 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) {
name string
protocol string
host string
port int
port uint16
wantTarget string
}{
{
@@ -645,8 +645,8 @@ func TestGenerateExposeName(t *testing.T) {
func TestExposeServiceRequest_ToService(t *testing.T) {
t.Run("basic HTTP service", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
service := req.ToService("account-1", "peer-1", "mysvc")
@@ -658,7 +658,7 @@ func TestExposeServiceRequest_ToService(t *testing.T) {
require.Len(t, service.Targets, 1)
target := service.Targets[0]
assert.Equal(t, 8080, target.Port)
assert.Equal(t, uint16(8080), target.Port)
assert.Equal(t, "http", target.Protocol)
assert.Equal(t, "peer-1", target.TargetId)
assert.Equal(t, TargetTypePeer, target.TargetType)
@@ -730,3 +730,182 @@ func TestExposeServiceRequest_ToService(t *testing.T) {
require.NotNil(t, service.Auth.BearerAuth)
})
}
func TestValidate_TLSOnly(t *testing.T) {
rp := &Service{
Name: "tls-svc",
Mode: "tls",
Domain: "example.com",
ListenPort: 8443,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true},
},
}
require.NoError(t, rp.Validate())
}
func TestValidate_TLSMissingListenPort(t *testing.T) {
rp := &Service{
Name: "tls-svc",
Mode: "tls",
Domain: "example.com",
ListenPort: 0,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true},
},
}
assert.ErrorContains(t, rp.Validate(), "listen_port is required")
}
func TestValidate_TLSMissingDomain(t *testing.T) {
rp := &Service{
Name: "tls-svc",
Mode: "tls",
ListenPort: 8443,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true},
},
}
assert.ErrorContains(t, rp.Validate(), "domain is required")
}
func TestValidate_TCPValid(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
ListenPort: 5432,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
},
}
require.NoError(t, rp.Validate())
}
func TestValidate_TCPMissingListenPort(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
},
}
require.NoError(t, rp.Validate(), "TCP with listen_port=0 is valid (auto-assigned by manager)")
}
func TestValidate_L4MultipleTargets(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
ListenPort: 5432,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
{TargetId: "peer-2", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
},
}
assert.ErrorContains(t, rp.Validate(), "exactly one target")
}
func TestValidate_L4TargetMissingPort(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
ListenPort: 5432,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 0, Enabled: true},
},
}
assert.ErrorContains(t, rp.Validate(), "port is required")
}
func TestValidate_TLSInvalidTargetType(t *testing.T) {
rp := &Service{
Name: "tls-svc",
Mode: "tls",
Domain: "example.com",
ListenPort: 443,
Targets: []*Target{
{TargetId: "peer-1", TargetType: "invalid", Protocol: "tcp", Port: 443, Enabled: true},
},
}
assert.Error(t, rp.Validate())
}
func TestValidate_TLSSubnetValid(t *testing.T) {
rp := &Service{
Name: "tls-subnet",
Mode: "tls",
Domain: "example.com",
ListenPort: 8443,
Targets: []*Target{
{TargetId: "subnet-1", TargetType: TargetTypeSubnet, Protocol: "tcp", Port: 443, Host: "10.0.0.5", Enabled: true},
},
}
require.NoError(t, rp.Validate())
}
func TestValidate_HTTPProxyProtocolRejected(t *testing.T) {
rp := validProxy()
rp.Targets[0].ProxyProtocol = true
assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for HTTP")
}
func TestValidate_UDPProxyProtocolRejected(t *testing.T) {
rp := &Service{
Name: "udp-svc",
Mode: "udp",
Domain: "cluster.test",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "udp", Port: 5432, Enabled: true, ProxyProtocol: true},
},
}
assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for UDP")
}
func TestValidate_TCPProxyProtocolAllowed(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
ListenPort: 5432,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true, ProxyProtocol: true},
},
}
require.NoError(t, rp.Validate())
}
func TestExposeServiceRequest_Validate_L4RejectsAuth(t *testing.T) {
tests := []struct {
name string
req ExposeServiceRequest
}{
{
name: "tcp with pin",
req: ExposeServiceRequest{Port: 8080, Mode: "tcp", Pin: "123456"},
},
{
name: "udp with password",
req: ExposeServiceRequest{Port: 8080, Mode: "udp", Password: "secret"},
},
{
name: "tls with user groups",
req: ExposeServiceRequest{Port: 443, Mode: "tls", UserGroups: []string{"admins"}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.req.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "authentication is not supported")
})
}
}
func TestExposeServiceRequest_Validate_HTTPAllowsAuth(t *testing.T) {
req := ExposeServiceRequest{Port: 8080, Mode: "http", Pin: "123456"}
require.NoError(t, req.Validate())
}