mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)
This commit is contained in:
@@ -22,7 +22,7 @@ type Manager interface {
|
||||
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
|
||||
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
|
||||
CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error)
|
||||
RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
|
||||
StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
|
||||
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
|
||||
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
|
||||
StartExposeReaper(ctx context.Context)
|
||||
}
|
||||
|
||||
@@ -211,17 +211,17 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter
|
||||
}
|
||||
|
||||
// RenewServiceFromPeer mocks base method.
|
||||
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain)
|
||||
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, serviceID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer.
|
||||
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, serviceID)
|
||||
}
|
||||
|
||||
// SetCertificateIssuedAt mocks base method.
|
||||
@@ -265,17 +265,17 @@ func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Ca
|
||||
}
|
||||
|
||||
// StopServiceFromPeer mocks base method.
|
||||
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain)
|
||||
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, serviceID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// StopServiceFromPeer indicates an expected call of StopServiceFromPeer.
|
||||
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, serviceID)
|
||||
}
|
||||
|
||||
// UpdateService mocks base method.
|
||||
|
||||
@@ -11,19 +11,22 @@ import (
|
||||
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
manager rpservice.Manager
|
||||
manager rpservice.Manager
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
// RegisterEndpoints registers all service HTTP endpoints.
|
||||
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
|
||||
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
manager: manager,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
|
||||
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
|
||||
|
||||
@@ -18,8 +18,8 @@ func TestReapExpiredExposes(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Port: 8080,
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -28,8 +28,8 @@ func TestReapExpiredExposes(t *testing.T) {
|
||||
|
||||
// Create a non-expired service
|
||||
resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8081,
|
||||
Protocol: "http",
|
||||
Port: 8081,
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -49,15 +49,16 @@ func TestReapAlreadyDeletedService(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Port: 8080,
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||
|
||||
// Delete the service before reaping
|
||||
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
|
||||
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reaping should handle the already-deleted service gracefully
|
||||
@@ -70,8 +71,8 @@ func TestConcurrentReapAndRenew(t *testing.T) {
|
||||
|
||||
for i := range 5 {
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080 + i,
|
||||
Protocol: "http",
|
||||
Port: uint16(8080 + i),
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -108,17 +109,19 @@ func TestRenewEphemeralService(t *testing.T) {
|
||||
|
||||
t.Run("renew succeeds for active service", func(t *testing.T) {
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8082,
|
||||
Protocol: "http",
|
||||
Port: 8082,
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
svc, lookupErr := mgr.store.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.NoError(t, lookupErr)
|
||||
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("renew fails for nonexistent domain", func(t *testing.T) {
|
||||
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
|
||||
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no active expose session")
|
||||
})
|
||||
@@ -133,8 +136,8 @@ func TestCountAndExistsEphemeralServices(t *testing.T) {
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8083,
|
||||
Protocol: "http",
|
||||
Port: 8083,
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -157,15 +160,15 @@ func TestMaxExposesPerPeerEnforced(t *testing.T) {
|
||||
|
||||
for i := range maxExposesPerPeer {
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8090 + i,
|
||||
Protocol: "http",
|
||||
Port: uint16(8090 + i),
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err, "expose %d should succeed", i)
|
||||
}
|
||||
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 9999,
|
||||
Protocol: "http",
|
||||
Port: 9999,
|
||||
Mode: "http",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "maximum number of active expose sessions")
|
||||
@@ -176,8 +179,8 @@ func TestReapSkipsRenewedService(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8086,
|
||||
Protocol: "http",
|
||||
Port: 8086,
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -185,7 +188,9 @@ func TestReapSkipsRenewedService(t *testing.T) {
|
||||
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||
|
||||
// Renew it before the reaper runs
|
||||
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
svc, err := testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reaper should skip it because the re-check sees a fresh timestamp
|
||||
@@ -195,6 +200,14 @@ func TestReapSkipsRenewedService(t *testing.T) {
|
||||
require.NoError(t, err, "renewed service should survive reaping")
|
||||
}
|
||||
|
||||
// resolveServiceIDByDomain looks up a service ID by domain in tests.
|
||||
func resolveServiceIDByDomain(t *testing.T, s store.Store, domain string) string {
|
||||
t.Helper()
|
||||
svc, err := s.GetServiceByDomain(context.Background(), domain)
|
||||
require.NoError(t, err)
|
||||
return svc.ID
|
||||
}
|
||||
|
||||
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
|
||||
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
|
||||
t.Helper()
|
||||
|
||||
@@ -0,0 +1,582 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const testCluster = "test-cluster"
|
||||
|
||||
func boolPtr(v bool) *bool { return &v }
|
||||
|
||||
// setupL4Test creates a manager with a mock proxy controller for L4 port tests.
|
||||
func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Store, *proxy.MockController) {
|
||||
t.Helper()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
ctx := context.Background()
|
||||
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
err = testStore.SaveAccount(ctx, &types.Account{
|
||||
Id: testAccountID,
|
||||
CreatedBy: testUserID,
|
||||
Settings: &types.Settings{
|
||||
PeerExposeEnabled: true,
|
||||
PeerExposeGroups: []string{testGroupID},
|
||||
},
|
||||
Users: map[string]*types.User{
|
||||
testUserID: {
|
||||
Id: testUserID,
|
||||
AccountID: testAccountID,
|
||||
Role: types.UserRoleAdmin,
|
||||
},
|
||||
},
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
testPeerID: {
|
||||
ID: testPeerID,
|
||||
AccountID: testAccountID,
|
||||
Key: "test-key",
|
||||
DNSLabel: "test-peer",
|
||||
Name: "test-peer",
|
||||
IP: net.ParseIP("100.64.0.1"),
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*types.Group{
|
||||
testGroupID: {
|
||||
ID: testGroupID,
|
||||
AccountID: testAccountID,
|
||||
Name: "Expose Group",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockCtrl := proxy.NewMockController(ctrl)
|
||||
mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes()
|
||||
mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
|
||||
mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()
|
||||
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
|
||||
},
|
||||
}
|
||||
|
||||
mgr := &Manager{
|
||||
store: testStore,
|
||||
accountManager: accountMgr,
|
||||
permissionsManager: permissions.NewManager(testStore),
|
||||
proxyController: mockCtrl,
|
||||
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
|
||||
}
|
||||
mgr.exposeReaper = &exposeReaper{manager: mgr}
|
||||
|
||||
return mgr, testStore, mockCtrl
|
||||
}
|
||||
|
||||
// seedService creates a service directly in the store for test setup.
|
||||
func seedService(t *testing.T, s store.Store, name, protocol, domain, cluster string, port uint16) *rpservice.Service {
|
||||
t.Helper()
|
||||
|
||||
svc := &rpservice.Service{
|
||||
AccountID: testAccountID,
|
||||
Name: name,
|
||||
Mode: protocol,
|
||||
Domain: domain,
|
||||
ProxyCluster: cluster,
|
||||
ListenPort: port,
|
||||
Enabled: true,
|
||||
Source: "permanent",
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: protocol, Port: 8080, Enabled: true},
|
||||
},
|
||||
}
|
||||
svc.InitNewRecord()
|
||||
err := s.CreateService(context.Background(), svc)
|
||||
require.NoError(t, err)
|
||||
return svc
|
||||
}
|
||||
|
||||
func TestPortConflict_TCPSamePortCluster(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
seedService(t, testStore, "existing-tcp", "tcp", testCluster, testCluster, 5432)
|
||||
|
||||
svc := &rpservice.Service{
|
||||
AccountID: testAccountID,
|
||||
Name: "conflicting-tcp",
|
||||
Mode: "tcp",
|
||||
Domain: "conflicting-tcp." + testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 5432,
|
||||
Enabled: true,
|
||||
Source: "permanent",
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
|
||||
},
|
||||
}
|
||||
svc.InitNewRecord()
|
||||
|
||||
err := mgr.persistNewService(ctx, testAccountID, svc)
|
||||
require.Error(t, err, "TCP+TCP on same port/cluster should be rejected")
|
||||
assert.Contains(t, err.Error(), "already in use")
|
||||
}
|
||||
|
||||
func TestPortConflict_UDPSamePortCluster(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
seedService(t, testStore, "existing-udp", "udp", testCluster, testCluster, 5432)
|
||||
|
||||
svc := &rpservice.Service{
|
||||
AccountID: testAccountID,
|
||||
Name: "conflicting-udp",
|
||||
Mode: "udp",
|
||||
Domain: "conflicting-udp." + testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 5432,
|
||||
Enabled: true,
|
||||
Source: "permanent",
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "udp", Port: 9090, Enabled: true},
|
||||
},
|
||||
}
|
||||
svc.InitNewRecord()
|
||||
|
||||
err := mgr.persistNewService(ctx, testAccountID, svc)
|
||||
require.Error(t, err, "UDP+UDP on same port/cluster should be rejected")
|
||||
assert.Contains(t, err.Error(), "already in use")
|
||||
}
|
||||
|
||||
func TestPortConflict_TLSSamePortDifferentDomain(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
seedService(t, testStore, "existing-tls", "tls", "app1.example.com", testCluster, 443)
|
||||
|
||||
svc := &rpservice.Service{
|
||||
AccountID: testAccountID,
|
||||
Name: "new-tls",
|
||||
Mode: "tls",
|
||||
Domain: "app2.example.com",
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 443,
|
||||
Enabled: true,
|
||||
Source: "permanent",
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
|
||||
},
|
||||
}
|
||||
svc.InitNewRecord()
|
||||
|
||||
err := mgr.persistNewService(ctx, testAccountID, svc)
|
||||
assert.NoError(t, err, "TLS+TLS on same port with different domains should be allowed (SNI routing)")
|
||||
}
|
||||
|
||||
func TestPortConflict_TLSSamePortSameDomain(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443)
|
||||
|
||||
svc := &rpservice.Service{
|
||||
AccountID: testAccountID,
|
||||
Name: "duplicate-tls",
|
||||
Mode: "tls",
|
||||
Domain: "app.example.com",
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 443,
|
||||
Enabled: true,
|
||||
Source: "permanent",
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
|
||||
},
|
||||
}
|
||||
svc.InitNewRecord()
|
||||
|
||||
err := mgr.persistNewService(ctx, testAccountID, svc)
|
||||
require.Error(t, err, "TLS+TLS on same domain should be rejected")
|
||||
assert.Contains(t, err.Error(), "domain already taken")
|
||||
}
|
||||
|
||||
func TestPortConflict_TLSAndTCPSamePort(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443)
|
||||
|
||||
svc := &rpservice.Service{
|
||||
AccountID: testAccountID,
|
||||
Name: "new-tcp",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 443,
|
||||
Enabled: true,
|
||||
Source: "permanent",
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
|
||||
},
|
||||
}
|
||||
svc.InitNewRecord()
|
||||
|
||||
err := mgr.persistNewService(ctx, testAccountID, svc)
|
||||
assert.NoError(t, err, "TLS+TCP on same port should be allowed (multiplexed)")
|
||||
}
|
||||
|
||||
func TestAutoAssign_TCPNoListenPort(t *testing.T) {
|
||||
mgr, _, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
svc := &rpservice.Service{
|
||||
AccountID: testAccountID,
|
||||
Name: "auto-tcp",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 0,
|
||||
Enabled: true,
|
||||
Source: "permanent",
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
|
||||
},
|
||||
}
|
||||
svc.InitNewRecord()
|
||||
|
||||
err := mgr.persistNewService(ctx, testAccountID, svc)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax,
|
||||
"auto-assigned port %d should be in range [%d, %d]", svc.ListenPort, autoAssignPortMin, autoAssignPortMax)
|
||||
assert.True(t, svc.PortAutoAssigned, "PortAutoAssigned should be set")
|
||||
}
|
||||
|
||||
func TestAutoAssign_TCPCustomPortRejectedWhenNotSupported(t *testing.T) {
|
||||
mgr, _, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
svc := &rpservice.Service{
|
||||
AccountID: testAccountID,
|
||||
Name: "custom-tcp",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 5555,
|
||||
Enabled: true,
|
||||
Source: "permanent",
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
|
||||
},
|
||||
}
|
||||
svc.InitNewRecord()
|
||||
|
||||
err := mgr.persistNewService(ctx, testAccountID, svc)
|
||||
require.Error(t, err, "TCP with custom port should be rejected when cluster doesn't support it")
|
||||
assert.Contains(t, err.Error(), "custom ports")
|
||||
}
|
||||
|
||||
func TestAutoAssign_TLSCustomPortAlwaysAllowed(t *testing.T) {
|
||||
mgr, _, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
svc := &rpservice.Service{
|
||||
AccountID: testAccountID,
|
||||
Name: "custom-tls",
|
||||
Mode: "tls",
|
||||
Domain: "app.example.com",
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 9999,
|
||||
Enabled: true,
|
||||
Source: "permanent",
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
|
||||
},
|
||||
}
|
||||
svc.InitNewRecord()
|
||||
|
||||
err := mgr.persistNewService(ctx, testAccountID, svc)
|
||||
assert.NoError(t, err, "TLS with custom port should always be allowed regardless of cluster capability")
|
||||
assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden")
|
||||
assert.False(t, svc.PortAutoAssigned, "PortAutoAssigned should not be set for TLS")
|
||||
}
|
||||
|
||||
func TestAutoAssign_EphemeralOverridesPortWhenNotSupported(t *testing.T) {
|
||||
mgr, _, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
svc := &rpservice.Service{
|
||||
AccountID: testAccountID,
|
||||
Name: "ephemeral-tcp",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 5555,
|
||||
Enabled: true,
|
||||
Source: "ephemeral",
|
||||
SourcePeer: testPeerID,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
|
||||
},
|
||||
}
|
||||
svc.InitNewRecord()
|
||||
|
||||
err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, uint16(5555), svc.ListenPort, "requested port should be overridden")
|
||||
assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax,
|
||||
"auto-assigned port %d should be in range", svc.ListenPort)
|
||||
assert.True(t, svc.PortAutoAssigned)
|
||||
}
|
||||
|
||||
func TestAutoAssign_EphemeralTLSKeepsCustomPort(t *testing.T) {
|
||||
mgr, _, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
svc := &rpservice.Service{
|
||||
AccountID: testAccountID,
|
||||
Name: "ephemeral-tls",
|
||||
Mode: "tls",
|
||||
Domain: "app.example.com",
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 9999,
|
||||
Enabled: true,
|
||||
Source: "ephemeral",
|
||||
SourcePeer: testPeerID,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
|
||||
},
|
||||
}
|
||||
svc.InitNewRecord()
|
||||
|
||||
err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden")
|
||||
assert.False(t, svc.PortAutoAssigned)
|
||||
}
|
||||
|
||||
func TestAutoAssign_AvoidsExistingPorts(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
existingPort := uint16(20000)
|
||||
seedService(t, testStore, "existing", "tcp", testCluster, testCluster, existingPort)
|
||||
|
||||
svc := &rpservice.Service{
|
||||
AccountID: testAccountID,
|
||||
Name: "auto-tcp",
|
||||
Mode: "tcp",
|
||||
Domain: "auto-tcp." + testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 0,
|
||||
Enabled: true,
|
||||
Source: "permanent",
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
|
||||
},
|
||||
}
|
||||
svc.InitNewRecord()
|
||||
|
||||
err := mgr.persistNewService(ctx, testAccountID, svc)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, existingPort, svc.ListenPort, "auto-assigned port should not collide with existing")
|
||||
assert.True(t, svc.PortAutoAssigned)
|
||||
}
|
||||
|
||||
func TestAutoAssign_TCPCustomPortAllowedWhenSupported(t *testing.T) {
|
||||
mgr, _, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
svc := &rpservice.Service{
|
||||
AccountID: testAccountID,
|
||||
Name: "custom-tcp",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 5555,
|
||||
Enabled: true,
|
||||
Source: "permanent",
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
|
||||
},
|
||||
}
|
||||
svc.InitNewRecord()
|
||||
|
||||
err := mgr.persistNewService(ctx, testAccountID, svc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint16(5555), svc.ListenPort, "custom port should be preserved when supported")
|
||||
assert.False(t, svc.PortAutoAssigned)
|
||||
}
|
||||
|
||||
func TestUpdate_PreservesExistingListenPort(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
|
||||
|
||||
updated := &rpservice.Service{
|
||||
ID: existing.ID,
|
||||
AccountID: testAccountID,
|
||||
Name: "tcp-svc-renamed",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 0,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when update sends 0")
|
||||
}
|
||||
|
||||
func TestUpdate_AllowsPortChange(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
|
||||
|
||||
updated := &rpservice.Service{
|
||||
ID: existing.ID,
|
||||
AccountID: testAccountID,
|
||||
Name: "tcp-svc",
|
||||
Mode: "tcp",
|
||||
Domain: testCluster,
|
||||
ProxyCluster: testCluster,
|
||||
ListenPort: 54321,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied")
|
||||
}
|
||||
|
||||
func TestCreateServiceFromPeer_TCP(t *testing.T) {
|
||||
mgr, _, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 5432,
|
||||
Mode: "tcp",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEmpty(t, resp.ServiceName)
|
||||
assert.Contains(t, resp.Domain, ".test.netbird.io", "TCP uses unique subdomain")
|
||||
assert.True(t, resp.PortAutoAssigned, "port should be auto-assigned when cluster doesn't support custom ports")
|
||||
assert.Contains(t, resp.ServiceURL, "tcp://")
|
||||
}
|
||||
|
||||
func TestCreateServiceFromPeer_TCP_CustomPort(t *testing.T) {
|
||||
mgr, _, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 5432,
|
||||
Mode: "tcp",
|
||||
ListenPort: 15432,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, resp.PortAutoAssigned)
|
||||
assert.Contains(t, resp.ServiceURL, ":15432")
|
||||
}
|
||||
|
||||
func TestCreateServiceFromPeer_TCP_DefaultListenPort(t *testing.T) {
|
||||
mgr, _, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 5432,
|
||||
Mode: "tcp",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// When no explicit listen port, defaults to target port
|
||||
assert.Contains(t, resp.ServiceURL, ":5432")
|
||||
assert.False(t, resp.PortAutoAssigned)
|
||||
}
|
||||
|
||||
func TestCreateServiceFromPeer_TLS(t *testing.T) {
|
||||
mgr, _, _ := setupL4Test(t, boolPtr(false))
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 443,
|
||||
Mode: "tls",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Contains(t, resp.Domain, ".test.netbird.io", "TLS uses subdomain")
|
||||
assert.Contains(t, resp.ServiceURL, "tls://")
|
||||
assert.Contains(t, resp.ServiceURL, ":443")
|
||||
// TLS always keeps its port (not port-based protocol for auto-assign)
|
||||
assert.False(t, resp.PortAutoAssigned)
|
||||
}
|
||||
|
||||
func TestCreateServiceFromPeer_TCP_StopAndRenew(t *testing.T) {
|
||||
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Mode: "tcp",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
|
||||
|
||||
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Renew after stop should fail
|
||||
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCreateServiceFromPeer_L4_RejectsAuth(t *testing.T) {
|
||||
mgr, _, _ := setupL4Test(t, boolPtr(true))
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Mode: "tcp",
|
||||
Pin: "123456",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "authentication is not supported")
|
||||
}
|
||||
@@ -4,7 +4,9 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -23,6 +25,45 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAutoAssignPortMin uint16 = 10000
|
||||
defaultAutoAssignPortMax uint16 = 49151
|
||||
|
||||
// EnvAutoAssignPortMin overrides the lower bound for auto-assigned L4 listen ports.
|
||||
EnvAutoAssignPortMin = "NB_PROXY_PORT_MIN"
|
||||
// EnvAutoAssignPortMax overrides the upper bound for auto-assigned L4 listen ports.
|
||||
EnvAutoAssignPortMax = "NB_PROXY_PORT_MAX"
|
||||
)
|
||||
|
||||
var (
|
||||
autoAssignPortMin = defaultAutoAssignPortMin
|
||||
autoAssignPortMax = defaultAutoAssignPortMax
|
||||
)
|
||||
|
||||
func init() {
|
||||
autoAssignPortMin = portFromEnv(EnvAutoAssignPortMin, defaultAutoAssignPortMin)
|
||||
autoAssignPortMax = portFromEnv(EnvAutoAssignPortMax, defaultAutoAssignPortMax)
|
||||
if autoAssignPortMin > autoAssignPortMax {
|
||||
log.Warnf("port range invalid: %s (%d) > %s (%d), using defaults",
|
||||
EnvAutoAssignPortMin, autoAssignPortMin, EnvAutoAssignPortMax, autoAssignPortMax)
|
||||
autoAssignPortMin = defaultAutoAssignPortMin
|
||||
autoAssignPortMax = defaultAutoAssignPortMax
|
||||
}
|
||||
}
|
||||
|
||||
func portFromEnv(key string, fallback uint16) uint16 {
|
||||
val := os.Getenv(key)
|
||||
if val == "" {
|
||||
return fallback
|
||||
}
|
||||
n, err := strconv.ParseUint(val, 10, 16)
|
||||
if err != nil {
|
||||
log.Warnf("invalid %s value %q, using default %d: %v", key, val, fallback, err)
|
||||
return fallback
|
||||
}
|
||||
return uint16(n)
|
||||
}
|
||||
|
||||
const unknownHostPlaceholder = "unknown"
|
||||
|
||||
// ClusterDeriver derives the proxy cluster from a domain.
|
||||
@@ -115,6 +156,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
|
||||
return fmt.Errorf("unknown target type: %s", target.TargetType)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -197,55 +239,19 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
|
||||
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil {
|
||||
if svc.Domain != "" {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("failed to create service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// persistNewEphemeralService creates an ephemeral service inside a single transaction
|
||||
// that also enforces the duplicate and per-peer limit checks atomically.
|
||||
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
|
||||
// for the same peer, preventing the per-peer limit from being bypassed.
|
||||
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
// Lock the peer row to serialize concurrent creates for the same peer.
|
||||
// Without this, when no ephemeral rows exist yet, FOR UPDATE on the services
|
||||
// table returns no rows and acquires no locks, allowing concurrent inserts
|
||||
// to bypass the per-peer limit.
|
||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
|
||||
return fmt.Errorf("lock peer row: %w", err)
|
||||
}
|
||||
|
||||
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check existing expose: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
|
||||
}
|
||||
|
||||
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("count peer exposes: %w", err)
|
||||
}
|
||||
if count >= int64(maxExposesPerPeer) {
|
||||
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
|
||||
}
|
||||
|
||||
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
|
||||
if err := m.checkPortConflict(ctx, transaction, svc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -261,11 +267,155 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
|
||||
})
|
||||
}
|
||||
|
||||
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
|
||||
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service) error {
|
||||
if !service.IsL4Protocol(svc.Mode) {
|
||||
return nil
|
||||
}
|
||||
customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster)
|
||||
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
|
||||
if svc.Source != service.SourceEphemeral {
|
||||
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
|
||||
}
|
||||
svc.ListenPort = 0
|
||||
}
|
||||
if svc.ListenPort == 0 {
|
||||
port, err := m.assignPort(ctx, tx, svc.ProxyCluster)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
svc.ListenPort = port
|
||||
svc.PortAutoAssigned = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkPortConflict rejects L4 services that would conflict on the same listener.
|
||||
// For TCP/UDP: unique per cluster+protocol+port.
|
||||
// For TLS: unique per cluster+port+domain (SNI routing allows sharing ports).
|
||||
// Cross-protocol conflicts (TLS vs raw TCP) are intentionally not checked:
|
||||
// the proxy router multiplexes TLS (via SNI) and raw TCP (via fallback) on the same listener.
|
||||
func (m *Manager) checkPortConflict(ctx context.Context, transaction store.Store, svc *service.Service) error {
|
||||
if !service.IsL4Protocol(svc.Mode) || svc.ListenPort == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
existing, err := transaction.GetServicesByClusterAndPort(ctx, store.LockingStrengthUpdate, svc.ProxyCluster, svc.Mode, svc.ListenPort)
|
||||
if err != nil {
|
||||
return fmt.Errorf("query port conflicts: %w", err)
|
||||
}
|
||||
for _, s := range existing {
|
||||
if s.ID == svc.ID {
|
||||
continue
|
||||
}
|
||||
// TLS services on the same port are allowed if they have different domains (SNI routing)
|
||||
if svc.Mode == service.ModeTLS && s.Domain != svc.Domain {
|
||||
continue
|
||||
}
|
||||
return status.Errorf(status.AlreadyExists,
|
||||
"%s port %d is already in use by service %q on cluster %s",
|
||||
svc.Mode, svc.ListenPort, s.Name, svc.ProxyCluster)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// assignPort picks a random available port on the cluster within the auto-assign range.
|
||||
func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string) (uint16, error) {
|
||||
services, err := tx.GetServicesByCluster(ctx, store.LockingStrengthUpdate, cluster)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("query cluster ports: %w", err)
|
||||
}
|
||||
|
||||
occupied := make(map[uint16]struct{}, len(services))
|
||||
for _, s := range services {
|
||||
if s.ListenPort > 0 {
|
||||
occupied[s.ListenPort] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
portRange := int(autoAssignPortMax-autoAssignPortMin) + 1
|
||||
for range 100 {
|
||||
port := autoAssignPortMin + uint16(rand.IntN(portRange))
|
||||
if _, taken := occupied[port]; !taken {
|
||||
return port, nil
|
||||
}
|
||||
}
|
||||
|
||||
for port := autoAssignPortMin; port <= autoAssignPortMax; port++ {
|
||||
if _, taken := occupied[port]; !taken {
|
||||
return port, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, status.Errorf(status.PreconditionFailed, "no available ports on cluster %s", cluster)
|
||||
}
|
||||
|
||||
// persistNewEphemeralService creates an ephemeral service inside a single transaction
|
||||
// that also enforces the duplicate and per-peer limit checks atomically.
|
||||
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
|
||||
// for the same peer, preventing the per-peer limit from being bypassed.
|
||||
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
|
||||
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.checkPortConflict(ctx, transaction, svc); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.CreateService(ctx, svc); err != nil {
|
||||
return fmt.Errorf("create service: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) validateEphemeralPreconditions(ctx context.Context, transaction store.Store, accountID, peerID string, svc *service.Service) error {
|
||||
// Lock the peer row to serialize concurrent creates for the same peer.
|
||||
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
|
||||
return fmt.Errorf("lock peer row: %w", err)
|
||||
}
|
||||
|
||||
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check existing expose: %w", err)
|
||||
}
|
||||
if exists {
|
||||
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
|
||||
}
|
||||
|
||||
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("count peer exposes: %w", err)
|
||||
}
|
||||
if count >= int64(maxExposesPerPeer) {
|
||||
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkDomainAvailable checks that no other service already uses this domain.
|
||||
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error {
|
||||
existingService, err := transaction.GetServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
|
||||
return fmt.Errorf("failed to check existing service: %w", err)
|
||||
return fmt.Errorf("check existing service: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -322,6 +472,10 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateInfo.oldCluster = existingService.ProxyCluster
|
||||
updateInfo.domainChanged = existingService.Domain != service.Domain
|
||||
|
||||
@@ -335,12 +489,18 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
|
||||
|
||||
m.preserveExistingAuthSecrets(service, existingService)
|
||||
m.preserveServiceMetadata(service, existingService)
|
||||
m.preserveListenPort(service, existingService)
|
||||
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
|
||||
|
||||
if err := m.ensureL4Port(ctx, transaction, service); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.UpdateService(ctx, service); err != nil {
|
||||
return fmt.Errorf("update service: %w", err)
|
||||
}
|
||||
@@ -351,23 +511,39 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
|
||||
return &updateInfo, err
|
||||
}
|
||||
|
||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil {
|
||||
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
|
||||
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.clusterDeriver != nil {
|
||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
|
||||
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
|
||||
if err != nil {
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
|
||||
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
|
||||
} else {
|
||||
service.ProxyCluster = newCluster
|
||||
svc.ProxyCluster = newCluster
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateProtocolChange rejects mode changes on update.
|
||||
// Only empty<->HTTP is allowed; all other transitions are rejected.
|
||||
func validateProtocolChange(oldMode, newMode string) error {
|
||||
if newMode == "" || newMode == oldMode {
|
||||
return nil
|
||||
}
|
||||
if isHTTPFamily(oldMode) && isHTTPFamily(newMode) {
|
||||
return nil
|
||||
}
|
||||
return status.Errorf(status.InvalidArgument, "cannot change mode from %q to %q", oldMode, newMode)
|
||||
}
|
||||
|
||||
func isHTTPFamily(mode string) bool {
|
||||
return mode == "" || mode == "http"
|
||||
}
|
||||
|
||||
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
|
||||
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
|
||||
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
|
||||
@@ -388,6 +564,13 @@ func (m *Manager) preserveServiceMetadata(service, existingService *service.Serv
|
||||
service.SessionPublicKey = existingService.SessionPublicKey
|
||||
}
|
||||
|
||||
func (m *Manager) preserveListenPort(svc, existing *service.Service) {
|
||||
if existing.ListenPort > 0 && svc.ListenPort == 0 {
|
||||
svc.ListenPort = existing.ListenPort
|
||||
svc.PortAutoAssigned = existing.PortAutoAssigned
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
|
||||
oidcCfg := m.proxyController.GetOIDCValidationConfig()
|
||||
|
||||
@@ -675,6 +858,10 @@ func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerI
|
||||
return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group")
|
||||
}
|
||||
|
||||
func (m *Manager) resolveDefaultDomain(serviceName string) (string, error) {
|
||||
return m.buildRandomDomain(serviceName)
|
||||
}
|
||||
|
||||
// CreateServiceFromPeer creates a service initiated by a peer expose request.
|
||||
// It validates the request, checks expose permissions, enforces the per-peer limit,
|
||||
// creates the service, and tracks it for TTL-based reaping.
|
||||
@@ -696,9 +883,9 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
|
||||
svc.Source = service.SourceEphemeral
|
||||
|
||||
if svc.Domain == "" {
|
||||
domain, err := m.buildRandomDomain(svc.Name)
|
||||
domain, err := m.resolveDefaultDomain(svc.Name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
|
||||
return nil, err
|
||||
}
|
||||
svc.Domain = domain
|
||||
}
|
||||
@@ -739,10 +926,16 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
serviceURL := "https://" + svc.Domain
|
||||
if service.IsL4Protocol(svc.Mode) {
|
||||
serviceURL = fmt.Sprintf("%s://%s:%d", svc.Mode, svc.Domain, svc.ListenPort)
|
||||
}
|
||||
|
||||
return &service.ExposeServiceResponse{
|
||||
ServiceName: svc.Name,
|
||||
ServiceURL: "https://" + svc.Domain,
|
||||
Domain: svc.Domain,
|
||||
ServiceName: svc.Name,
|
||||
ServiceURL: serviceURL,
|
||||
Domain: svc.Domain,
|
||||
PortAutoAssigned: svc.PortAutoAssigned,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -761,64 +954,47 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
|
||||
return groupIDs, nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildRandomDomain(name string) (string, error) {
|
||||
func (m *Manager) getDefaultClusterDomain() (string, error) {
|
||||
if m.clusterDeriver == nil {
|
||||
return "", fmt.Errorf("unable to get random domain")
|
||||
return "", fmt.Errorf("unable to get cluster domain")
|
||||
}
|
||||
clusterDomains := m.clusterDeriver.GetClusterDomains()
|
||||
if len(clusterDomains) == 0 {
|
||||
return "", fmt.Errorf("no cluster domains found for service %s", name)
|
||||
return "", fmt.Errorf("no cluster domains available")
|
||||
}
|
||||
index := rand.IntN(len(clusterDomains))
|
||||
domain := name + "." + clusterDomains[index]
|
||||
return domain, nil
|
||||
return clusterDomains[rand.IntN(len(clusterDomains))], nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildRandomDomain(name string) (string, error) {
|
||||
domain, err := m.getDefaultClusterDomain()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return name + "." + domain, nil
|
||||
}
|
||||
|
||||
// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service.
|
||||
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
return m.store.RenewEphemeralService(ctx, accountID, peerID, domain)
|
||||
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
|
||||
return m.store.RenewEphemeralService(ctx, accountID, peerID, serviceID)
|
||||
}
|
||||
|
||||
// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB.
|
||||
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
|
||||
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
|
||||
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
|
||||
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, serviceID, false); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer-exposed service %s: %v", serviceID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
|
||||
// deleteServiceFromPeer deletes a peer-initiated service identified by service ID.
|
||||
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
|
||||
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
|
||||
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string, expired bool) error {
|
||||
activityCode := activity.PeerServiceUnexposed
|
||||
if expired {
|
||||
activityCode = activity.PeerServiceExposeExpired
|
||||
}
|
||||
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
|
||||
}
|
||||
|
||||
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
|
||||
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
|
||||
svc, err := m.store.GetServiceByDomain(ctx, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if svc.Source != service.SourceEphemeral {
|
||||
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
|
||||
}
|
||||
|
||||
if svc.SourcePeer != peerID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
|
||||
}
|
||||
|
||||
return svc, nil
|
||||
return m.deletePeerService(ctx, accountID, peerID, serviceID, activityCode)
|
||||
}
|
||||
|
||||
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {
|
||||
|
||||
@@ -803,8 +803,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
|
||||
req := &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Port: 8080,
|
||||
Mode: "http",
|
||||
}
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
|
||||
@@ -826,9 +826,9 @@ func TestCreateServiceFromPeer(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
|
||||
req := &rpservice.ExposeServiceRequest{
|
||||
Port: 80,
|
||||
Protocol: "http",
|
||||
Domain: "example.com",
|
||||
Port: 80,
|
||||
Mode: "http",
|
||||
Domain: "example.com",
|
||||
}
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
|
||||
@@ -847,8 +847,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
req := &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Port: 8080,
|
||||
Mode: "http",
|
||||
}
|
||||
|
||||
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
|
||||
@@ -860,8 +860,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
|
||||
req := &rpservice.ExposeServiceRequest{
|
||||
Port: 0,
|
||||
Protocol: "http",
|
||||
Port: 0,
|
||||
Mode: "http",
|
||||
}
|
||||
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
|
||||
@@ -878,62 +878,52 @@ func TestExposeServiceRequestValidate(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "valid http request",
|
||||
req: rpservice.ExposeServiceRequest{Port: 8080, Protocol: "http"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 8080, Mode: "http"},
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "valid https request with pin",
|
||||
req: rpservice.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
|
||||
wantErr: "",
|
||||
name: "https mode rejected",
|
||||
req: rpservice.ExposeServiceRequest{Port: 443, Mode: "https", Pin: "123456"},
|
||||
wantErr: "unsupported mode",
|
||||
},
|
||||
{
|
||||
name: "port zero rejected",
|
||||
req: rpservice.ExposeServiceRequest{Port: 0, Protocol: "http"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 0, Mode: "http"},
|
||||
wantErr: "port must be between 1 and 65535",
|
||||
},
|
||||
{
|
||||
name: "negative port rejected",
|
||||
req: rpservice.ExposeServiceRequest{Port: -1, Protocol: "http"},
|
||||
wantErr: "port must be between 1 and 65535",
|
||||
},
|
||||
{
|
||||
name: "port above 65535 rejected",
|
||||
req: rpservice.ExposeServiceRequest{Port: 65536, Protocol: "http"},
|
||||
wantErr: "port must be between 1 and 65535",
|
||||
},
|
||||
{
|
||||
name: "unsupported protocol",
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
|
||||
wantErr: "unsupported protocol",
|
||||
name: "unsupported mode",
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "ftp"},
|
||||
wantErr: "unsupported mode",
|
||||
},
|
||||
{
|
||||
name: "invalid pin format",
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "abc"},
|
||||
wantErr: "invalid pin",
|
||||
},
|
||||
{
|
||||
name: "pin too short",
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "12345"},
|
||||
wantErr: "invalid pin",
|
||||
},
|
||||
{
|
||||
name: "valid 6-digit pin",
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "000000"},
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "empty user group name",
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", UserGroups: []string{"valid", ""}},
|
||||
wantErr: "user group name cannot be empty",
|
||||
},
|
||||
{
|
||||
name: "invalid name prefix",
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "INVALID"},
|
||||
wantErr: "invalid name prefix",
|
||||
},
|
||||
{
|
||||
name: "valid name prefix",
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
|
||||
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "my-service"},
|
||||
wantErr: "",
|
||||
},
|
||||
}
|
||||
@@ -966,14 +956,14 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
|
||||
|
||||
// First create a service
|
||||
req := &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Port: 8080,
|
||||
Mode: "http",
|
||||
}
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete by domain using unexported method
|
||||
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, false)
|
||||
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
|
||||
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify service is deleted
|
||||
@@ -982,16 +972,17 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("expire uses correct activity", func(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
|
||||
req := &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Port: 8080,
|
||||
Mode: "http",
|
||||
}
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, true)
|
||||
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
|
||||
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, true)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@@ -1003,13 +994,14 @@ func TestStopServiceFromPeer(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
|
||||
req := &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Port: 8080,
|
||||
Mode: "http",
|
||||
}
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
|
||||
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
|
||||
@@ -1022,8 +1014,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Port: 8080,
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -1042,8 +1034,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
|
||||
assert.Equal(t, int64(0), count, "ephemeral service should be deleted after API delete")
|
||||
|
||||
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 9090,
|
||||
Protocol: "http",
|
||||
Port: 9090,
|
||||
Mode: "http",
|
||||
})
|
||||
assert.NoError(t, err, "new expose should succeed after API delete")
|
||||
}
|
||||
@@ -1054,8 +1046,8 @@ func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) {
|
||||
|
||||
for i := range 3 {
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080 + i,
|
||||
Protocol: "http",
|
||||
Port: uint16(8080 + i),
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
@@ -1076,21 +1068,22 @@ func TestRenewServiceFromPeer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("renews tracked expose", func(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Port: 8080,
|
||||
Mode: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
|
||||
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails for untracked domain", func(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
|
||||
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id")
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
@@ -1191,3 +1184,33 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
|
||||
}
|
||||
|
||||
func TestValidateProtocolChange(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
oldP string
|
||||
newP string
|
||||
wantErr bool
|
||||
}{
|
||||
{"empty to http", "", "http", false},
|
||||
{"http to http", "http", "http", false},
|
||||
{"same protocol", "tcp", "tcp", false},
|
||||
{"empty new proto", "tcp", "", false},
|
||||
{"http to tcp", "http", "tcp", true},
|
||||
{"tcp to udp", "tcp", "udp", true},
|
||||
{"tls to http", "tls", "http", true},
|
||||
{"udp to tls", "udp", "tls", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateProtocolChange(tt.oldP, tt.newP)
|
||||
if tt.wantErr {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cannot change mode")
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ const (
|
||||
)
|
||||
|
||||
type Status string
|
||||
type TargetType string
|
||||
|
||||
const (
|
||||
StatusPending Status = "pending"
|
||||
@@ -43,34 +44,36 @@ const (
|
||||
StatusCertificateFailed Status = "certificate_failed"
|
||||
StatusError Status = "error"
|
||||
|
||||
TargetTypePeer = "peer"
|
||||
TargetTypeHost = "host"
|
||||
TargetTypeDomain = "domain"
|
||||
TargetTypeSubnet = "subnet"
|
||||
TargetTypePeer TargetType = "peer"
|
||||
TargetTypeHost TargetType = "host"
|
||||
TargetTypeDomain TargetType = "domain"
|
||||
TargetTypeSubnet TargetType = "subnet"
|
||||
|
||||
SourcePermanent = "permanent"
|
||||
SourceEphemeral = "ephemeral"
|
||||
)
|
||||
|
||||
type TargetOptions struct {
|
||||
SkipTLSVerify bool `json:"skip_tls_verify"`
|
||||
RequestTimeout time.Duration `json:"request_timeout,omitempty"`
|
||||
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
|
||||
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
|
||||
SkipTLSVerify bool `json:"skip_tls_verify"`
|
||||
RequestTimeout time.Duration `json:"request_timeout,omitempty"`
|
||||
SessionIdleTimeout time.Duration `json:"session_idle_timeout,omitempty"`
|
||||
PathRewrite PathRewriteMode `json:"path_rewrite,omitempty"`
|
||||
CustomHeaders map[string]string `gorm:"serializer:json" json:"custom_headers,omitempty"`
|
||||
}
|
||||
|
||||
type Target struct {
|
||||
ID uint `gorm:"primaryKey" json:"-"`
|
||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||
Port int `gorm:"index:idx_target_port" json:"port"`
|
||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||
TargetType string `gorm:"index:idx_target_type" json:"target_type"`
|
||||
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
||||
Options TargetOptions `gorm:"embedded" json:"options"`
|
||||
ID uint `gorm:"primaryKey" json:"-"`
|
||||
AccountID string `gorm:"index:idx_target_account;not null" json:"-"`
|
||||
ServiceID string `gorm:"index:idx_service_targets;not null" json:"-"`
|
||||
Path *string `json:"path,omitempty"`
|
||||
Host string `json:"host"` // the Host field is only used for subnet targets, otherwise ignored
|
||||
Port uint16 `gorm:"index:idx_target_port" json:"port"`
|
||||
Protocol string `gorm:"index:idx_target_protocol" json:"protocol"`
|
||||
TargetId string `gorm:"index:idx_target_id" json:"target_id"`
|
||||
TargetType TargetType `gorm:"index:idx_target_type" json:"target_type"`
|
||||
Enabled bool `gorm:"index:idx_target_enabled" json:"enabled"`
|
||||
Options TargetOptions `gorm:"embedded" json:"options"`
|
||||
ProxyProtocol bool `json:"proxy_protocol"`
|
||||
}
|
||||
|
||||
type PasswordAuthConfig struct {
|
||||
@@ -146,23 +149,10 @@ type Service struct {
|
||||
SessionPublicKey string `gorm:"column:session_public_key"`
|
||||
Source string `gorm:"default:'permanent';index:idx_service_source_peer"`
|
||||
SourcePeer string `gorm:"index:idx_service_source_peer"`
|
||||
}
|
||||
|
||||
func NewService(accountID, name, domain, proxyCluster string, targets []*Target, enabled bool) *Service {
|
||||
for _, target := range targets {
|
||||
target.AccountID = accountID
|
||||
}
|
||||
|
||||
s := &Service{
|
||||
AccountID: accountID,
|
||||
Name: name,
|
||||
Domain: domain,
|
||||
ProxyCluster: proxyCluster,
|
||||
Targets: targets,
|
||||
Enabled: enabled,
|
||||
}
|
||||
s.InitNewRecord()
|
||||
return s
|
||||
// Mode determines the service type: "http", "tcp", "udp", or "tls".
|
||||
Mode string `gorm:"default:'http'"`
|
||||
ListenPort uint16
|
||||
PortAutoAssigned bool
|
||||
}
|
||||
|
||||
// InitNewRecord generates a new unique ID and resets metadata for a newly created
|
||||
@@ -177,21 +167,17 @@ func (s *Service) InitNewRecord() {
|
||||
}
|
||||
|
||||
func (s *Service) ToAPIResponse() *api.Service {
|
||||
s.Auth.ClearSecrets()
|
||||
|
||||
authConfig := api.ServiceAuthConfig{}
|
||||
|
||||
if s.Auth.PasswordAuth != nil {
|
||||
authConfig.PasswordAuth = &api.PasswordAuthConfig{
|
||||
Enabled: s.Auth.PasswordAuth.Enabled,
|
||||
Password: s.Auth.PasswordAuth.Password,
|
||||
Enabled: s.Auth.PasswordAuth.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
if s.Auth.PinAuth != nil {
|
||||
authConfig.PinAuth = &api.PINAuthConfig{
|
||||
Enabled: s.Auth.PinAuth.Enabled,
|
||||
Pin: s.Auth.PinAuth.Pin,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,13 +194,18 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
st := api.ServiceTarget{
|
||||
Path: target.Path,
|
||||
Host: &target.Host,
|
||||
Port: target.Port,
|
||||
Port: int(target.Port),
|
||||
Protocol: api.ServiceTargetProtocol(target.Protocol),
|
||||
TargetId: target.TargetId,
|
||||
TargetType: api.ServiceTargetTargetType(target.TargetType),
|
||||
Enabled: target.Enabled,
|
||||
}
|
||||
st.Options = targetOptionsToAPI(target.Options)
|
||||
opts := targetOptionsToAPI(target.Options)
|
||||
if opts == nil {
|
||||
opts = &api.ServiceTargetOptions{}
|
||||
}
|
||||
opts.ProxyProtocol = &target.ProxyProtocol
|
||||
st.Options = opts
|
||||
apiTargets = append(apiTargets, st)
|
||||
}
|
||||
|
||||
@@ -227,6 +218,9 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
meta.CertificateIssuedAt = s.Meta.CertificateIssuedAt
|
||||
}
|
||||
|
||||
mode := api.ServiceMode(s.Mode)
|
||||
listenPort := int(s.ListenPort)
|
||||
|
||||
resp := &api.Service{
|
||||
Id: s.ID,
|
||||
Name: s.Name,
|
||||
@@ -237,6 +231,9 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
RewriteRedirects: &s.RewriteRedirects,
|
||||
Auth: authConfig,
|
||||
Meta: meta,
|
||||
Mode: &mode,
|
||||
ListenPort: &listenPort,
|
||||
PortAutoAssigned: &s.PortAutoAssigned,
|
||||
}
|
||||
|
||||
if s.ProxyCluster != "" {
|
||||
@@ -247,37 +244,7 @@ func (s *Service) ToAPIResponse() *api.Service {
|
||||
}
|
||||
|
||||
func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConfig proxy.OIDCValidationConfig) *proto.ProxyMapping {
|
||||
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
|
||||
for _, target := range s.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO: Make path prefix stripping configurable per-target.
|
||||
// Currently the matching prefix is baked into the target URL path,
|
||||
// so the proxy strips-then-re-adds it (effectively a no-op).
|
||||
targetURL := url.URL{
|
||||
Scheme: target.Protocol,
|
||||
Host: target.Host,
|
||||
Path: "/", // TODO: support service path
|
||||
}
|
||||
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
|
||||
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.Itoa(target.Port))
|
||||
}
|
||||
|
||||
path := "/"
|
||||
if target.Path != nil {
|
||||
path = *target.Path
|
||||
}
|
||||
|
||||
pm := &proto.PathMapping{
|
||||
Path: path,
|
||||
Target: targetURL.String(),
|
||||
}
|
||||
|
||||
pm.Options = targetOptionsToProto(target.Options)
|
||||
pathMappings = append(pathMappings, pm)
|
||||
}
|
||||
pathMappings := s.buildPathMappings()
|
||||
|
||||
auth := &proto.Authentication{
|
||||
SessionKey: s.SessionPublicKey,
|
||||
@@ -306,9 +273,58 @@ func (s *Service) ToProtoMapping(operation Operation, authToken string, oidcConf
|
||||
AccountId: s.AccountID,
|
||||
PassHostHeader: s.PassHostHeader,
|
||||
RewriteRedirects: s.RewriteRedirects,
|
||||
Mode: s.Mode,
|
||||
ListenPort: int32(s.ListenPort), //nolint:gosec
|
||||
}
|
||||
}
|
||||
|
||||
// buildPathMappings constructs PathMapping entries from targets.
|
||||
// For HTTP/HTTPS, each target becomes a path-based route with a full URL.
|
||||
// For L4/TLS, a single target maps to a host:port address.
|
||||
func (s *Service) buildPathMappings() []*proto.PathMapping {
|
||||
pathMappings := make([]*proto.PathMapping, 0, len(s.Targets))
|
||||
for _, target := range s.Targets {
|
||||
if !target.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
if IsL4Protocol(s.Mode) {
|
||||
pm := &proto.PathMapping{
|
||||
Target: net.JoinHostPort(target.Host, strconv.FormatUint(uint64(target.Port), 10)),
|
||||
}
|
||||
opts := l4TargetOptionsToProto(target)
|
||||
if opts != nil {
|
||||
pm.Options = opts
|
||||
}
|
||||
pathMappings = append(pathMappings, pm)
|
||||
continue
|
||||
}
|
||||
|
||||
// HTTP/HTTPS: build full URL
|
||||
targetURL := url.URL{
|
||||
Scheme: target.Protocol,
|
||||
Host: target.Host,
|
||||
Path: "/",
|
||||
}
|
||||
if target.Port > 0 && !isDefaultPort(target.Protocol, target.Port) {
|
||||
targetURL.Host = net.JoinHostPort(targetURL.Host, strconv.FormatUint(uint64(target.Port), 10))
|
||||
}
|
||||
|
||||
path := "/"
|
||||
if target.Path != nil {
|
||||
path = *target.Path
|
||||
}
|
||||
|
||||
pm := &proto.PathMapping{
|
||||
Path: path,
|
||||
Target: targetURL.String(),
|
||||
}
|
||||
pm.Options = targetOptionsToProto(target.Options)
|
||||
pathMappings = append(pathMappings, pm)
|
||||
}
|
||||
return pathMappings
|
||||
}
|
||||
|
||||
func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
||||
switch op {
|
||||
case Create:
|
||||
@@ -325,8 +341,8 @@ func operationToProtoType(op Operation) proto.ProxyMappingUpdateType {
|
||||
|
||||
// isDefaultPort reports whether port is the standard default for the given scheme
|
||||
// (443 for https, 80 for http).
|
||||
func isDefaultPort(scheme string, port int) bool {
|
||||
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
|
||||
func isDefaultPort(scheme string, port uint16) bool {
|
||||
return (scheme == TargetProtoHTTPS && port == 443) || (scheme == TargetProtoHTTP && port == 80)
|
||||
}
|
||||
|
||||
// PathRewriteMode controls how the request path is rewritten before forwarding.
|
||||
@@ -346,7 +362,7 @@ func pathRewriteToProto(mode PathRewriteMode) proto.PathRewriteMode {
|
||||
}
|
||||
|
||||
func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
|
||||
if !opts.SkipTLSVerify && opts.RequestTimeout == 0 && opts.SessionIdleTimeout == 0 && opts.PathRewrite == "" && len(opts.CustomHeaders) == 0 {
|
||||
return nil
|
||||
}
|
||||
apiOpts := &api.ServiceTargetOptions{}
|
||||
@@ -357,6 +373,10 @@ func targetOptionsToAPI(opts TargetOptions) *api.ServiceTargetOptions {
|
||||
s := opts.RequestTimeout.String()
|
||||
apiOpts.RequestTimeout = &s
|
||||
}
|
||||
if opts.SessionIdleTimeout != 0 {
|
||||
s := opts.SessionIdleTimeout.String()
|
||||
apiOpts.SessionIdleTimeout = &s
|
||||
}
|
||||
if opts.PathRewrite != "" {
|
||||
pr := api.ServiceTargetOptionsPathRewrite(opts.PathRewrite)
|
||||
apiOpts.PathRewrite = &pr
|
||||
@@ -382,6 +402,23 @@ func targetOptionsToProto(opts TargetOptions) *proto.PathTargetOptions {
|
||||
return popts
|
||||
}
|
||||
|
||||
// l4TargetOptionsToProto converts L4-relevant target options to proto.
|
||||
func l4TargetOptionsToProto(target *Target) *proto.PathTargetOptions {
|
||||
if !target.ProxyProtocol && target.Options.RequestTimeout == 0 && target.Options.SessionIdleTimeout == 0 {
|
||||
return nil
|
||||
}
|
||||
opts := &proto.PathTargetOptions{
|
||||
ProxyProtocol: target.ProxyProtocol,
|
||||
}
|
||||
if target.Options.RequestTimeout > 0 {
|
||||
opts.RequestTimeout = durationpb.New(target.Options.RequestTimeout)
|
||||
}
|
||||
if target.Options.SessionIdleTimeout > 0 {
|
||||
opts.SessionIdleTimeout = durationpb.New(target.Options.SessionIdleTimeout)
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions, error) {
|
||||
var opts TargetOptions
|
||||
if o.SkipTlsVerify != nil {
|
||||
@@ -394,6 +431,13 @@ func targetOptionsFromAPI(idx int, o *api.ServiceTargetOptions) (TargetOptions,
|
||||
}
|
||||
opts.RequestTimeout = d
|
||||
}
|
||||
if o.SessionIdleTimeout != nil {
|
||||
d, err := time.ParseDuration(*o.SessionIdleTimeout)
|
||||
if err != nil {
|
||||
return opts, fmt.Errorf("target %d: parse session_idle_timeout %q: %w", idx, *o.SessionIdleTimeout, err)
|
||||
}
|
||||
opts.SessionIdleTimeout = d
|
||||
}
|
||||
if o.PathRewrite != nil {
|
||||
opts.PathRewrite = PathRewriteMode(*o.PathRewrite)
|
||||
}
|
||||
@@ -408,15 +452,49 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
|
||||
s.Domain = req.Domain
|
||||
s.AccountID = accountID
|
||||
|
||||
targets := make([]*Target, 0, len(req.Targets))
|
||||
for i, apiTarget := range req.Targets {
|
||||
if req.Mode != nil {
|
||||
s.Mode = string(*req.Mode)
|
||||
}
|
||||
if req.ListenPort != nil {
|
||||
s.ListenPort = uint16(*req.ListenPort) //nolint:gosec
|
||||
}
|
||||
|
||||
targets, err := targetsFromAPI(accountID, req.Targets)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Targets = targets
|
||||
s.Enabled = req.Enabled
|
||||
|
||||
if req.PassHostHeader != nil {
|
||||
s.PassHostHeader = *req.PassHostHeader
|
||||
}
|
||||
if req.RewriteRedirects != nil {
|
||||
s.RewriteRedirects = *req.RewriteRedirects
|
||||
}
|
||||
|
||||
if req.Auth != nil {
|
||||
s.Auth = authFromAPI(req.Auth)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func targetsFromAPI(accountID string, apiTargetsPtr *[]api.ServiceTarget) ([]*Target, error) {
|
||||
var apiTargets []api.ServiceTarget
|
||||
if apiTargetsPtr != nil {
|
||||
apiTargets = *apiTargetsPtr
|
||||
}
|
||||
|
||||
targets := make([]*Target, 0, len(apiTargets))
|
||||
for i, apiTarget := range apiTargets {
|
||||
target := &Target{
|
||||
AccountID: accountID,
|
||||
Path: apiTarget.Path,
|
||||
Port: apiTarget.Port,
|
||||
Port: uint16(apiTarget.Port), //nolint:gosec // validated by API layer
|
||||
Protocol: string(apiTarget.Protocol),
|
||||
TargetId: apiTarget.TargetId,
|
||||
TargetType: string(apiTarget.TargetType),
|
||||
TargetType: TargetType(apiTarget.TargetType),
|
||||
Enabled: apiTarget.Enabled,
|
||||
}
|
||||
if apiTarget.Host != nil {
|
||||
@@ -425,49 +503,42 @@ func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) erro
|
||||
if apiTarget.Options != nil {
|
||||
opts, err := targetOptionsFromAPI(i, apiTarget.Options)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
target.Options = opts
|
||||
if apiTarget.Options.ProxyProtocol != nil {
|
||||
target.ProxyProtocol = *apiTarget.Options.ProxyProtocol
|
||||
}
|
||||
}
|
||||
targets = append(targets, target)
|
||||
}
|
||||
s.Targets = targets
|
||||
return targets, nil
|
||||
}
|
||||
|
||||
s.Enabled = req.Enabled
|
||||
|
||||
if req.PassHostHeader != nil {
|
||||
s.PassHostHeader = *req.PassHostHeader
|
||||
}
|
||||
|
||||
if req.RewriteRedirects != nil {
|
||||
s.RewriteRedirects = *req.RewriteRedirects
|
||||
}
|
||||
|
||||
if req.Auth.PasswordAuth != nil {
|
||||
s.Auth.PasswordAuth = &PasswordAuthConfig{
|
||||
Enabled: req.Auth.PasswordAuth.Enabled,
|
||||
Password: req.Auth.PasswordAuth.Password,
|
||||
func authFromAPI(reqAuth *api.ServiceAuthConfig) AuthConfig {
|
||||
var auth AuthConfig
|
||||
if reqAuth.PasswordAuth != nil {
|
||||
auth.PasswordAuth = &PasswordAuthConfig{
|
||||
Enabled: reqAuth.PasswordAuth.Enabled,
|
||||
Password: reqAuth.PasswordAuth.Password,
|
||||
}
|
||||
}
|
||||
|
||||
if req.Auth.PinAuth != nil {
|
||||
s.Auth.PinAuth = &PINAuthConfig{
|
||||
Enabled: req.Auth.PinAuth.Enabled,
|
||||
Pin: req.Auth.PinAuth.Pin,
|
||||
if reqAuth.PinAuth != nil {
|
||||
auth.PinAuth = &PINAuthConfig{
|
||||
Enabled: reqAuth.PinAuth.Enabled,
|
||||
Pin: reqAuth.PinAuth.Pin,
|
||||
}
|
||||
}
|
||||
|
||||
if req.Auth.BearerAuth != nil {
|
||||
if reqAuth.BearerAuth != nil {
|
||||
bearerAuth := &BearerAuthConfig{
|
||||
Enabled: req.Auth.BearerAuth.Enabled,
|
||||
Enabled: reqAuth.BearerAuth.Enabled,
|
||||
}
|
||||
if req.Auth.BearerAuth.DistributionGroups != nil {
|
||||
bearerAuth.DistributionGroups = *req.Auth.BearerAuth.DistributionGroups
|
||||
if reqAuth.BearerAuth.DistributionGroups != nil {
|
||||
bearerAuth.DistributionGroups = *reqAuth.BearerAuth.DistributionGroups
|
||||
}
|
||||
s.Auth.BearerAuth = bearerAuth
|
||||
auth.BearerAuth = bearerAuth
|
||||
}
|
||||
|
||||
return nil
|
||||
return auth
|
||||
}
|
||||
|
||||
func (s *Service) Validate() error {
|
||||
@@ -478,14 +549,69 @@ func (s *Service) Validate() error {
|
||||
return errors.New("service name exceeds maximum length of 255 characters")
|
||||
}
|
||||
|
||||
if s.Domain == "" {
|
||||
return errors.New("service domain is required")
|
||||
}
|
||||
|
||||
if len(s.Targets) == 0 {
|
||||
return errors.New("at least one target is required")
|
||||
}
|
||||
|
||||
if s.Mode == "" {
|
||||
s.Mode = ModeHTTP
|
||||
}
|
||||
|
||||
switch s.Mode {
|
||||
case ModeHTTP:
|
||||
return s.validateHTTPMode()
|
||||
case ModeTCP, ModeUDP:
|
||||
return s.validateTCPUDPMode()
|
||||
case ModeTLS:
|
||||
return s.validateTLSMode()
|
||||
default:
|
||||
return fmt.Errorf("unsupported mode %q", s.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) validateHTTPMode() error {
|
||||
if s.Domain == "" {
|
||||
return errors.New("service domain is required")
|
||||
}
|
||||
if s.ListenPort != 0 {
|
||||
return errors.New("listen_port is not supported for HTTP services")
|
||||
}
|
||||
return s.validateHTTPTargets()
|
||||
}
|
||||
|
||||
func (s *Service) validateTCPUDPMode() error {
|
||||
if s.Domain == "" {
|
||||
return errors.New("domain is required for TCP/UDP services (used for cluster derivation)")
|
||||
}
|
||||
if s.isAuthEnabled() {
|
||||
return errors.New("auth is not supported for TCP/UDP services")
|
||||
}
|
||||
if len(s.Targets) != 1 {
|
||||
return errors.New("TCP/UDP services must have exactly one target")
|
||||
}
|
||||
if s.Mode == ModeUDP && s.Targets[0].ProxyProtocol {
|
||||
return errors.New("proxy_protocol is not supported for UDP services")
|
||||
}
|
||||
return s.validateL4Target(s.Targets[0])
|
||||
}
|
||||
|
||||
func (s *Service) validateTLSMode() error {
|
||||
if s.Domain == "" {
|
||||
return errors.New("domain is required for TLS services (used for SNI matching)")
|
||||
}
|
||||
if s.isAuthEnabled() {
|
||||
return errors.New("auth is not supported for TLS services")
|
||||
}
|
||||
if s.ListenPort == 0 {
|
||||
return errors.New("listen_port is required for TLS services")
|
||||
}
|
||||
if len(s.Targets) != 1 {
|
||||
return errors.New("TLS services must have exactly one target")
|
||||
}
|
||||
return s.validateL4Target(s.Targets[0])
|
||||
}
|
||||
|
||||
func (s *Service) validateHTTPTargets() error {
|
||||
for i, target := range s.Targets {
|
||||
switch target.TargetType {
|
||||
case TargetTypePeer, TargetTypeHost, TargetTypeDomain:
|
||||
@@ -500,6 +626,9 @@ func (s *Service) Validate() error {
|
||||
if target.TargetId == "" {
|
||||
return fmt.Errorf("target %d has empty target_id", i)
|
||||
}
|
||||
if target.ProxyProtocol {
|
||||
return fmt.Errorf("target %d: proxy_protocol is not supported for HTTP services", i)
|
||||
}
|
||||
if err := validateTargetOptions(i, &target.Options); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -508,11 +637,62 @@ func (s *Service) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) validateL4Target(target *Target) error {
|
||||
if target.Port == 0 {
|
||||
return errors.New("target port is required for L4 services")
|
||||
}
|
||||
if target.TargetId == "" {
|
||||
return errors.New("target_id is required for L4 services")
|
||||
}
|
||||
switch target.TargetType {
|
||||
case TargetTypePeer, TargetTypeHost:
|
||||
// OK
|
||||
case TargetTypeSubnet:
|
||||
if target.Host == "" {
|
||||
return errors.New("target host is required for subnet targets")
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("invalid target_type %q for L4 service", target.TargetType)
|
||||
}
|
||||
if target.Path != nil && *target.Path != "" && *target.Path != "/" {
|
||||
return errors.New("path is not supported for L4 services")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Service mode constants.
|
||||
const (
|
||||
maxRequestTimeout = 5 * time.Minute
|
||||
maxCustomHeaders = 16
|
||||
maxHeaderKeyLen = 128
|
||||
maxHeaderValueLen = 4096
|
||||
ModeHTTP = "http"
|
||||
ModeTCP = "tcp"
|
||||
ModeUDP = "udp"
|
||||
ModeTLS = "tls"
|
||||
)
|
||||
|
||||
// Target protocol constants (URL scheme for backend connections).
|
||||
const (
|
||||
TargetProtoHTTP = "http"
|
||||
TargetProtoHTTPS = "https"
|
||||
TargetProtoTCP = "tcp"
|
||||
TargetProtoUDP = "udp"
|
||||
)
|
||||
|
||||
// IsL4Protocol returns true if the mode requires port-based routing (TCP, UDP, or TLS).
|
||||
func IsL4Protocol(mode string) bool {
|
||||
return mode == ModeTCP || mode == ModeUDP || mode == ModeTLS
|
||||
}
|
||||
|
||||
// IsPortBasedProtocol returns true if the mode relies on dedicated port allocation.
|
||||
// TLS is excluded because it uses SNI routing and can share ports with other TLS services.
|
||||
func IsPortBasedProtocol(mode string) bool {
|
||||
return mode == ModeTCP || mode == ModeUDP
|
||||
}
|
||||
|
||||
const (
|
||||
maxRequestTimeout = 5 * time.Minute
|
||||
maxSessionIdleTimeout = 10 * time.Minute
|
||||
maxCustomHeaders = 16
|
||||
maxHeaderKeyLen = 128
|
||||
maxHeaderValueLen = 4096
|
||||
)
|
||||
|
||||
// httpHeaderNameRe matches valid HTTP header field names per RFC 7230 token definition.
|
||||
@@ -560,6 +740,15 @@ func validateTargetOptions(idx int, opts *TargetOptions) error {
|
||||
}
|
||||
}
|
||||
|
||||
if opts.SessionIdleTimeout != 0 {
|
||||
if opts.SessionIdleTimeout <= 0 {
|
||||
return fmt.Errorf("target %d: session_idle_timeout must be positive", idx)
|
||||
}
|
||||
if opts.SessionIdleTimeout > maxSessionIdleTimeout {
|
||||
return fmt.Errorf("target %d: session_idle_timeout exceeds maximum of %s", idx, maxSessionIdleTimeout)
|
||||
}
|
||||
}
|
||||
|
||||
if err := validateCustomHeaders(idx, opts.CustomHeaders); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -608,17 +797,49 @@ func containsCRLF(s string) bool {
|
||||
}
|
||||
|
||||
func (s *Service) EventMeta() map[string]any {
|
||||
return map[string]any{"name": s.Name, "domain": s.Domain, "proxy_cluster": s.ProxyCluster, "source": s.Source, "auth": s.isAuthEnabled()}
|
||||
meta := map[string]any{
|
||||
"name": s.Name,
|
||||
"domain": s.Domain,
|
||||
"proxy_cluster": s.ProxyCluster,
|
||||
"source": s.Source,
|
||||
"auth": s.isAuthEnabled(),
|
||||
"mode": s.Mode,
|
||||
}
|
||||
|
||||
if s.ListenPort != 0 {
|
||||
meta["listen_port"] = s.ListenPort
|
||||
}
|
||||
|
||||
if len(s.Targets) > 0 {
|
||||
t := s.Targets[0]
|
||||
if t.ProxyProtocol {
|
||||
meta["proxy_protocol"] = true
|
||||
}
|
||||
if t.Options.RequestTimeout != 0 {
|
||||
meta["request_timeout"] = t.Options.RequestTimeout.String()
|
||||
}
|
||||
if t.Options.SessionIdleTimeout != 0 {
|
||||
meta["session_idle_timeout"] = t.Options.SessionIdleTimeout.String()
|
||||
}
|
||||
}
|
||||
|
||||
return meta
|
||||
}
|
||||
|
||||
func (s *Service) isAuthEnabled() bool {
|
||||
return s.Auth.PasswordAuth != nil || s.Auth.PinAuth != nil || s.Auth.BearerAuth != nil
|
||||
return (s.Auth.PasswordAuth != nil && s.Auth.PasswordAuth.Enabled) ||
|
||||
(s.Auth.PinAuth != nil && s.Auth.PinAuth.Enabled) ||
|
||||
(s.Auth.BearerAuth != nil && s.Auth.BearerAuth.Enabled)
|
||||
}
|
||||
|
||||
func (s *Service) Copy() *Service {
|
||||
targets := make([]*Target, len(s.Targets))
|
||||
for i, target := range s.Targets {
|
||||
targetCopy := *target
|
||||
if target.Path != nil {
|
||||
p := *target.Path
|
||||
targetCopy.Path = &p
|
||||
}
|
||||
if len(target.Options.CustomHeaders) > 0 {
|
||||
targetCopy.Options.CustomHeaders = make(map[string]string, len(target.Options.CustomHeaders))
|
||||
for k, v := range target.Options.CustomHeaders {
|
||||
@@ -628,6 +849,24 @@ func (s *Service) Copy() *Service {
|
||||
targets[i] = &targetCopy
|
||||
}
|
||||
|
||||
authCopy := s.Auth
|
||||
if s.Auth.PasswordAuth != nil {
|
||||
pa := *s.Auth.PasswordAuth
|
||||
authCopy.PasswordAuth = &pa
|
||||
}
|
||||
if s.Auth.PinAuth != nil {
|
||||
pa := *s.Auth.PinAuth
|
||||
authCopy.PinAuth = &pa
|
||||
}
|
||||
if s.Auth.BearerAuth != nil {
|
||||
ba := *s.Auth.BearerAuth
|
||||
if len(s.Auth.BearerAuth.DistributionGroups) > 0 {
|
||||
ba.DistributionGroups = make([]string, len(s.Auth.BearerAuth.DistributionGroups))
|
||||
copy(ba.DistributionGroups, s.Auth.BearerAuth.DistributionGroups)
|
||||
}
|
||||
authCopy.BearerAuth = &ba
|
||||
}
|
||||
|
||||
return &Service{
|
||||
ID: s.ID,
|
||||
AccountID: s.AccountID,
|
||||
@@ -638,12 +877,15 @@ func (s *Service) Copy() *Service {
|
||||
Enabled: s.Enabled,
|
||||
PassHostHeader: s.PassHostHeader,
|
||||
RewriteRedirects: s.RewriteRedirects,
|
||||
Auth: s.Auth,
|
||||
Auth: authCopy,
|
||||
Meta: s.Meta,
|
||||
SessionPrivateKey: s.SessionPrivateKey,
|
||||
SessionPublicKey: s.SessionPublicKey,
|
||||
Source: s.Source,
|
||||
SourcePeer: s.SourcePeer,
|
||||
Mode: s.Mode,
|
||||
ListenPort: s.ListenPort,
|
||||
PortAutoAssigned: s.PortAutoAssigned,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -688,12 +930,16 @@ var validNamePrefix = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,30}[a-z0-9])?$`)
|
||||
// ExposeServiceRequest contains the parameters for creating a peer-initiated expose service.
|
||||
type ExposeServiceRequest struct {
|
||||
NamePrefix string
|
||||
Port int
|
||||
Protocol string
|
||||
Domain string
|
||||
Pin string
|
||||
Password string
|
||||
UserGroups []string
|
||||
Port uint16
|
||||
Mode string
|
||||
// TargetProtocol is the protocol used to connect to the peer backend.
|
||||
// For HTTP mode: "http" (default) or "https". For L4 modes: "tcp" or "udp".
|
||||
TargetProtocol string
|
||||
Domain string
|
||||
Pin string
|
||||
Password string
|
||||
UserGroups []string
|
||||
ListenPort uint16
|
||||
}
|
||||
|
||||
// Validate checks all fields of the expose request.
|
||||
@@ -702,12 +948,20 @@ func (r *ExposeServiceRequest) Validate() error {
|
||||
return errors.New("request cannot be nil")
|
||||
}
|
||||
|
||||
if r.Port < 1 || r.Port > 65535 {
|
||||
if r.Port == 0 {
|
||||
return fmt.Errorf("port must be between 1 and 65535, got %d", r.Port)
|
||||
}
|
||||
|
||||
if r.Protocol != "http" && r.Protocol != "https" {
|
||||
return fmt.Errorf("unsupported protocol %q: must be http or https", r.Protocol)
|
||||
switch r.Mode {
|
||||
case ModeHTTP, ModeTCP, ModeUDP, ModeTLS:
|
||||
default:
|
||||
return fmt.Errorf("unsupported mode %q", r.Mode)
|
||||
}
|
||||
|
||||
if IsL4Protocol(r.Mode) {
|
||||
if r.Pin != "" || r.Password != "" || len(r.UserGroups) > 0 {
|
||||
return fmt.Errorf("authentication is not supported for %s mode", r.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
if r.Pin != "" && !pinRegexp.MatchString(r.Pin) {
|
||||
@@ -729,55 +983,79 @@ func (r *ExposeServiceRequest) Validate() error {
|
||||
|
||||
// ToService builds a Service from the expose request.
|
||||
func (r *ExposeServiceRequest) ToService(accountID, peerID, serviceName string) *Service {
|
||||
service := &Service{
|
||||
svc := &Service{
|
||||
AccountID: accountID,
|
||||
Name: serviceName,
|
||||
Mode: r.Mode,
|
||||
Enabled: true,
|
||||
Targets: []*Target{
|
||||
{
|
||||
AccountID: accountID,
|
||||
Port: r.Port,
|
||||
Protocol: r.Protocol,
|
||||
TargetId: peerID,
|
||||
TargetType: TargetTypePeer,
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
// If domain is empty, CreateServiceFromPeer generates a unique subdomain.
|
||||
// When explicitly provided, the service name is prepended as a subdomain.
|
||||
if r.Domain != "" {
|
||||
svc.Domain = serviceName + "." + r.Domain
|
||||
}
|
||||
|
||||
if IsL4Protocol(r.Mode) {
|
||||
svc.ListenPort = r.Port
|
||||
if r.ListenPort > 0 {
|
||||
svc.ListenPort = r.ListenPort
|
||||
}
|
||||
}
|
||||
|
||||
var targetProto string
|
||||
switch {
|
||||
case !IsL4Protocol(r.Mode):
|
||||
targetProto = TargetProtoHTTP
|
||||
if r.TargetProtocol != "" {
|
||||
targetProto = r.TargetProtocol
|
||||
}
|
||||
case r.Mode == ModeUDP:
|
||||
targetProto = TargetProtoUDP
|
||||
default:
|
||||
targetProto = TargetProtoTCP
|
||||
}
|
||||
svc.Targets = []*Target{
|
||||
{
|
||||
AccountID: accountID,
|
||||
Port: r.Port,
|
||||
Protocol: targetProto,
|
||||
TargetId: peerID,
|
||||
TargetType: TargetTypePeer,
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
|
||||
if r.Domain != "" {
|
||||
service.Domain = serviceName + "." + r.Domain
|
||||
}
|
||||
|
||||
if r.Pin != "" {
|
||||
service.Auth.PinAuth = &PINAuthConfig{
|
||||
svc.Auth.PinAuth = &PINAuthConfig{
|
||||
Enabled: true,
|
||||
Pin: r.Pin,
|
||||
}
|
||||
}
|
||||
|
||||
if r.Password != "" {
|
||||
service.Auth.PasswordAuth = &PasswordAuthConfig{
|
||||
svc.Auth.PasswordAuth = &PasswordAuthConfig{
|
||||
Enabled: true,
|
||||
Password: r.Password,
|
||||
}
|
||||
}
|
||||
|
||||
if len(r.UserGroups) > 0 {
|
||||
service.Auth.BearerAuth = &BearerAuthConfig{
|
||||
svc.Auth.BearerAuth = &BearerAuthConfig{
|
||||
Enabled: true,
|
||||
DistributionGroups: r.UserGroups,
|
||||
}
|
||||
}
|
||||
|
||||
return service
|
||||
return svc
|
||||
}
|
||||
|
||||
// ExposeServiceResponse contains the result of a successful peer expose creation.
|
||||
type ExposeServiceResponse struct {
|
||||
ServiceName string
|
||||
ServiceURL string
|
||||
Domain string
|
||||
ServiceName string
|
||||
ServiceURL string
|
||||
Domain string
|
||||
PortAutoAssigned bool
|
||||
}
|
||||
|
||||
// GenerateExposeName generates a random service name for peer-exposed services.
|
||||
|
||||
@@ -44,7 +44,7 @@ func TestValidate_EmptyDomain(t *testing.T) {
|
||||
func TestValidate_NoTargets(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets = nil
|
||||
assert.ErrorContains(t, rp.Validate(), "at least one target")
|
||||
assert.ErrorContains(t, rp.Validate(), "at least one target is required")
|
||||
}
|
||||
|
||||
func TestValidate_EmptyTargetId(t *testing.T) {
|
||||
@@ -273,7 +273,7 @@ func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) {
|
||||
func TestIsDefaultPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
scheme string
|
||||
port int
|
||||
port uint16
|
||||
want bool
|
||||
}{
|
||||
{"http", 80, true},
|
||||
@@ -299,7 +299,7 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) {
|
||||
name string
|
||||
protocol string
|
||||
host string
|
||||
port int
|
||||
port uint16
|
||||
wantTarget string
|
||||
}{
|
||||
{
|
||||
@@ -645,8 +645,8 @@ func TestGenerateExposeName(t *testing.T) {
|
||||
func TestExposeServiceRequest_ToService(t *testing.T) {
|
||||
t.Run("basic HTTP service", func(t *testing.T) {
|
||||
req := &ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Port: 8080,
|
||||
Mode: "http",
|
||||
}
|
||||
|
||||
service := req.ToService("account-1", "peer-1", "mysvc")
|
||||
@@ -658,7 +658,7 @@ func TestExposeServiceRequest_ToService(t *testing.T) {
|
||||
require.Len(t, service.Targets, 1)
|
||||
|
||||
target := service.Targets[0]
|
||||
assert.Equal(t, 8080, target.Port)
|
||||
assert.Equal(t, uint16(8080), target.Port)
|
||||
assert.Equal(t, "http", target.Protocol)
|
||||
assert.Equal(t, "peer-1", target.TargetId)
|
||||
assert.Equal(t, TargetTypePeer, target.TargetType)
|
||||
@@ -730,3 +730,182 @@ func TestExposeServiceRequest_ToService(t *testing.T) {
|
||||
require.NotNil(t, service.Auth.BearerAuth)
|
||||
})
|
||||
}
|
||||
|
||||
func TestValidate_TLSOnly(t *testing.T) {
|
||||
rp := &Service{
|
||||
Name: "tls-svc",
|
||||
Mode: "tls",
|
||||
Domain: "example.com",
|
||||
ListenPort: 8443,
|
||||
Targets: []*Target{
|
||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_TLSMissingListenPort(t *testing.T) {
|
||||
rp := &Service{
|
||||
Name: "tls-svc",
|
||||
Mode: "tls",
|
||||
Domain: "example.com",
|
||||
ListenPort: 0,
|
||||
Targets: []*Target{
|
||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true},
|
||||
},
|
||||
}
|
||||
assert.ErrorContains(t, rp.Validate(), "listen_port is required")
|
||||
}
|
||||
|
||||
func TestValidate_TLSMissingDomain(t *testing.T) {
|
||||
rp := &Service{
|
||||
Name: "tls-svc",
|
||||
Mode: "tls",
|
||||
ListenPort: 8443,
|
||||
Targets: []*Target{
|
||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true},
|
||||
},
|
||||
}
|
||||
assert.ErrorContains(t, rp.Validate(), "domain is required")
|
||||
}
|
||||
|
||||
func TestValidate_TCPValid(t *testing.T) {
|
||||
rp := &Service{
|
||||
Name: "tcp-svc",
|
||||
Mode: "tcp",
|
||||
Domain: "cluster.test",
|
||||
ListenPort: 5432,
|
||||
Targets: []*Target{
|
||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_TCPMissingListenPort(t *testing.T) {
|
||||
rp := &Service{
|
||||
Name: "tcp-svc",
|
||||
Mode: "tcp",
|
||||
Domain: "cluster.test",
|
||||
Targets: []*Target{
|
||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate(), "TCP with listen_port=0 is valid (auto-assigned by manager)")
|
||||
}
|
||||
|
||||
func TestValidate_L4MultipleTargets(t *testing.T) {
|
||||
rp := &Service{
|
||||
Name: "tcp-svc",
|
||||
Mode: "tcp",
|
||||
Domain: "cluster.test",
|
||||
ListenPort: 5432,
|
||||
Targets: []*Target{
|
||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
|
||||
{TargetId: "peer-2", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
|
||||
},
|
||||
}
|
||||
assert.ErrorContains(t, rp.Validate(), "exactly one target")
|
||||
}
|
||||
|
||||
func TestValidate_L4TargetMissingPort(t *testing.T) {
|
||||
rp := &Service{
|
||||
Name: "tcp-svc",
|
||||
Mode: "tcp",
|
||||
Domain: "cluster.test",
|
||||
ListenPort: 5432,
|
||||
Targets: []*Target{
|
||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 0, Enabled: true},
|
||||
},
|
||||
}
|
||||
assert.ErrorContains(t, rp.Validate(), "port is required")
|
||||
}
|
||||
|
||||
func TestValidate_TLSInvalidTargetType(t *testing.T) {
|
||||
rp := &Service{
|
||||
Name: "tls-svc",
|
||||
Mode: "tls",
|
||||
Domain: "example.com",
|
||||
ListenPort: 443,
|
||||
Targets: []*Target{
|
||||
{TargetId: "peer-1", TargetType: "invalid", Protocol: "tcp", Port: 443, Enabled: true},
|
||||
},
|
||||
}
|
||||
assert.Error(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_TLSSubnetValid(t *testing.T) {
|
||||
rp := &Service{
|
||||
Name: "tls-subnet",
|
||||
Mode: "tls",
|
||||
Domain: "example.com",
|
||||
ListenPort: 8443,
|
||||
Targets: []*Target{
|
||||
{TargetId: "subnet-1", TargetType: TargetTypeSubnet, Protocol: "tcp", Port: 443, Host: "10.0.0.5", Enabled: true},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestValidate_HTTPProxyProtocolRejected(t *testing.T) {
|
||||
rp := validProxy()
|
||||
rp.Targets[0].ProxyProtocol = true
|
||||
assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for HTTP")
|
||||
}
|
||||
|
||||
func TestValidate_UDPProxyProtocolRejected(t *testing.T) {
|
||||
rp := &Service{
|
||||
Name: "udp-svc",
|
||||
Mode: "udp",
|
||||
Domain: "cluster.test",
|
||||
Targets: []*Target{
|
||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "udp", Port: 5432, Enabled: true, ProxyProtocol: true},
|
||||
},
|
||||
}
|
||||
assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for UDP")
|
||||
}
|
||||
|
||||
func TestValidate_TCPProxyProtocolAllowed(t *testing.T) {
|
||||
rp := &Service{
|
||||
Name: "tcp-svc",
|
||||
Mode: "tcp",
|
||||
Domain: "cluster.test",
|
||||
ListenPort: 5432,
|
||||
Targets: []*Target{
|
||||
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true, ProxyProtocol: true},
|
||||
},
|
||||
}
|
||||
require.NoError(t, rp.Validate())
|
||||
}
|
||||
|
||||
func TestExposeServiceRequest_Validate_L4RejectsAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req ExposeServiceRequest
|
||||
}{
|
||||
{
|
||||
name: "tcp with pin",
|
||||
req: ExposeServiceRequest{Port: 8080, Mode: "tcp", Pin: "123456"},
|
||||
},
|
||||
{
|
||||
name: "udp with password",
|
||||
req: ExposeServiceRequest{Port: 8080, Mode: "udp", Password: "secret"},
|
||||
},
|
||||
{
|
||||
name: "tls with user groups",
|
||||
req: ExposeServiceRequest{Port: 443, Mode: "tls", UserGroups: []string{"admins"}},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.req.Validate()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "authentication is not supported")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExposeServiceRequest_Validate_HTTPAllowsAuth(t *testing.T) {
|
||||
req := ExposeServiceRequest{Port: 8080, Mode: "http", Pin: "123456"}
|
||||
require.NoError(t, req.Validate())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user