[management,proxy,client] Add L4 capabilities (TLS/TCP/UDP) (#5530)

This commit is contained in:
Viktor Liu
2026-03-14 01:36:44 +08:00
committed by GitHub
parent fe9b844511
commit 3e6baea405
90 changed files with 9611 additions and 1397 deletions

View File

@@ -11,19 +11,22 @@ import (
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager rpservice.Manager
manager rpservice.Manager
permissionsManager permissions.Manager
}
// RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, permissionsManager permissions.Manager, router *mux.Router) {
h := &handler{
manager: manager,
manager: manager,
permissionsManager: permissionsManager,
}
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()

View File

@@ -18,8 +18,8 @@ func TestReapExpiredExposes(t *testing.T) {
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
@@ -28,8 +28,8 @@ func TestReapExpiredExposes(t *testing.T) {
// Create a non-expired service
resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8081,
Protocol: "http",
Port: 8081,
Mode: "http",
})
require.NoError(t, err)
@@ -49,15 +49,16 @@ func TestReapAlreadyDeletedService(t *testing.T) {
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Delete the service before reaping
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
// Reaping should handle the already-deleted service gracefully
@@ -70,8 +71,8 @@ func TestConcurrentReapAndRenew(t *testing.T) {
for i := range 5 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
Port: uint16(8080 + i),
Mode: "http",
})
require.NoError(t, err)
}
@@ -108,17 +109,19 @@ func TestRenewEphemeralService(t *testing.T) {
t.Run("renew succeeds for active service", func(t *testing.T) {
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8082,
Protocol: "http",
Port: 8082,
Mode: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svc, lookupErr := mgr.store.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, lookupErr)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID)
require.NoError(t, err)
})
t.Run("renew fails for nonexistent domain", func(t *testing.T) {
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id")
require.Error(t, err)
assert.Contains(t, err.Error(), "no active expose session")
})
@@ -133,8 +136,8 @@ func TestCountAndExistsEphemeralServices(t *testing.T) {
assert.Equal(t, int64(0), count)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8083,
Protocol: "http",
Port: 8083,
Mode: "http",
})
require.NoError(t, err)
@@ -157,15 +160,15 @@ func TestMaxExposesPerPeerEnforced(t *testing.T) {
for i := range maxExposesPerPeer {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8090 + i,
Protocol: "http",
Port: uint16(8090 + i),
Mode: "http",
})
require.NoError(t, err, "expose %d should succeed", i)
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9999,
Protocol: "http",
Port: 9999,
Mode: "http",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "maximum number of active expose sessions")
@@ -176,8 +179,8 @@ func TestReapSkipsRenewedService(t *testing.T) {
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8086,
Protocol: "http",
Port: 8086,
Mode: "http",
})
require.NoError(t, err)
@@ -185,7 +188,9 @@ func TestReapSkipsRenewedService(t *testing.T) {
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Renew it before the reaper runs
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svc, err := testStore.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID)
require.NoError(t, err)
// Reaper should skip it because the re-check sees a fresh timestamp
@@ -195,6 +200,14 @@ func TestReapSkipsRenewedService(t *testing.T) {
require.NoError(t, err, "renewed service should survive reaping")
}
// resolveServiceIDByDomain looks up a service ID by domain in tests.
func resolveServiceIDByDomain(t *testing.T, s store.Store, domain string) string {
t.Helper()
svc, err := s.GetServiceByDomain(context.Background(), domain)
require.NoError(t, err)
return svc.ID
}
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
t.Helper()

View File

@@ -0,0 +1,582 @@
package manager
import (
"context"
"net"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
const testCluster = "test-cluster"
func boolPtr(v bool) *bool { return &v }
// setupL4Test creates a manager with a mock proxy controller for L4 port tests.
func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Store, *proxy.MockController) {
t.Helper()
ctrl := gomock.NewController(t)
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
err = testStore.SaveAccount(ctx, &types.Account{
Id: testAccountID,
CreatedBy: testUserID,
Settings: &types.Settings{
PeerExposeEnabled: true,
PeerExposeGroups: []string{testGroupID},
},
Users: map[string]*types.User{
testUserID: {
Id: testUserID,
AccountID: testAccountID,
Role: types.UserRoleAdmin,
},
},
Peers: map[string]*nbpeer.Peer{
testPeerID: {
ID: testPeerID,
AccountID: testAccountID,
Key: "test-key",
DNSLabel: "test-peer",
Name: "test-peer",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
},
},
Groups: map[string]*types.Group{
testGroupID: {
ID: testGroupID,
AccountID: testAccountID,
Name: "Expose Group",
},
},
})
require.NoError(t, err)
err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID)
require.NoError(t, err)
mockCtrl := proxy.NewMockController(ctrl)
mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes()
mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
},
}
mgr := &Manager{
store: testStore,
accountManager: accountMgr,
permissionsManager: permissions.NewManager(testStore),
proxyController: mockCtrl,
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
}
mgr.exposeReaper = &exposeReaper{manager: mgr}
return mgr, testStore, mockCtrl
}
// seedService creates a service directly in the store for test setup.
func seedService(t *testing.T, s store.Store, name, protocol, domain, cluster string, port uint16) *rpservice.Service {
t.Helper()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: name,
Mode: protocol,
Domain: domain,
ProxyCluster: cluster,
ListenPort: port,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: protocol, Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := s.CreateService(context.Background(), svc)
require.NoError(t, err)
return svc
}
func TestPortConflict_TCPSamePortCluster(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tcp", "tcp", testCluster, testCluster, 5432)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "conflicting-tcp",
Mode: "tcp",
Domain: "conflicting-tcp." + testCluster,
ProxyCluster: testCluster,
ListenPort: 5432,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "TCP+TCP on same port/cluster should be rejected")
assert.Contains(t, err.Error(), "already in use")
}
func TestPortConflict_UDPSamePortCluster(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-udp", "udp", testCluster, testCluster, 5432)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "conflicting-udp",
Mode: "udp",
Domain: "conflicting-udp." + testCluster,
ProxyCluster: testCluster,
ListenPort: 5432,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "udp", Port: 9090, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "UDP+UDP on same port/cluster should be rejected")
assert.Contains(t, err.Error(), "already in use")
}
func TestPortConflict_TLSSamePortDifferentDomain(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tls", "tls", "app1.example.com", testCluster, 443)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "new-tls",
Mode: "tls",
Domain: "app2.example.com",
ProxyCluster: testCluster,
ListenPort: 443,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
assert.NoError(t, err, "TLS+TLS on same port with different domains should be allowed (SNI routing)")
}
func TestPortConflict_TLSSamePortSameDomain(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "duplicate-tls",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 443,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "TLS+TLS on same domain should be rejected")
assert.Contains(t, err.Error(), "domain already taken")
}
func TestPortConflict_TLSAndTCPSamePort(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "new-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 443,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
assert.NoError(t, err, "TLS+TCP on same port should be allowed (multiplexed)")
}
func TestAutoAssign_TCPNoListenPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "auto-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.NoError(t, err)
assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax,
"auto-assigned port %d should be in range [%d, %d]", svc.ListenPort, autoAssignPortMin, autoAssignPortMax)
assert.True(t, svc.PortAutoAssigned, "PortAutoAssigned should be set")
}
func TestAutoAssign_TCPCustomPortRejectedWhenNotSupported(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "custom-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 5555,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "TCP with custom port should be rejected when cluster doesn't support it")
assert.Contains(t, err.Error(), "custom ports")
}
func TestAutoAssign_TLSCustomPortAlwaysAllowed(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "custom-tls",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 9999,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
assert.NoError(t, err, "TLS with custom port should always be allowed regardless of cluster capability")
assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden")
assert.False(t, svc.PortAutoAssigned, "PortAutoAssigned should not be set for TLS")
}
func TestAutoAssign_EphemeralOverridesPortWhenNotSupported(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "ephemeral-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 5555,
Enabled: true,
Source: "ephemeral",
SourcePeer: testPeerID,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc)
require.NoError(t, err)
assert.NotEqual(t, uint16(5555), svc.ListenPort, "requested port should be overridden")
assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax,
"auto-assigned port %d should be in range", svc.ListenPort)
assert.True(t, svc.PortAutoAssigned)
}
func TestAutoAssign_EphemeralTLSKeepsCustomPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "ephemeral-tls",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 9999,
Enabled: true,
Source: "ephemeral",
SourcePeer: testPeerID,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc)
require.NoError(t, err)
assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden")
assert.False(t, svc.PortAutoAssigned)
}
func TestAutoAssign_AvoidsExistingPorts(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
existingPort := uint16(20000)
seedService(t, testStore, "existing", "tcp", testCluster, testCluster, existingPort)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "auto-tcp",
Mode: "tcp",
Domain: "auto-tcp." + testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.NoError(t, err)
assert.NotEqual(t, existingPort, svc.ListenPort, "auto-assigned port should not collide with existing")
assert.True(t, svc.PortAutoAssigned)
}
func TestAutoAssign_TCPCustomPortAllowedWhenSupported(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "custom-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 5555,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.NoError(t, err)
assert.Equal(t, uint16(5555), svc.ListenPort, "custom port should be preserved when supported")
assert.False(t, svc.PortAutoAssigned)
}
func TestUpdate_PreservesExistingListenPort(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc-renamed",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err)
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when update sends 0")
}
func TestUpdate_AllowsPortChange(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 54321,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err)
assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied")
}
func TestCreateServiceFromPeer_TCP(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 5432,
Mode: "tcp",
})
require.NoError(t, err)
assert.NotEmpty(t, resp.ServiceName)
assert.Contains(t, resp.Domain, ".test.netbird.io", "TCP uses unique subdomain")
assert.True(t, resp.PortAutoAssigned, "port should be auto-assigned when cluster doesn't support custom ports")
assert.Contains(t, resp.ServiceURL, "tcp://")
}
func TestCreateServiceFromPeer_TCP_CustomPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 5432,
Mode: "tcp",
ListenPort: 15432,
})
require.NoError(t, err)
assert.False(t, resp.PortAutoAssigned)
assert.Contains(t, resp.ServiceURL, ":15432")
}
func TestCreateServiceFromPeer_TCP_DefaultListenPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 5432,
Mode: "tcp",
})
require.NoError(t, err)
// When no explicit listen port, defaults to target port
assert.Contains(t, resp.ServiceURL, ":5432")
assert.False(t, resp.PortAutoAssigned)
}
func TestCreateServiceFromPeer_TLS(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 443,
Mode: "tls",
})
require.NoError(t, err)
assert.Contains(t, resp.Domain, ".test.netbird.io", "TLS uses subdomain")
assert.Contains(t, resp.ServiceURL, "tls://")
assert.Contains(t, resp.ServiceURL, ":443")
// TLS always keeps its port (not port-based protocol for auto-assign)
assert.False(t, resp.PortAutoAssigned)
}
func TestCreateServiceFromPeer_TCP_StopAndRenew(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "tcp",
})
require.NoError(t, err)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
// Renew after stop should fail
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.Error(t, err)
}
func TestCreateServiceFromPeer_L4_RejectsAuth(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "tcp",
Pin: "123456",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "authentication is not supported")
}

View File

@@ -4,7 +4,9 @@ import (
"context"
"fmt"
"math/rand/v2"
"os"
"slices"
"strconv"
"time"
log "github.com/sirupsen/logrus"
@@ -23,6 +25,45 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
const (
defaultAutoAssignPortMin uint16 = 10000
defaultAutoAssignPortMax uint16 = 49151
// EnvAutoAssignPortMin overrides the lower bound for auto-assigned L4 listen ports.
EnvAutoAssignPortMin = "NB_PROXY_PORT_MIN"
// EnvAutoAssignPortMax overrides the upper bound for auto-assigned L4 listen ports.
EnvAutoAssignPortMax = "NB_PROXY_PORT_MAX"
)
var (
autoAssignPortMin = defaultAutoAssignPortMin
autoAssignPortMax = defaultAutoAssignPortMax
)
func init() {
autoAssignPortMin = portFromEnv(EnvAutoAssignPortMin, defaultAutoAssignPortMin)
autoAssignPortMax = portFromEnv(EnvAutoAssignPortMax, defaultAutoAssignPortMax)
if autoAssignPortMin > autoAssignPortMax {
log.Warnf("port range invalid: %s (%d) > %s (%d), using defaults",
EnvAutoAssignPortMin, autoAssignPortMin, EnvAutoAssignPortMax, autoAssignPortMax)
autoAssignPortMin = defaultAutoAssignPortMin
autoAssignPortMax = defaultAutoAssignPortMax
}
}
func portFromEnv(key string, fallback uint16) uint16 {
val := os.Getenv(key)
if val == "" {
return fallback
}
n, err := strconv.ParseUint(val, 10, 16)
if err != nil {
log.Warnf("invalid %s value %q, using default %d: %v", key, val, fallback, err)
return fallback
}
return uint16(n)
}
const unknownHostPlaceholder = "unknown"
// ClusterDeriver derives the proxy cluster from a domain.
@@ -115,6 +156,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
}
return nil
}
@@ -197,55 +239,19 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return nil
}
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil {
if svc.Domain != "" {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
return err
}
}
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
return nil
})
}
// persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Lock the peer row to serialize concurrent creates for the same peer.
// Without this, when no ephemeral rows exist yet, FOR UPDATE on the services
// table returns no rows and acquires no locks, allowing concurrent inserts
// to bypass the per-peer limit.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
if err := m.checkPortConflict(ctx, transaction, svc); err != nil {
return err
}
@@ -261,11 +267,155 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
})
}
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service) error {
if !service.IsL4Protocol(svc.Mode) {
return nil
}
customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster)
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
}
svc.ListenPort = 0
}
if svc.ListenPort == 0 {
port, err := m.assignPort(ctx, tx, svc.ProxyCluster)
if err != nil {
return err
}
svc.ListenPort = port
svc.PortAutoAssigned = true
}
return nil
}
// checkPortConflict rejects L4 services that would conflict on the same listener.
// For TCP/UDP: unique per cluster+protocol+port.
// For TLS: unique per cluster+port+domain (SNI routing allows sharing ports).
// Cross-protocol conflicts (TLS vs raw TCP) are intentionally not checked:
// the proxy router multiplexes TLS (via SNI) and raw TCP (via fallback) on the same listener.
func (m *Manager) checkPortConflict(ctx context.Context, transaction store.Store, svc *service.Service) error {
if !service.IsL4Protocol(svc.Mode) || svc.ListenPort == 0 {
return nil
}
existing, err := transaction.GetServicesByClusterAndPort(ctx, store.LockingStrengthUpdate, svc.ProxyCluster, svc.Mode, svc.ListenPort)
if err != nil {
return fmt.Errorf("query port conflicts: %w", err)
}
for _, s := range existing {
if s.ID == svc.ID {
continue
}
// TLS services on the same port are allowed if they have different domains (SNI routing)
if svc.Mode == service.ModeTLS && s.Domain != svc.Domain {
continue
}
return status.Errorf(status.AlreadyExists,
"%s port %d is already in use by service %q on cluster %s",
svc.Mode, svc.ListenPort, s.Name, svc.ProxyCluster)
}
return nil
}
// assignPort picks a random available port on the cluster within the auto-assign range.
func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string) (uint16, error) {
services, err := tx.GetServicesByCluster(ctx, store.LockingStrengthUpdate, cluster)
if err != nil {
return 0, fmt.Errorf("query cluster ports: %w", err)
}
occupied := make(map[uint16]struct{}, len(services))
for _, s := range services {
if s.ListenPort > 0 {
occupied[s.ListenPort] = struct{}{}
}
}
portRange := int(autoAssignPortMax-autoAssignPortMin) + 1
for range 100 {
port := autoAssignPortMin + uint16(rand.IntN(portRange))
if _, taken := occupied[port]; !taken {
return port, nil
}
}
for port := autoAssignPortMin; port <= autoAssignPortMax; port++ {
if _, taken := occupied[port]; !taken {
return port, nil
}
}
return 0, status.Errorf(status.PreconditionFailed, "no available ports on cluster %s", cluster)
}
// persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
return err
}
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, svc); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
return nil
})
}
func (m *Manager) validateEphemeralPreconditions(ctx context.Context, transaction store.Store, accountID, peerID string, svc *service.Service) error {
// Lock the peer row to serialize concurrent creates for the same peer.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
return err
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
return nil
}
// checkDomainAvailable checks that no other service already uses this domain.
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err)
return fmt.Errorf("check existing service: %w", err)
}
return nil
}
@@ -322,6 +472,10 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
return err
}
if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil {
return err
}
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
@@ -335,12 +489,18 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
m.preserveExistingAuthSecrets(service, existingService)
m.preserveServiceMetadata(service, existingService)
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := m.ensureL4Port(ctx, transaction, service); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
@@ -351,23 +511,39 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
return &updateInfo, err
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil {
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil {
return err
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
} else {
service.ProxyCluster = newCluster
svc.ProxyCluster = newCluster
}
}
return nil
}
// validateProtocolChange rejects mode changes on update.
// Only empty<->HTTP is allowed; all other transitions are rejected.
func validateProtocolChange(oldMode, newMode string) error {
if newMode == "" || newMode == oldMode {
return nil
}
if isHTTPFamily(oldMode) && isHTTPFamily(newMode) {
return nil
}
return status.Errorf(status.InvalidArgument, "cannot change mode from %q to %q", oldMode, newMode)
}
func isHTTPFamily(mode string) bool {
return mode == "" || mode == "http"
}
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
@@ -388,6 +564,13 @@ func (m *Manager) preserveServiceMetadata(service, existingService *service.Serv
service.SessionPublicKey = existingService.SessionPublicKey
}
func (m *Manager) preserveListenPort(svc, existing *service.Service) {
if existing.ListenPort > 0 && svc.ListenPort == 0 {
svc.ListenPort = existing.ListenPort
svc.PortAutoAssigned = existing.PortAutoAssigned
}
}
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyController.GetOIDCValidationConfig()
@@ -675,6 +858,10 @@ func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerI
return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group")
}
func (m *Manager) resolveDefaultDomain(serviceName string) (string, error) {
return m.buildRandomDomain(serviceName)
}
// CreateServiceFromPeer creates a service initiated by a peer expose request.
// It validates the request, checks expose permissions, enforces the per-peer limit,
// creates the service, and tracks it for TTL-based reaping.
@@ -696,9 +883,9 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
svc.Source = service.SourceEphemeral
if svc.Domain == "" {
domain, err := m.buildRandomDomain(svc.Name)
domain, err := m.resolveDefaultDomain(svc.Name)
if err != nil {
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
return nil, err
}
svc.Domain = domain
}
@@ -739,10 +926,16 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
serviceURL := "https://" + svc.Domain
if service.IsL4Protocol(svc.Mode) {
serviceURL = fmt.Sprintf("%s://%s:%d", svc.Mode, svc.Domain, svc.ListenPort)
}
return &service.ExposeServiceResponse{
ServiceName: svc.Name,
ServiceURL: "https://" + svc.Domain,
Domain: svc.Domain,
ServiceName: svc.Name,
ServiceURL: serviceURL,
Domain: svc.Domain,
PortAutoAssigned: svc.PortAutoAssigned,
}, nil
}
@@ -761,64 +954,47 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
return groupIDs, nil
}
func (m *Manager) buildRandomDomain(name string) (string, error) {
func (m *Manager) getDefaultClusterDomain() (string, error) {
if m.clusterDeriver == nil {
return "", fmt.Errorf("unable to get random domain")
return "", fmt.Errorf("unable to get cluster domain")
}
clusterDomains := m.clusterDeriver.GetClusterDomains()
if len(clusterDomains) == 0 {
return "", fmt.Errorf("no cluster domains found for service %s", name)
return "", fmt.Errorf("no cluster domains available")
}
index := rand.IntN(len(clusterDomains))
domain := name + "." + clusterDomains[index]
return domain, nil
return clusterDomains[rand.IntN(len(clusterDomains))], nil
}
func (m *Manager) buildRandomDomain(name string) (string, error) {
domain, err := m.getDefaultClusterDomain()
if err != nil {
return "", err
}
return name + "." + domain, nil
}
// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service.
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
return m.store.RenewEphemeralService(ctx, accountID, peerID, domain)
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
return m.store.RenewEphemeralService(ctx, accountID, peerID, serviceID)
}
// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB.
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, serviceID, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service %s: %v", serviceID, err)
return err
}
return nil
}
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
// deleteServiceFromPeer deletes a peer-initiated service identified by service ID.
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
if err != nil {
return err
}
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string, expired bool) error {
activityCode := activity.PeerServiceUnexposed
if expired {
activityCode = activity.PeerServiceExposeExpired
}
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
}
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
svc, err := m.store.GetServiceByDomain(ctx, domain)
if err != nil {
return nil, err
}
if svc.Source != service.SourceEphemeral {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
}
if svc.SourcePeer != peerID {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
}
return svc, nil
return m.deletePeerService(ctx, accountID, peerID, serviceID, activityCode)
}
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {

View File

@@ -803,8 +803,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -826,9 +826,9 @@ func TestCreateServiceFromPeer(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 80,
Protocol: "http",
Domain: "example.com",
Port: 80,
Mode: "http",
Domain: "example.com",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -847,8 +847,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
require.NoError(t, err)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -860,8 +860,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 0,
Protocol: "http",
Port: 0,
Mode: "http",
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -878,62 +878,52 @@ func TestExposeServiceRequestValidate(t *testing.T) {
}{
{
name: "valid http request",
req: rpservice.ExposeServiceRequest{Port: 8080, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: 8080, Mode: "http"},
wantErr: "",
},
{
name: "valid https request with pin",
req: rpservice.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
wantErr: "",
name: "https mode rejected",
req: rpservice.ExposeServiceRequest{Port: 443, Mode: "https", Pin: "123456"},
wantErr: "unsupported mode",
},
{
name: "port zero rejected",
req: rpservice.ExposeServiceRequest{Port: 0, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: 0, Mode: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "negative port rejected",
req: rpservice.ExposeServiceRequest{Port: -1, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "port above 65535 rejected",
req: rpservice.ExposeServiceRequest{Port: 65536, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "unsupported protocol",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
wantErr: "unsupported protocol",
name: "unsupported mode",
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "ftp"},
wantErr: "unsupported mode",
},
{
name: "invalid pin format",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "abc"},
wantErr: "invalid pin",
},
{
name: "pin too short",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "12345"},
wantErr: "invalid pin",
},
{
name: "valid 6-digit pin",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "000000"},
wantErr: "",
},
{
name: "empty user group name",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", UserGroups: []string{"valid", ""}},
wantErr: "user group name cannot be empty",
},
{
name: "invalid name prefix",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "INVALID"},
wantErr: "invalid name prefix",
},
{
name: "valid name prefix",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "my-service"},
wantErr: "",
},
}
@@ -966,14 +956,14 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
// First create a service
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
// Delete by domain using unexported method
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, false)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, false)
require.NoError(t, err)
// Verify service is deleted
@@ -982,16 +972,17 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
})
t.Run("expire uses correct activity", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
mgr, testStore := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, true)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, true)
require.NoError(t, err)
})
}
@@ -1003,13 +994,14 @@ func TestStopServiceFromPeer(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
@@ -1022,8 +1014,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
@@ -1042,8 +1034,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
assert.Equal(t, int64(0), count, "ephemeral service should be deleted after API delete")
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9090,
Protocol: "http",
Port: 9090,
Mode: "http",
})
assert.NoError(t, err, "new expose should succeed after API delete")
}
@@ -1054,8 +1046,8 @@ func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) {
for i := range 3 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
Port: uint16(8080 + i),
Mode: "http",
})
require.NoError(t, err)
}
@@ -1076,21 +1068,22 @@ func TestRenewServiceFromPeer(t *testing.T) {
ctx := context.Background()
t.Run("renews tracked expose", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
})
t.Run("fails for untracked domain", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id")
require.Error(t, err)
})
}
@@ -1191,3 +1184,33 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
require.NoError(t, err)
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
}
func TestValidateProtocolChange(t *testing.T) {
tests := []struct {
name string
oldP string
newP string
wantErr bool
}{
{"empty to http", "", "http", false},
{"http to http", "http", "http", false},
{"same protocol", "tcp", "tcp", false},
{"empty new proto", "tcp", "", false},
{"http to tcp", "http", "tcp", true},
{"tcp to udp", "tcp", "udp", true},
{"tls to http", "tls", "http", true},
{"udp to tls", "udp", "tls", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateProtocolChange(tt.oldP, tt.newP)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "cannot change mode")
} else {
require.NoError(t, err)
}
})
}
}