merge main

This commit is contained in:
crn4
2026-03-24 14:50:03 +01:00
269 changed files with 20324 additions and 3434 deletions

View File

@@ -10,22 +10,33 @@ import (
"github.com/netbirdio/netbird/shared/management/proto"
)
// AccessLogProtocol identifies the transport protocol of an access log entry.
type AccessLogProtocol string
const (
AccessLogProtocolHTTP AccessLogProtocol = "http"
AccessLogProtocolTCP AccessLogProtocol = "tcp"
AccessLogProtocolUDP AccessLogProtocol = "udp"
)
type AccessLogEntry struct {
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
ServiceID string `gorm:"index"`
Timestamp time.Time `gorm:"index"`
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
Method string `gorm:"index"`
Host string `gorm:"index"`
Path string `gorm:"index"`
Duration time.Duration `gorm:"index"`
StatusCode int `gorm:"index"`
Reason string
UserId string `gorm:"index"`
AuthMethodUsed string `gorm:"index"`
BytesUpload int64 `gorm:"index"`
BytesDownload int64 `gorm:"index"`
ID string `gorm:"primaryKey"`
AccountID string `gorm:"index"`
ServiceID string `gorm:"index"`
Timestamp time.Time `gorm:"index"`
GeoLocation peer.Location `gorm:"embedded;embeddedPrefix:location_"`
SubdivisionCode string
Method string `gorm:"index"`
Host string `gorm:"index"`
Path string `gorm:"index"`
Duration time.Duration `gorm:"index"`
StatusCode int `gorm:"index"`
Reason string
UserId string `gorm:"index"`
AuthMethodUsed string `gorm:"index"`
BytesUpload int64 `gorm:"index"`
BytesDownload int64 `gorm:"index"`
Protocol AccessLogProtocol `gorm:"index"`
}
// FromProto creates an AccessLogEntry from a proto.AccessLog
@@ -43,17 +54,22 @@ func (a *AccessLogEntry) FromProto(serviceLog *proto.AccessLog) {
a.AccountID = serviceLog.GetAccountId()
a.BytesUpload = serviceLog.GetBytesUpload()
a.BytesDownload = serviceLog.GetBytesDownload()
a.Protocol = AccessLogProtocol(serviceLog.GetProtocol())
if sourceIP := serviceLog.GetSourceIp(); sourceIP != "" {
if ip, err := netip.ParseAddr(sourceIP); err == nil {
a.GeoLocation.ConnectionIP = net.IP(ip.AsSlice())
if addr, err := netip.ParseAddr(sourceIP); err == nil {
addr = addr.Unmap()
a.GeoLocation.ConnectionIP = net.IP(addr.AsSlice())
}
}
if !serviceLog.GetAuthSuccess() {
a.Reason = "Authentication failed"
} else if serviceLog.GetResponseCode() >= 400 {
a.Reason = "Request failed"
// Only set reason for HTTP entries. L4 entries have no auth or status code.
if a.Protocol == "" || a.Protocol == AccessLogProtocolHTTP {
if !serviceLog.GetAuthSuccess() {
a.Reason = "Authentication failed"
} else if serviceLog.GetResponseCode() >= 400 {
a.Reason = "Request failed"
}
}
}
@@ -90,22 +106,35 @@ func (a *AccessLogEntry) ToAPIResponse() *api.ProxyAccessLog {
cityName = &a.GeoLocation.CityName
}
var subdivisionCode *string
if a.SubdivisionCode != "" {
subdivisionCode = &a.SubdivisionCode
}
var protocol *string
if a.Protocol != "" {
p := string(a.Protocol)
protocol = &p
}
return &api.ProxyAccessLog{
Id: a.ID,
ServiceId: a.ServiceID,
Timestamp: a.Timestamp,
Method: a.Method,
Host: a.Host,
Path: a.Path,
DurationMs: int(a.Duration.Milliseconds()),
StatusCode: a.StatusCode,
SourceIp: sourceIP,
Reason: reason,
UserId: userID,
AuthMethodUsed: authMethod,
CountryCode: countryCode,
CityName: cityName,
BytesUpload: a.BytesUpload,
BytesDownload: a.BytesDownload,
Id: a.ID,
ServiceId: a.ServiceID,
Timestamp: a.Timestamp,
Method: a.Method,
Host: a.Host,
Path: a.Path,
DurationMs: int(a.Duration.Milliseconds()),
StatusCode: a.StatusCode,
SourceIp: sourceIP,
Reason: reason,
UserId: userID,
AuthMethodUsed: authMethod,
CountryCode: countryCode,
CityName: cityName,
SubdivisionCode: subdivisionCode,
BytesUpload: a.BytesUpload,
BytesDownload: a.BytesDownload,
Protocol: protocol,
}
}

View File

@@ -41,6 +41,9 @@ func (m *managerImpl) SaveAccessLog(ctx context.Context, logEntry *accesslogs.Ac
logEntry.GeoLocation.CountryCode = location.Country.ISOCode
logEntry.GeoLocation.CityName = location.City.Names.En
logEntry.GeoLocation.GeoNameID = location.City.GeonameID
if len(location.Subdivisions) > 0 {
logEntry.SubdivisionCode = location.Subdivisions[0].ISOCode
}
}
}

View File

@@ -14,6 +14,12 @@ type Domain struct {
TargetCluster string // The proxy cluster this domain should be validated against
Type Type `gorm:"-"`
Validated bool
// SupportsCustomPorts is populated at query time for free domains from the
// proxy cluster capabilities. Not persisted.
SupportsCustomPorts *bool `gorm:"-"`
// RequireSubdomain is populated at query time. When true, the domain
// cannot be used bare and a subdomain label must be prepended. Not persisted.
RequireSubdomain *bool `gorm:"-"`
}
// EventMeta returns activity event metadata for a domain

View File

@@ -42,10 +42,12 @@ func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
resp := api.ReverseProxyDomain{
Domain: d.Domain,
Id: d.ID,
Type: domainTypeToApi(d.Type),
Validated: d.Validated,
Domain: d.Domain,
Id: d.ID,
Type: domainTypeToApi(d.Type),
Validated: d.Validated,
SupportsCustomPorts: d.SupportsCustomPorts,
RequireSubdomain: d.RequireSubdomain,
}
if d.TargetCluster != "" {
resp.TargetCluster = &d.TargetCluster

View File

@@ -0,0 +1,172 @@
package manager
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
)
func TestExtractClusterFromFreeDomain(t *testing.T) {
clusters := []string{"eu1.proxy.netbird.io", "us1.proxy.netbird.io"}
tests := []struct {
name string
domain string
wantOK bool
wantVal string
}{
{
name: "subdomain of cluster matches",
domain: "myapp.eu1.proxy.netbird.io",
wantOK: true,
wantVal: "eu1.proxy.netbird.io",
},
{
name: "deep subdomain of cluster matches",
domain: "foo.bar.eu1.proxy.netbird.io",
wantOK: true,
wantVal: "eu1.proxy.netbird.io",
},
{
name: "bare cluster domain matches",
domain: "eu1.proxy.netbird.io",
wantOK: true,
wantVal: "eu1.proxy.netbird.io",
},
{
name: "unrelated domain does not match",
domain: "example.com",
wantOK: false,
},
{
name: "partial suffix does not match",
domain: "fakeu1.proxy.netbird.io",
wantOK: false,
},
{
name: "second cluster matches",
domain: "app.us1.proxy.netbird.io",
wantOK: true,
wantVal: "us1.proxy.netbird.io",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
cluster, ok := ExtractClusterFromFreeDomain(tc.domain, clusters)
assert.Equal(t, tc.wantOK, ok)
if ok {
assert.Equal(t, tc.wantVal, cluster)
}
})
}
}
func TestExtractClusterFromCustomDomains(t *testing.T) {
customDomains := []*domain.Domain{
{Domain: "example.com", TargetCluster: "eu1.proxy.netbird.io"},
{Domain: "proxy.corp.io", TargetCluster: "us1.proxy.netbird.io"},
}
tests := []struct {
name string
domain string
wantOK bool
wantVal string
}{
{
name: "subdomain of custom domain matches",
domain: "app.example.com",
wantOK: true,
wantVal: "eu1.proxy.netbird.io",
},
{
name: "bare custom domain matches",
domain: "example.com",
wantOK: true,
wantVal: "eu1.proxy.netbird.io",
},
{
name: "deep subdomain of custom domain matches",
domain: "a.b.example.com",
wantOK: true,
wantVal: "eu1.proxy.netbird.io",
},
{
name: "subdomain of multi-level custom domain matches",
domain: "app.proxy.corp.io",
wantOK: true,
wantVal: "us1.proxy.netbird.io",
},
{
name: "bare multi-level custom domain matches",
domain: "proxy.corp.io",
wantOK: true,
wantVal: "us1.proxy.netbird.io",
},
{
name: "unrelated domain does not match",
domain: "other.com",
wantOK: false,
},
{
name: "partial suffix does not match custom domain",
domain: "fakeexample.com",
wantOK: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
cluster, ok := extractClusterFromCustomDomains(tc.domain, customDomains)
assert.Equal(t, tc.wantOK, ok)
if ok {
assert.Equal(t, tc.wantVal, cluster)
}
})
}
}
func TestExtractClusterFromCustomDomains_OverlappingDomains(t *testing.T) {
customDomains := []*domain.Domain{
{Domain: "example.com", TargetCluster: "cluster-generic"},
{Domain: "app.example.com", TargetCluster: "cluster-app"},
}
tests := []struct {
name string
domain string
wantVal string
}{
{
name: "exact match on more specific domain",
domain: "app.example.com",
wantVal: "cluster-app",
},
{
name: "subdomain of more specific domain",
domain: "api.app.example.com",
wantVal: "cluster-app",
},
{
name: "subdomain of generic domain",
domain: "other.example.com",
wantVal: "cluster-generic",
},
{
name: "bare generic domain",
domain: "example.com",
wantVal: "cluster-generic",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
cluster, ok := extractClusterFromCustomDomains(tc.domain, customDomains)
assert.True(t, ok)
assert.Equal(t, tc.wantVal, cluster)
})
}
}

View File

@@ -34,11 +34,17 @@ type proxyManager interface {
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
}
type clusterCapabilities interface {
ClusterSupportsCustomPorts(clusterAddr string) *bool
ClusterRequireSubdomain(clusterAddr string) *bool
}
type Manager struct {
store store
validator domain.Validator
proxyManager proxyManager
permissionsManager permissions.Manager
store store
validator domain.Validator
proxyManager proxyManager
clusterCapabilities clusterCapabilities
permissionsManager permissions.Manager
accountManager account.Manager
}
@@ -52,6 +58,11 @@ func NewManager(store store, proxyMgr proxyManager, permissionsManager permissio
}
}
// SetClusterCapabilities sets the cluster capabilities provider for domain queries.
func (m *Manager) SetClusterCapabilities(caps clusterCapabilities) {
m.clusterCapabilities = caps
}
func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*domain.Domain, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
@@ -81,24 +92,35 @@ func (m Manager) GetDomains(ctx context.Context, accountID, userID string) ([]*d
}).Debug("getting domains with proxy allow list")
for _, cluster := range allowList {
ret = append(ret, &domain.Domain{
d := &domain.Domain{
Domain: cluster,
AccountID: accountID,
Type: domain.TypeFree,
Validated: true,
})
}
if m.clusterCapabilities != nil {
d.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(cluster)
d.RequireSubdomain = m.clusterCapabilities.ClusterRequireSubdomain(cluster)
}
ret = append(ret, d)
}
// Add custom domains.
for _, d := range domains {
ret = append(ret, &domain.Domain{
cd := &domain.Domain{
ID: d.ID,
Domain: d.Domain,
AccountID: accountID,
TargetCluster: d.TargetCluster,
Type: domain.TypeCustom,
Validated: d.Validated,
})
}
if m.clusterCapabilities != nil && d.TargetCluster != "" {
cd.SupportsCustomPorts = m.clusterCapabilities.ClusterSupportsCustomPorts(d.TargetCluster)
}
// Custom domains never require a subdomain by default since
// the account owns them and should be able to use the bare domain.
ret = append(ret, cd)
}
return ret, nil
@@ -296,13 +318,19 @@ func (m Manager) getClusterAllowList(ctx context.Context, accountID string) ([]s
return m.proxyManager.GetActiveClusterAddresses(ctx)
}
func extractClusterFromCustomDomains(domain string, customDomains []*domain.Domain) (string, bool) {
for _, customDomain := range customDomains {
if strings.HasSuffix(domain, "."+customDomain.Domain) {
return customDomain.TargetCluster, true
func extractClusterFromCustomDomains(serviceDomain string, customDomains []*domain.Domain) (string, bool) {
bestCluster := ""
bestLen := -1
for _, cd := range customDomains {
if serviceDomain != cd.Domain && !strings.HasSuffix(serviceDomain, "."+cd.Domain) {
continue
}
if l := len(cd.Domain); l > bestLen {
bestLen = l
bestCluster = cd.TargetCluster
}
}
return "", false
return bestCluster, bestLen >= 0
}
// ExtractClusterFromFreeDomain extracts the cluster address from a free domain.
@@ -310,7 +338,7 @@ func extractClusterFromCustomDomains(domain string, customDomains []*domain.Doma
// It matches the domain suffix against available clusters and returns the matching cluster.
func ExtractClusterFromFreeDomain(domain string, availableClusters []string) (string, bool) {
for _, cluster := range availableClusters {
if strings.HasSuffix(domain, "."+cluster) {
if domain == cluster || strings.HasSuffix(domain, "."+cluster) {
return cluster, true
}
}

View File

@@ -96,51 +96,3 @@ func TestGetClusterAllowList_BYOPEmptySlice_FallbackToShared(t *testing.T) {
assert.Equal(t, []string{"eu.proxy.netbird.io"}, result)
}
func TestExtractClusterFromFreeDomain(t *testing.T) {
clusters := []string{"eu.proxy.netbird.io", "us.proxy.netbird.io"}
tests := []struct {
name string
domain string
wantCluster string
wantOK bool
}{
{
name: "matches EU cluster",
domain: "myapp.abc123.eu.proxy.netbird.io",
wantCluster: "eu.proxy.netbird.io",
wantOK: true,
},
{
name: "matches US cluster",
domain: "myapp.xyz789.us.proxy.netbird.io",
wantCluster: "us.proxy.netbird.io",
wantOK: true,
},
{
name: "no match - custom domain",
domain: "app.example.com",
wantOK: false,
},
{
name: "no match - partial cluster name",
domain: "proxy.netbird.io",
wantOK: false,
},
{
name: "exact cluster name - no prefix",
domain: "eu.proxy.netbird.io",
wantOK: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cluster, ok := ExtractClusterFromFreeDomain(tt.domain, clusters)
assert.Equal(t, tt.wantOK, ok)
if tt.wantOK {
assert.Equal(t, tt.wantCluster, cluster)
}
})
}
}

View File

@@ -13,9 +13,10 @@ import (
type Manager interface {
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, accountID *string) error
Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetActiveClusters(ctx context.Context) ([]Cluster, error)
CleanupStale(ctx context.Context, inactivityDuration time.Duration) error
GetAccountProxy(ctx context.Context, accountID string) (*Proxy, error)
CountAccountProxies(ctx context.Context, accountID string) (int64, error)
@@ -38,4 +39,6 @@ type Controller interface {
RegisterProxyToCluster(ctx context.Context, clusterAddr, proxyID string) error
UnregisterProxyFromCluster(ctx context.Context, clusterAddr, proxyID string) error
GetProxiesForCluster(clusterAddr string) []string
ClusterSupportsCustomPorts(clusterAddr string) *bool
ClusterRequireSubdomain(clusterAddr string) *bool
}

View File

@@ -72,6 +72,17 @@ func (c *GRPCController) UnregisterProxyFromCluster(ctx context.Context, cluster
return nil
}
// ClusterSupportsCustomPorts returns whether any proxy in the cluster supports custom ports.
func (c *GRPCController) ClusterSupportsCustomPorts(clusterAddr string) *bool {
return c.proxyGRPCServer.ClusterSupportsCustomPorts(clusterAddr)
}
// ClusterRequireSubdomain returns whether the cluster requires a subdomain label.
// Returns nil when no proxy has reported the capability (defaults to false).
func (c *GRPCController) ClusterRequireSubdomain(clusterAddr string) *bool {
return c.proxyGRPCServer.ClusterRequireSubdomain(clusterAddr)
}
// GetProxiesForCluster returns all proxy IDs registered for a specific cluster.
func (c *GRPCController) GetProxiesForCluster(clusterAddr string) []string {
proxySet, ok := c.clusterProxies.Load(clusterAddr)

View File

@@ -14,9 +14,10 @@ import (
type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error
DisconnectProxy(ctx context.Context, proxyID string) error
UpdateProxyHeartbeat(ctx context.Context, proxyID string) error
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetProxyByAccountID(ctx context.Context, accountID string) (*proxy.Proxy, error)
CountProxiesByAccountID(ctx context.Context, accountID string) (int64, error)
@@ -85,11 +86,13 @@ func (m *Manager) Disconnect(ctx context.Context, proxyID string) error {
}
// Heartbeat updates the proxy's last seen timestamp
func (m *Manager) Heartbeat(ctx context.Context, proxyID string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID); err != nil {
func (m *Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err
}
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID)
m.metrics.IncrementProxyHeartbeatCount()
return nil
}
@@ -104,6 +107,16 @@ func (m *Manager) GetActiveClusterAddresses(ctx context.Context) ([]string, erro
return addresses, nil
}
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m Manager) GetActiveClusters(ctx context.Context) ([]proxy.Cluster, error) {
clusters, err := m.store.GetActiveProxyClusters(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to get active proxy clusters: %v", err)
return nil, err
}
return clusters, nil
}
// CleanupStale removes proxies that haven't sent heartbeat in the specified duration
func (m *Manager) CleanupStale(ctx context.Context, inactivityDuration time.Duration) error {
if err := m.store.CleanupStaleProxies(ctx, inactivityDuration); err != nil {

View File

@@ -16,7 +16,7 @@ import (
type mockStore struct {
saveProxyFunc func(ctx context.Context, p *proxy.Proxy) error
disconnectProxyFunc func(ctx context.Context, proxyID string) error
updateProxyHeartbeatFunc func(ctx context.Context, proxyID string) error
updateProxyHeartbeatFunc func(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
getActiveProxyClusterAddressesFunc func(ctx context.Context) ([]string, error)
getActiveProxyClusterAddressesForAccFunc func(ctx context.Context, accountID string) ([]string, error)
cleanupStaleProxiesFunc func(ctx context.Context, d time.Duration) error
@@ -38,9 +38,9 @@ func (m *mockStore) DisconnectProxy(ctx context.Context, proxyID string) error {
}
return nil
}
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID string) error {
func (m *mockStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
if m.updateProxyHeartbeatFunc != nil {
return m.updateProxyHeartbeatFunc(ctx, proxyID)
return m.updateProxyHeartbeatFunc(ctx, proxyID, clusterAddress, ipAddress)
}
return nil
}
@@ -56,6 +56,9 @@ func (m *mockStore) GetActiveProxyClusterAddressesForAccount(ctx context.Context
}
return nil, nil
}
func (m *mockStore) GetActiveProxyClusters(_ context.Context) ([]proxy.Cluster, error) {
return nil, nil
}
func (m *mockStore) CleanupStaleProxies(ctx context.Context, d time.Duration) error {
if m.cleanupStaleProxiesFunc != nil {
return m.cleanupStaleProxiesFunc(ctx, d)

View File

@@ -93,7 +93,6 @@ func (mr *MockManagerMockRecorder) GetActiveClusterAddresses(ctx interface{}) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddresses", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddresses), ctx)
}
// GetActiveClusterAddressesForAccount mocks base method.
func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, accountID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusterAddressesForAccount", ctx, accountID)
@@ -102,24 +101,37 @@ func (m *MockManager) GetActiveClusterAddressesForAccount(ctx context.Context, a
return ret0, ret1
}
// GetActiveClusterAddressesForAccount indicates an expected call of GetActiveClusterAddressesForAccount.
func (mr *MockManagerMockRecorder) GetActiveClusterAddressesForAccount(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusterAddressesForAccount", reflect.TypeOf((*MockManager)(nil).GetActiveClusterAddressesForAccount), ctx, accountID)
}
// Heartbeat mocks base method.
func (m *MockManager) Heartbeat(ctx context.Context, proxyID string) error {
func (m *MockManager) GetActiveClusters(ctx context.Context) ([]Cluster, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID)
ret := m.ctrl.Call(m, "GetActiveClusters", ctx)
ret0, _ := ret[0].([]Cluster)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusters indicates an expected call of GetActiveClusters.
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx)
}
// Heartbeat mocks base method.
func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress)
ret0, _ := ret[0].(error)
return ret0
}
// Heartbeat indicates an expected call of Heartbeat.
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
}
// GetAccountProxy mocks base method.
@@ -204,6 +216,34 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder {
return m.recorder
}
// ClusterSupportsCustomPorts mocks base method.
func (m *MockController) ClusterSupportsCustomPorts(clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterSupportsCustomPorts", clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterSupportsCustomPorts indicates an expected call of ClusterSupportsCustomPorts.
func (mr *MockControllerMockRecorder) ClusterSupportsCustomPorts(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterSupportsCustomPorts", reflect.TypeOf((*MockController)(nil).ClusterSupportsCustomPorts), clusterAddr)
}
// ClusterRequireSubdomain mocks base method.
func (m *MockController) ClusterRequireSubdomain(clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ClusterRequireSubdomain", clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// ClusterRequireSubdomain indicates an expected call of ClusterRequireSubdomain.
func (mr *MockControllerMockRecorder) ClusterRequireSubdomain(clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ClusterRequireSubdomain", reflect.TypeOf((*MockController)(nil).ClusterRequireSubdomain), clusterAddr)
}
// GetOIDCValidationConfig mocks base method.
func (m *MockController) GetOIDCValidationConfig() OIDCValidationConfig {
m.ctrl.T.Helper()

View File

@@ -29,3 +29,9 @@ type Proxy struct {
func (Proxy) TableName() string {
return "proxies"
}
// Cluster represents a group of proxy nodes serving the same address.
type Cluster struct {
Address string
ConnectedProxies int
}

View File

@@ -4,9 +4,12 @@ package service
import (
"context"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
type Manager interface {
GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error)
GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error)
GetService(ctx context.Context, accountID, userID, serviceID string) (*Service, error)
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
@@ -22,8 +25,8 @@ type Manager interface {
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error)
RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
StartExposeReaper(ctx context.Context)
GetServiceByDomain(ctx context.Context, domain string) (*Service, error)
}

View File

@@ -9,6 +9,7 @@ import (
reflect "reflect"
gomock "github.com/golang/mock/gomock"
proxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
)
// MockManager is a mock of Manager interface.
@@ -107,6 +108,21 @@ func (mr *MockManagerMockRecorder) GetAccountServices(ctx, accountID interface{}
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockManager)(nil).GetAccountServices), ctx, accountID)
}
// GetActiveClusters mocks base method.
func (m *MockManager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetActiveClusters", ctx, accountID, userID)
ret0, _ := ret[0].([]proxy.Cluster)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetActiveClusters indicates an expected call of GetActiveClusters.
func (mr *MockManagerMockRecorder) GetActiveClusters(ctx, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetActiveClusters", reflect.TypeOf((*MockManager)(nil).GetActiveClusters), ctx, accountID, userID)
}
// GetAllServices mocks base method.
func (m *MockManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*Service, error) {
m.ctrl.T.Helper()
@@ -226,17 +242,17 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter
}
// RenewServiceFromPeer mocks base method.
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain)
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer.
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, serviceID)
}
// SetCertificateIssuedAt mocks base method.
@@ -280,17 +296,17 @@ func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Ca
}
// StopServiceFromPeer mocks base method.
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain)
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// StopServiceFromPeer indicates an expected call of StopServiceFromPeer.
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, serviceID)
}
// UpdateService mocks base method.

View File

@@ -11,19 +11,22 @@ import (
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type handler struct {
manager rpservice.Manager
manager rpservice.Manager
permissionsManager permissions.Manager
}
// RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, permissionsManager permissions.Manager, router *mux.Router) {
h := &handler{
manager: manager,
manager: manager,
permissionsManager: permissionsManager,
}
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
@@ -31,6 +34,7 @@ func RegisterEndpoints(manager rpservice.Manager, domainManager domainmanager.Ma
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
router.HandleFunc("/reverse-proxies/clusters", h.getClusters).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
@@ -174,3 +178,27 @@ func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func (h *handler) getClusters(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
clusters, err := h.manager.GetActiveClusters(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
apiClusters := make([]api.ProxyCluster, 0, len(clusters))
for _, c := range clusters {
apiClusters = append(apiClusters, api.ProxyCluster{
Address: c.Address,
ConnectedProxies: c.ConnectedProxies,
})
}
util.WriteJSONObject(r.Context(), w, apiClusters)
}

View File

@@ -18,8 +18,8 @@ func TestReapExpiredExposes(t *testing.T) {
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
@@ -28,8 +28,8 @@ func TestReapExpiredExposes(t *testing.T) {
// Create a non-expired service
resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8081,
Protocol: "http",
Port: 8081,
Mode: "http",
})
require.NoError(t, err)
@@ -49,15 +49,16 @@ func TestReapAlreadyDeletedService(t *testing.T) {
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Delete the service before reaping
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
// Reaping should handle the already-deleted service gracefully
@@ -70,8 +71,8 @@ func TestConcurrentReapAndRenew(t *testing.T) {
for i := range 5 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
Port: uint16(8080 + i),
Mode: "http",
})
require.NoError(t, err)
}
@@ -108,17 +109,19 @@ func TestRenewEphemeralService(t *testing.T) {
t.Run("renew succeeds for active service", func(t *testing.T) {
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8082,
Protocol: "http",
Port: 8082,
Mode: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svc, lookupErr := mgr.store.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, lookupErr)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID)
require.NoError(t, err)
})
t.Run("renew fails for nonexistent domain", func(t *testing.T) {
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id")
require.Error(t, err)
assert.Contains(t, err.Error(), "no active expose session")
})
@@ -133,8 +136,8 @@ func TestCountAndExistsEphemeralServices(t *testing.T) {
assert.Equal(t, int64(0), count)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8083,
Protocol: "http",
Port: 8083,
Mode: "http",
})
require.NoError(t, err)
@@ -157,15 +160,15 @@ func TestMaxExposesPerPeerEnforced(t *testing.T) {
for i := range maxExposesPerPeer {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8090 + i,
Protocol: "http",
Port: uint16(8090 + i),
Mode: "http",
})
require.NoError(t, err, "expose %d should succeed", i)
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9999,
Protocol: "http",
Port: 9999,
Mode: "http",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "maximum number of active expose sessions")
@@ -176,8 +179,8 @@ func TestReapSkipsRenewedService(t *testing.T) {
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8086,
Protocol: "http",
Port: 8086,
Mode: "http",
})
require.NoError(t, err)
@@ -185,7 +188,9 @@ func TestReapSkipsRenewedService(t *testing.T) {
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
// Renew it before the reaper runs
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svc, err := testStore.GetServiceByDomain(ctx, resp.Domain)
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svc.ID)
require.NoError(t, err)
// Reaper should skip it because the re-check sees a fresh timestamp
@@ -195,6 +200,14 @@ func TestReapSkipsRenewedService(t *testing.T) {
require.NoError(t, err, "renewed service should survive reaping")
}
// resolveServiceIDByDomain looks up a service ID by domain in tests.
func resolveServiceIDByDomain(t *testing.T, s store.Store, domain string) string {
t.Helper()
svc, err := s.GetServiceByDomain(context.Background(), domain)
require.NoError(t, err)
return svc.ID
}
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
t.Helper()

View File

@@ -0,0 +1,583 @@
package manager
import (
"context"
"net"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
const testCluster = "test-cluster"
func boolPtr(v bool) *bool { return &v }
// setupL4Test creates a manager with a mock proxy controller for L4 port tests.
func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Store, *proxy.MockController) {
t.Helper()
ctrl := gomock.NewController(t)
ctx := context.Background()
testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
err = testStore.SaveAccount(ctx, &types.Account{
Id: testAccountID,
CreatedBy: testUserID,
Settings: &types.Settings{
PeerExposeEnabled: true,
PeerExposeGroups: []string{testGroupID},
},
Users: map[string]*types.User{
testUserID: {
Id: testUserID,
AccountID: testAccountID,
Role: types.UserRoleAdmin,
},
},
Peers: map[string]*nbpeer.Peer{
testPeerID: {
ID: testPeerID,
AccountID: testAccountID,
Key: "test-key",
DNSLabel: "test-peer",
Name: "test-peer",
IP: net.ParseIP("100.64.0.1"),
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
},
},
Groups: map[string]*types.Group{
testGroupID: {
ID: testGroupID,
AccountID: testAccountID,
Name: "Expose Group",
},
},
})
require.NoError(t, err)
err = testStore.AddPeerToGroup(ctx, testAccountID, testPeerID, testGroupID)
require.NoError(t, err)
mockCtrl := proxy.NewMockController(ctrl)
mockCtrl.EXPECT().ClusterSupportsCustomPorts(gomock.Any()).Return(customPortsSupported).AnyTimes()
mockCtrl.EXPECT().ClusterRequireSubdomain(gomock.Any()).Return((*bool)(nil)).AnyTimes()
mockCtrl.EXPECT().SendServiceUpdateToCluster(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
mockCtrl.EXPECT().GetOIDCValidationConfig().Return(proxy.OIDCValidationConfig{}).AnyTimes()
accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, accountID, groupName string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, groupName, accountID)
},
}
mgr := &Manager{
store: testStore,
accountManager: accountMgr,
permissionsManager: permissions.NewManager(testStore),
proxyController: mockCtrl,
clusterDeriver: &testClusterDeriver{domains: []string{"test.netbird.io"}},
}
mgr.exposeReaper = &exposeReaper{manager: mgr}
return mgr, testStore, mockCtrl
}
// seedService creates a service directly in the store for test setup.
func seedService(t *testing.T, s store.Store, name, protocol, domain, cluster string, port uint16) *rpservice.Service {
t.Helper()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: name,
Mode: protocol,
Domain: domain,
ProxyCluster: cluster,
ListenPort: port,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: protocol, Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := s.CreateService(context.Background(), svc)
require.NoError(t, err)
return svc
}
func TestPortConflict_TCPSamePortCluster(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tcp", "tcp", testCluster, testCluster, 5432)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "conflicting-tcp",
Mode: "tcp",
Domain: "conflicting-tcp." + testCluster,
ProxyCluster: testCluster,
ListenPort: 5432,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "TCP+TCP on same port/cluster should be rejected")
assert.Contains(t, err.Error(), "already in use")
}
func TestPortConflict_UDPSamePortCluster(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-udp", "udp", testCluster, testCluster, 5432)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "conflicting-udp",
Mode: "udp",
Domain: "conflicting-udp." + testCluster,
ProxyCluster: testCluster,
ListenPort: 5432,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "udp", Port: 9090, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "UDP+UDP on same port/cluster should be rejected")
assert.Contains(t, err.Error(), "already in use")
}
func TestPortConflict_TLSSamePortDifferentDomain(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tls", "tls", "app1.example.com", testCluster, 443)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "new-tls",
Mode: "tls",
Domain: "app2.example.com",
ProxyCluster: testCluster,
ListenPort: 443,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
assert.NoError(t, err, "TLS+TLS on same port with different domains should be allowed (SNI routing)")
}
func TestPortConflict_TLSSamePortSameDomain(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "duplicate-tls",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 443,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "TLS+TLS on same domain should be rejected")
assert.Contains(t, err.Error(), "domain already taken")
}
func TestPortConflict_TLSAndTCPSamePort(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
seedService(t, testStore, "existing-tls", "tls", "app.example.com", testCluster, 443)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "new-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 443,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
assert.NoError(t, err, "TLS+TCP on same port should be allowed (multiplexed)")
}
func TestAutoAssign_TCPNoListenPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "auto-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.NoError(t, err)
assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax,
"auto-assigned port %d should be in range [%d, %d]", svc.ListenPort, autoAssignPortMin, autoAssignPortMax)
assert.True(t, svc.PortAutoAssigned, "PortAutoAssigned should be set")
}
func TestAutoAssign_TCPCustomPortRejectedWhenNotSupported(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "custom-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 5555,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.Error(t, err, "TCP with custom port should be rejected when cluster doesn't support it")
assert.Contains(t, err.Error(), "custom ports")
}
func TestAutoAssign_TLSCustomPortAlwaysAllowed(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "custom-tls",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 9999,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
assert.NoError(t, err, "TLS with custom port should always be allowed regardless of cluster capability")
assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden")
assert.False(t, svc.PortAutoAssigned, "PortAutoAssigned should not be set for TLS")
}
func TestAutoAssign_EphemeralOverridesPortWhenNotSupported(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "ephemeral-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 5555,
Enabled: true,
Source: "ephemeral",
SourcePeer: testPeerID,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc)
require.NoError(t, err)
assert.NotEqual(t, uint16(5555), svc.ListenPort, "requested port should be overridden")
assert.True(t, svc.ListenPort >= autoAssignPortMin && svc.ListenPort <= autoAssignPortMax,
"auto-assigned port %d should be in range", svc.ListenPort)
assert.True(t, svc.PortAutoAssigned)
}
func TestAutoAssign_EphemeralTLSKeepsCustomPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "ephemeral-tls",
Mode: "tls",
Domain: "app.example.com",
ProxyCluster: testCluster,
ListenPort: 9999,
Enabled: true,
Source: "ephemeral",
SourcePeer: testPeerID,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8443, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewEphemeralService(ctx, testAccountID, testPeerID, svc)
require.NoError(t, err)
assert.Equal(t, uint16(9999), svc.ListenPort, "TLS listen port should not be overridden")
assert.False(t, svc.PortAutoAssigned)
}
func TestAutoAssign_AvoidsExistingPorts(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
existingPort := uint16(20000)
seedService(t, testStore, "existing", "tcp", testCluster, testCluster, existingPort)
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "auto-tcp",
Mode: "tcp",
Domain: "auto-tcp." + testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.NoError(t, err)
assert.NotEqual(t, existingPort, svc.ListenPort, "auto-assigned port should not collide with existing")
assert.True(t, svc.PortAutoAssigned)
}
func TestAutoAssign_TCPCustomPortAllowedWhenSupported(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
svc := &rpservice.Service{
AccountID: testAccountID,
Name: "custom-tcp",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 5555,
Enabled: true,
Source: "permanent",
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 8080, Enabled: true},
},
}
svc.InitNewRecord()
err := mgr.persistNewService(ctx, testAccountID, svc)
require.NoError(t, err)
assert.Equal(t, uint16(5555), svc.ListenPort, "custom port should be preserved when supported")
assert.False(t, svc.PortAutoAssigned)
}
func TestUpdate_PreservesExistingListenPort(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc-renamed",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 0,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err)
assert.Equal(t, uint16(12345), updated.ListenPort, "existing listen port should be preserved when update sends 0")
}
func TestUpdate_AllowsPortChange(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
existing := seedService(t, testStore, "tcp-svc", "tcp", testCluster, testCluster, 12345)
updated := &rpservice.Service{
ID: existing.ID,
AccountID: testAccountID,
Name: "tcp-svc",
Mode: "tcp",
Domain: testCluster,
ProxyCluster: testCluster,
ListenPort: 54321,
Enabled: true,
Targets: []*rpservice.Target{
{AccountID: testAccountID, TargetId: testPeerID, TargetType: rpservice.TargetTypePeer, Protocol: "tcp", Port: 9090, Enabled: true},
},
}
_, err := mgr.persistServiceUpdate(ctx, testAccountID, updated)
require.NoError(t, err)
assert.Equal(t, uint16(54321), updated.ListenPort, "explicit port change should be applied")
}
func TestCreateServiceFromPeer_TCP(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 5432,
Mode: "tcp",
})
require.NoError(t, err)
assert.NotEmpty(t, resp.ServiceName)
assert.Contains(t, resp.Domain, ".test.netbird.io", "TCP uses unique subdomain")
assert.True(t, resp.PortAutoAssigned, "port should be auto-assigned when cluster doesn't support custom ports")
assert.Contains(t, resp.ServiceURL, "tcp://")
}
func TestCreateServiceFromPeer_TCP_CustomPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 5432,
Mode: "tcp",
ListenPort: 15432,
})
require.NoError(t, err)
assert.False(t, resp.PortAutoAssigned)
assert.Contains(t, resp.ServiceURL, ":15432")
}
func TestCreateServiceFromPeer_TCP_DefaultListenPort(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 5432,
Mode: "tcp",
})
require.NoError(t, err)
// When no explicit listen port, defaults to target port
assert.Contains(t, resp.ServiceURL, ":5432")
assert.False(t, resp.PortAutoAssigned)
}
func TestCreateServiceFromPeer_TLS(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(false))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 443,
Mode: "tls",
})
require.NoError(t, err)
assert.Contains(t, resp.Domain, ".test.netbird.io", "TLS uses subdomain")
assert.Contains(t, resp.ServiceURL, "tls://")
assert.Contains(t, resp.ServiceURL, ":443")
// TLS always keeps its port (not port-based protocol for auto-assign)
assert.False(t, resp.PortAutoAssigned)
}
func TestCreateServiceFromPeer_TCP_StopAndRenew(t *testing.T) {
mgr, testStore, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "tcp",
})
require.NoError(t, err)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
// Renew after stop should fail
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.Error(t, err)
}
func TestCreateServiceFromPeer_L4_RejectsAuth(t *testing.T) {
mgr, _, _ := setupL4Test(t, boolPtr(true))
ctx := context.Background()
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "tcp",
Pin: "123456",
})
require.Error(t, err)
assert.Contains(t, err.Error(), "authentication is not supported")
}

View File

@@ -4,13 +4,18 @@ import (
"context"
"fmt"
"math/rand/v2"
"net/http"
"os"
"slices"
"strconv"
"time"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
@@ -23,6 +28,45 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
const (
defaultAutoAssignPortMin uint16 = 10000
defaultAutoAssignPortMax uint16 = 49151
// EnvAutoAssignPortMin overrides the lower bound for auto-assigned L4 listen ports.
EnvAutoAssignPortMin = "NB_PROXY_PORT_MIN"
// EnvAutoAssignPortMax overrides the upper bound for auto-assigned L4 listen ports.
EnvAutoAssignPortMax = "NB_PROXY_PORT_MAX"
)
var (
autoAssignPortMin = defaultAutoAssignPortMin
autoAssignPortMax = defaultAutoAssignPortMax
)
func init() {
autoAssignPortMin = portFromEnv(EnvAutoAssignPortMin, defaultAutoAssignPortMin)
autoAssignPortMax = portFromEnv(EnvAutoAssignPortMax, defaultAutoAssignPortMax)
if autoAssignPortMin > autoAssignPortMax {
log.Warnf("port range invalid: %s (%d) > %s (%d), using defaults",
EnvAutoAssignPortMin, autoAssignPortMin, EnvAutoAssignPortMax, autoAssignPortMax)
autoAssignPortMin = defaultAutoAssignPortMin
autoAssignPortMax = defaultAutoAssignPortMax
}
}
func portFromEnv(key string, fallback uint16) uint16 {
val := os.Getenv(key)
if val == "" {
return fallback
}
n, err := strconv.ParseUint(val, 10, 16)
if err != nil {
log.Warnf("invalid %s value %q, using default %d: %v", key, val, fallback, err)
return fallback
}
return uint16(n)
}
const unknownHostPlaceholder = "unknown"
// ClusterDeriver derives the proxy cluster from a domain.
@@ -58,6 +102,19 @@ func (m *Manager) StartExposeReaper(ctx context.Context) {
m.exposeReaper.StartExposeReaper(ctx)
}
// GetActiveClusters returns all active proxy clusters with their connected proxy count.
func (m *Manager) GetActiveClusters(ctx context.Context, accountID, userID string) ([]proxy.Cluster, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !ok {
return nil, status.NewPermissionDeniedError()
}
return m.store.GetActiveProxyClusters(ctx)
}
func (m *Manager) GetAllServices(ctx context.Context, accountID, userID string) ([]*service.Service, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Read)
if err != nil {
@@ -115,6 +172,7 @@ func (m *Manager) replaceHostByLookup(ctx context.Context, accountID string, s *
return fmt.Errorf("unknown target type: %s", target.TargetType)
}
}
return nil
}
@@ -178,6 +236,10 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return status.Errorf(status.PreconditionFailed, "could not derive cluster from domain %s: %v", service.Domain, err)
}
service.ProxyCluster = proxyCluster
if err := m.validateSubdomainRequirement(service.Domain, proxyCluster); err != nil {
return err
}
}
service.AccountID = accountID
@@ -187,6 +249,12 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return fmt.Errorf("hash secrets: %w", err)
}
for i, h := range service.Auth.HeaderAuths {
if h != nil && h.Enabled && h.Value == "" {
return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i)
}
}
keyPair, err := sessionkey.GenerateKeyPair()
if err != nil {
return fmt.Errorf("generate session keys: %w", err)
@@ -197,55 +265,33 @@ func (m *Manager) initializeServiceForCreate(ctx context.Context, accountID stri
return nil
}
func (m *Manager) persistNewService(ctx context.Context, accountID string, service *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, ""); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, service); err != nil {
return fmt.Errorf("failed to create service: %w", err)
}
// validateSubdomainRequirement checks whether the domain can be used bare
// (without a subdomain label) on the given cluster. If the cluster reports
// require_subdomain=true and the domain equals the cluster domain, it rejects.
func (m *Manager) validateSubdomainRequirement(domain, cluster string) error {
if domain != cluster {
return nil
})
}
requireSub := m.proxyController.ClusterRequireSubdomain(cluster)
if requireSub != nil && *requireSub {
return status.Errorf(status.InvalidArgument, "domain %s requires a subdomain label", domain)
}
return nil
}
// persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
func (m *Manager) persistNewService(ctx context.Context, accountID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
// Lock the peer row to serialize concurrent creates for the same peer.
// Without this, when no ephemeral rows exist yet, FOR UPDATE on the services
// table returns no rows and acquires no locks, allowing concurrent inserts
// to bypass the per-peer limit.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
if svc.Domain != "" {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
return err
}
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
return err
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
if err := m.checkPortConflict(ctx, transaction, svc); err != nil {
return err
}
@@ -261,11 +307,155 @@ func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, pee
})
}
// ensureL4Port auto-assigns a listen port when needed and validates cluster support.
func (m *Manager) ensureL4Port(ctx context.Context, tx store.Store, svc *service.Service) error {
if !service.IsL4Protocol(svc.Mode) {
return nil
}
customPorts := m.proxyController.ClusterSupportsCustomPorts(svc.ProxyCluster)
if service.IsPortBasedProtocol(svc.Mode) && svc.ListenPort > 0 && (customPorts == nil || !*customPorts) {
if svc.Source != service.SourceEphemeral {
return status.Errorf(status.InvalidArgument, "custom ports not supported on cluster %s", svc.ProxyCluster)
}
svc.ListenPort = 0
}
if svc.ListenPort == 0 {
port, err := m.assignPort(ctx, tx, svc.ProxyCluster)
if err != nil {
return err
}
svc.ListenPort = port
svc.PortAutoAssigned = true
}
return nil
}
// checkPortConflict rejects L4 services that would conflict on the same listener.
// For TCP/UDP: unique per cluster+protocol+port.
// For TLS: unique per cluster+port+domain (SNI routing allows sharing ports).
// Cross-protocol conflicts (TLS vs raw TCP) are intentionally not checked:
// the proxy router multiplexes TLS (via SNI) and raw TCP (via fallback) on the same listener.
func (m *Manager) checkPortConflict(ctx context.Context, transaction store.Store, svc *service.Service) error {
if !service.IsL4Protocol(svc.Mode) || svc.ListenPort == 0 {
return nil
}
existing, err := transaction.GetServicesByClusterAndPort(ctx, store.LockingStrengthUpdate, svc.ProxyCluster, svc.Mode, svc.ListenPort)
if err != nil {
return fmt.Errorf("query port conflicts: %w", err)
}
for _, s := range existing {
if s.ID == svc.ID {
continue
}
// TLS services on the same port are allowed if they have different domains (SNI routing)
if svc.Mode == service.ModeTLS && s.Domain != svc.Domain {
continue
}
return status.Errorf(status.AlreadyExists,
"%s port %d is already in use by service %q on cluster %s",
svc.Mode, svc.ListenPort, s.Name, svc.ProxyCluster)
}
return nil
}
// assignPort picks a random available port on the cluster within the auto-assign range.
func (m *Manager) assignPort(ctx context.Context, tx store.Store, cluster string) (uint16, error) {
services, err := tx.GetServicesByCluster(ctx, store.LockingStrengthUpdate, cluster)
if err != nil {
return 0, fmt.Errorf("query cluster ports: %w", err)
}
occupied := make(map[uint16]struct{}, len(services))
for _, s := range services {
if s.ListenPort > 0 {
occupied[s.ListenPort] = struct{}{}
}
}
portRange := int(autoAssignPortMax-autoAssignPortMin) + 1
for range 100 {
port := autoAssignPortMin + uint16(rand.IntN(portRange))
if _, taken := occupied[port]; !taken {
return port, nil
}
}
for port := autoAssignPortMin; port <= autoAssignPortMax; port++ {
if _, taken := occupied[port]; !taken {
return port, nil
}
}
return 0, status.Errorf(status.PreconditionFailed, "no available ports on cluster %s", cluster)
}
// persistNewEphemeralService creates an ephemeral service inside a single transaction
// that also enforces the duplicate and per-peer limit checks atomically.
// The count and exists queries use FOR UPDATE locking to serialize concurrent creates
// for the same peer, preventing the per-peer limit from being bypassed.
func (m *Manager) persistNewEphemeralService(ctx context.Context, accountID, peerID string, svc *service.Service) error {
return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := m.validateEphemeralPreconditions(ctx, transaction, accountID, peerID, svc); err != nil {
return err
}
if err := m.ensureL4Port(ctx, transaction, svc); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, svc); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, svc.Targets); err != nil {
return err
}
if err := transaction.CreateService(ctx, svc); err != nil {
return fmt.Errorf("create service: %w", err)
}
return nil
})
}
func (m *Manager) validateEphemeralPreconditions(ctx context.Context, transaction store.Store, accountID, peerID string, svc *service.Service) error {
// Lock the peer row to serialize concurrent creates for the same peer.
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID); err != nil {
return fmt.Errorf("lock peer row: %w", err)
}
exists, err := transaction.EphemeralServiceExists(ctx, store.LockingStrengthUpdate, accountID, peerID, svc.Domain)
if err != nil {
return fmt.Errorf("check existing expose: %w", err)
}
if exists {
return status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, ""); err != nil {
return err
}
count, err := transaction.CountEphemeralServicesByPeer(ctx, store.LockingStrengthUpdate, accountID, peerID)
if err != nil {
return fmt.Errorf("count peer exposes: %w", err)
}
if count >= int64(maxExposesPerPeer) {
return status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
return nil
}
// checkDomainAvailable checks that no other service already uses this domain.
func (m *Manager) checkDomainAvailable(ctx context.Context, transaction store.Store, domain, excludeServiceID string) error {
existingService, err := transaction.GetServiceByDomain(ctx, domain)
if err != nil {
if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound {
return fmt.Errorf("failed to check existing service: %w", err)
return fmt.Errorf("check existing service: %w", err)
}
return nil
}
@@ -317,69 +507,140 @@ func (m *Manager) persistServiceUpdate(ctx context.Context, accountID string, se
var updateInfo serviceUpdateInfo
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
return err
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
m.preserveExistingAuthSecrets(service, existingService)
m.preserveServiceMetadata(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
return nil
return m.executeServiceUpdate(ctx, transaction, accountID, service, &updateInfo)
})
return &updateInfo, err
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, service *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, service.Domain, service.ID); err != nil {
func (m *Manager) executeServiceUpdate(ctx context.Context, transaction store.Store, accountID string, service *service.Service, updateInfo *serviceUpdateInfo) error {
existingService, err := transaction.GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, service.ID)
if err != nil {
return err
}
if err := validateProtocolChange(existingService.Mode, service.Mode); err != nil {
return err
}
updateInfo.oldCluster = existingService.ProxyCluster
updateInfo.domainChanged = existingService.Domain != service.Domain
if updateInfo.domainChanged {
if err := m.handleDomainChange(ctx, transaction, accountID, service); err != nil {
return err
}
} else {
service.ProxyCluster = existingService.ProxyCluster
}
if err := m.validateSubdomainRequirement(service.Domain, service.ProxyCluster); err != nil {
return err
}
m.preserveExistingAuthSecrets(service, existingService)
if err := validateHeaderAuthValues(service.Auth.HeaderAuths); err != nil {
return err
}
m.preserveServiceMetadata(service, existingService)
m.preserveListenPort(service, existingService)
updateInfo.serviceEnabledChanged = existingService.Enabled != service.Enabled
if err := m.ensureL4Port(ctx, transaction, service); err != nil {
return err
}
if err := m.checkPortConflict(ctx, transaction, service); err != nil {
return err
}
if err := validateTargetReferences(ctx, transaction, accountID, service.Targets); err != nil {
return err
}
if err := transaction.UpdateService(ctx, service); err != nil {
return fmt.Errorf("update service: %w", err)
}
return nil
}
func (m *Manager) handleDomainChange(ctx context.Context, transaction store.Store, accountID string, svc *service.Service) error {
if err := m.checkDomainAvailable(ctx, transaction, svc.Domain, svc.ID); err != nil {
return err
}
if m.clusterDeriver != nil {
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, service.Domain)
newCluster, err := m.clusterDeriver.DeriveClusterFromDomain(ctx, accountID, svc.Domain)
if err != nil {
log.WithError(err).Warnf("could not derive cluster from domain %s", service.Domain)
log.WithError(err).Warnf("could not derive cluster from domain %s", svc.Domain)
} else {
service.ProxyCluster = newCluster
svc.ProxyCluster = newCluster
}
}
return nil
}
func (m *Manager) preserveExistingAuthSecrets(service, existingService *service.Service) {
if service.Auth.PasswordAuth != nil && service.Auth.PasswordAuth.Enabled &&
// validateProtocolChange rejects mode changes on update.
// Only empty<->HTTP is allowed; all other transitions are rejected.
func validateProtocolChange(oldMode, newMode string) error {
if newMode == "" || newMode == oldMode {
return nil
}
if isHTTPFamily(oldMode) && isHTTPFamily(newMode) {
return nil
}
return status.Errorf(status.InvalidArgument, "cannot change mode from %q to %q", oldMode, newMode)
}
func isHTTPFamily(mode string) bool {
return mode == "" || mode == "http"
}
func (m *Manager) preserveExistingAuthSecrets(svc, existingService *service.Service) {
if svc.Auth.PasswordAuth != nil && svc.Auth.PasswordAuth.Enabled &&
existingService.Auth.PasswordAuth != nil && existingService.Auth.PasswordAuth.Enabled &&
service.Auth.PasswordAuth.Password == "" {
service.Auth.PasswordAuth = existingService.Auth.PasswordAuth
svc.Auth.PasswordAuth.Password == "" {
svc.Auth.PasswordAuth = existingService.Auth.PasswordAuth
}
if service.Auth.PinAuth != nil && service.Auth.PinAuth.Enabled &&
if svc.Auth.PinAuth != nil && svc.Auth.PinAuth.Enabled &&
existingService.Auth.PinAuth != nil && existingService.Auth.PinAuth.Enabled &&
service.Auth.PinAuth.Pin == "" {
service.Auth.PinAuth = existingService.Auth.PinAuth
svc.Auth.PinAuth.Pin == "" {
svc.Auth.PinAuth = existingService.Auth.PinAuth
}
preserveHeaderAuthHashes(svc.Auth.HeaderAuths, existingService.Auth.HeaderAuths)
}
// preserveHeaderAuthHashes fills in empty header auth values from the existing
// service so that unchanged secrets are not lost on update.
func preserveHeaderAuthHashes(headers, existing []*service.HeaderAuthConfig) {
if len(headers) == 0 || len(existing) == 0 {
return
}
existingByHeader := make(map[string]string, len(existing))
for _, h := range existing {
if h != nil && h.Value != "" {
existingByHeader[http.CanonicalHeaderKey(h.Header)] = h.Value
}
}
for _, h := range headers {
if h != nil && h.Enabled && h.Value == "" {
if hash, ok := existingByHeader[http.CanonicalHeaderKey(h.Header)]; ok {
h.Value = hash
}
}
}
}
// validateHeaderAuthValues checks that all enabled header auths have a value
// (either freshly provided or preserved from the existing service).
func validateHeaderAuthValues(headers []*service.HeaderAuthConfig) error {
for i, h := range headers {
if h != nil && h.Enabled && h.Value == "" {
return status.Errorf(status.InvalidArgument, "header_auths[%d]: value is required", i)
}
}
return nil
}
func (m *Manager) preserveServiceMetadata(service, existingService *service.Service) {
@@ -388,11 +649,18 @@ func (m *Manager) preserveServiceMetadata(service, existingService *service.Serv
service.SessionPublicKey = existingService.SessionPublicKey
}
func (m *Manager) preserveListenPort(svc, existing *service.Service) {
if existing.ListenPort > 0 && svc.ListenPort == 0 {
svc.ListenPort = existing.ListenPort
svc.PortAutoAssigned = existing.PortAutoAssigned
}
}
func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID string, s *service.Service, updateInfo *serviceUpdateInfo) {
oidcCfg := m.proxyController.GetOIDCValidationConfig()
switch {
case updateInfo.domainChanged && updateInfo.oldCluster != s.ProxyCluster:
case updateInfo.domainChanged || updateInfo.oldCluster != s.ProxyCluster:
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", oidcCfg), updateInfo.oldCluster)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", oidcCfg), s.ProxyCluster)
case !s.Enabled && updateInfo.serviceEnabledChanged:
@@ -409,24 +677,53 @@ func validateTargetReferences(ctx context.Context, transaction store.Store, acco
for _, target := range targets {
switch target.TargetType {
case service.TargetTypePeer:
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
if err := validatePeerTarget(ctx, transaction, accountID, target); err != nil {
return err
}
case service.TargetTypeHost, service.TargetTypeSubnet, service.TargetTypeDomain:
if _, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
if err := validateResourceTarget(ctx, transaction, accountID, target); err != nil {
return err
}
default:
return status.Errorf(status.InvalidArgument, "unknown target type %q for target %q", target.TargetType, target.TargetId)
}
}
return nil
}
func validatePeerTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
if _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, target.TargetId); err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "peer target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up peer target %q: %w", target.TargetId, err)
}
return nil
}
func validateResourceTarget(ctx context.Context, transaction store.Store, accountID string, target *service.Target) error {
resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, target.TargetId)
if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
return status.Errorf(status.InvalidArgument, "resource target %q not found in account", target.TargetId)
}
return fmt.Errorf("look up resource target %q: %w", target.TargetId, err)
}
return validateResourceTargetType(target, resource)
}
// validateResourceTargetType checks that target_type matches the actual network resource type.
func validateResourceTargetType(target *service.Target, resource *resourcetypes.NetworkResource) error {
expected := resourcetypes.NetworkResourceType(target.TargetType)
if resource.Type != expected {
return status.Errorf(status.InvalidArgument,
"target %q has target_type %q but resource is of type %q",
target.TargetId, target.TargetType, resource.Type,
)
}
return nil
}
func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
if err != nil {
@@ -679,6 +976,10 @@ func (m *Manager) validateExposePermission(ctx context.Context, accountID, peerI
return status.Errorf(status.PermissionDenied, "peer is not in an allowed expose group")
}
func (m *Manager) resolveDefaultDomain(serviceName string) (string, error) {
return m.buildRandomDomain(serviceName)
}
// CreateServiceFromPeer creates a service initiated by a peer expose request.
// It validates the request, checks expose permissions, enforces the per-peer limit,
// creates the service, and tracks it for TTL-based reaping.
@@ -700,9 +1001,9 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
svc.Source = service.SourceEphemeral
if svc.Domain == "" {
domain, err := m.buildRandomDomain(svc.Name)
domain, err := m.resolveDefaultDomain(svc.Name)
if err != nil {
return nil, fmt.Errorf("build random domain for service %s: %w", svc.Name, err)
return nil, err
}
svc.Domain = domain
}
@@ -743,10 +1044,16 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID)
serviceURL := "https://" + svc.Domain
if service.IsL4Protocol(svc.Mode) {
serviceURL = fmt.Sprintf("%s://%s:%d", svc.Mode, svc.Domain, svc.ListenPort)
}
return &service.ExposeServiceResponse{
ServiceName: svc.Name,
ServiceURL: "https://" + svc.Domain,
Domain: svc.Domain,
ServiceName: svc.Name,
ServiceURL: serviceURL,
Domain: svc.Domain,
PortAutoAssigned: svc.PortAutoAssigned,
}, nil
}
@@ -765,64 +1072,47 @@ func (m *Manager) getGroupIDsFromNames(ctx context.Context, accountID string, gr
return groupIDs, nil
}
func (m *Manager) buildRandomDomain(name string) (string, error) {
func (m *Manager) getDefaultClusterDomain() (string, error) {
if m.clusterDeriver == nil {
return "", fmt.Errorf("unable to get random domain")
return "", fmt.Errorf("unable to get cluster domain")
}
clusterDomains := m.clusterDeriver.GetClusterDomains()
if len(clusterDomains) == 0 {
return "", fmt.Errorf("no cluster domains found for service %s", name)
return "", fmt.Errorf("no cluster domains available")
}
index := rand.IntN(len(clusterDomains))
domain := name + "." + clusterDomains[index]
return domain, nil
return clusterDomains[rand.IntN(len(clusterDomains))], nil
}
func (m *Manager) buildRandomDomain(name string) (string, error) {
domain, err := m.getDefaultClusterDomain()
if err != nil {
return "", err
}
return name + "." + domain, nil
}
// RenewServiceFromPeer updates the DB timestamp for the peer's ephemeral service.
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
return m.store.RenewEphemeralService(ctx, accountID, peerID, domain)
func (m *Manager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
return m.store.RenewEphemeralService(ctx, accountID, peerID, serviceID)
}
// StopServiceFromPeer stops a peer's active expose session by deleting the service from the DB.
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
func (m *Manager) StopServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, serviceID, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service %s: %v", serviceID, err)
return err
}
return nil
}
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
// deleteServiceFromPeer deletes a peer-initiated service identified by service ID.
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
svc, err := m.lookupPeerService(ctx, accountID, peerID, domain)
if err != nil {
return err
}
func (m *Manager) deleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string, expired bool) error {
activityCode := activity.PeerServiceUnexposed
if expired {
activityCode = activity.PeerServiceExposeExpired
}
return m.deletePeerService(ctx, accountID, peerID, svc.ID, activityCode)
}
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *Manager) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*service.Service, error) {
svc, err := m.store.GetServiceByDomain(ctx, domain)
if err != nil {
return nil, err
}
if svc.Source != service.SourceEphemeral {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
}
if svc.SourcePeer != peerID {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
}
return svc, nil
return m.deletePeerService(ctx, accountID, peerID, serviceID, activityCode)
}
func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/mock_server"
resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
@@ -803,8 +804,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -826,9 +827,9 @@ func TestCreateServiceFromPeer(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 80,
Protocol: "http",
Domain: "example.com",
Port: 80,
Mode: "http",
Domain: "example.com",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -847,8 +848,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
require.NoError(t, err)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -860,8 +861,8 @@ func TestCreateServiceFromPeer(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 0,
Protocol: "http",
Port: 0,
Mode: "http",
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
@@ -878,62 +879,52 @@ func TestExposeServiceRequestValidate(t *testing.T) {
}{
{
name: "valid http request",
req: rpservice.ExposeServiceRequest{Port: 8080, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: 8080, Mode: "http"},
wantErr: "",
},
{
name: "valid https request with pin",
req: rpservice.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
wantErr: "",
name: "https mode rejected",
req: rpservice.ExposeServiceRequest{Port: 443, Mode: "https", Pin: "123456"},
wantErr: "unsupported mode",
},
{
name: "port zero rejected",
req: rpservice.ExposeServiceRequest{Port: 0, Protocol: "http"},
req: rpservice.ExposeServiceRequest{Port: 0, Mode: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "negative port rejected",
req: rpservice.ExposeServiceRequest{Port: -1, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "port above 65535 rejected",
req: rpservice.ExposeServiceRequest{Port: 65536, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "unsupported protocol",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
wantErr: "unsupported protocol",
name: "unsupported mode",
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "ftp"},
wantErr: "unsupported mode",
},
{
name: "invalid pin format",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "abc"},
wantErr: "invalid pin",
},
{
name: "pin too short",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "12345"},
wantErr: "invalid pin",
},
{
name: "valid 6-digit pin",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", Pin: "000000"},
wantErr: "",
},
{
name: "empty user group name",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", UserGroups: []string{"valid", ""}},
wantErr: "user group name cannot be empty",
},
{
name: "invalid name prefix",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "INVALID"},
wantErr: "invalid name prefix",
},
{
name: "valid name prefix",
req: rpservice.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
req: rpservice.ExposeServiceRequest{Port: 80, Mode: "http", NamePrefix: "my-service"},
wantErr: "",
},
}
@@ -966,14 +957,14 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
// First create a service
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
// Delete by domain using unexported method
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, false)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, false)
require.NoError(t, err)
// Verify service is deleted
@@ -982,16 +973,17 @@ func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
})
t.Run("expire uses correct activity", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
mgr, testStore := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, true)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, svcID, true)
require.NoError(t, err)
})
}
@@ -1003,13 +995,14 @@ func TestStopServiceFromPeer(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
@@ -1022,8 +1015,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
@@ -1042,8 +1035,8 @@ func TestDeleteService_DeletesEphemeralExpose(t *testing.T) {
assert.Equal(t, int64(0), count, "ephemeral service should be deleted after API delete")
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 9090,
Protocol: "http",
Port: 9090,
Mode: "http",
})
assert.NoError(t, err, "new expose should succeed after API delete")
}
@@ -1054,8 +1047,8 @@ func TestDeleteAllServices_DeletesEphemeralExposes(t *testing.T) {
for i := range 3 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
Port: uint16(8080 + i),
Mode: "http",
})
require.NoError(t, err)
}
@@ -1076,21 +1069,22 @@ func TestRenewServiceFromPeer(t *testing.T) {
ctx := context.Background()
t.Run("renews tracked expose", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
})
t.Run("fails for untracked domain", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent-service-id")
require.Error(t, err)
})
}
@@ -1191,3 +1185,156 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
require.NoError(t, err)
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
}
func TestValidateProtocolChange(t *testing.T) {
tests := []struct {
name string
oldP string
newP string
wantErr bool
}{
{"empty to http", "", "http", false},
{"http to http", "http", "http", false},
{"same protocol", "tcp", "tcp", false},
{"empty new proto", "tcp", "", false},
{"http to tcp", "http", "tcp", true},
{"tcp to udp", "tcp", "udp", true},
{"tls to http", "tls", "http", true},
{"udp to tls", "udp", "tls", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateProtocolChange(tt.oldP, tt.newP)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "cannot change mode")
} else {
require.NoError(t, err)
}
})
}
}
func TestValidateTargetReferences_ResourceTypeMismatch(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
tests := []struct {
name string
targetType rpservice.TargetType
resourceType resourcetypes.NetworkResourceType
wantErr bool
}{
{"host matches host", rpservice.TargetTypeHost, resourcetypes.Host, false},
{"domain matches domain", rpservice.TargetTypeDomain, resourcetypes.Domain, false},
{"subnet matches subnet", rpservice.TargetTypeSubnet, resourcetypes.Subnet, false},
{"host but resource is domain", rpservice.TargetTypeHost, resourcetypes.Domain, true},
{"domain but resource is host", rpservice.TargetTypeDomain, resourcetypes.Host, true},
{"host but resource is subnet", rpservice.TargetTypeHost, resourcetypes.Subnet, true},
{"subnet but resource is domain", rpservice.TargetTypeSubnet, resourcetypes.Domain, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockStore.EXPECT().
GetNetworkResourceByID(gomock.Any(), store.LockingStrengthShare, accountID, "resource-1").
Return(&resourcetypes.NetworkResource{Type: tt.resourceType}, nil)
targets := []*rpservice.Target{
{TargetId: "resource-1", TargetType: tt.targetType, Host: "10.0.0.1"},
}
err := validateTargetReferences(ctx, mockStore, accountID, targets)
if tt.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "target_type")
} else {
require.NoError(t, err)
}
})
}
}
func TestValidateTargetReferences_PeerValid(t *testing.T) {
ctx := context.Background()
ctrl := gomock.NewController(t)
mockStore := store.NewMockStore(ctrl)
accountID := "test-account"
mockStore.EXPECT().
GetPeerByID(gomock.Any(), store.LockingStrengthShare, accountID, "peer-1").
Return(&nbpeer.Peer{}, nil)
targets := []*rpservice.Target{
{TargetId: "peer-1", TargetType: rpservice.TargetTypePeer},
}
require.NoError(t, validateTargetReferences(ctx, mockStore, accountID, targets))
}
func TestValidateSubdomainRequirement(t *testing.T) {
ptrBool := func(b bool) *bool { return &b }
tests := []struct {
name string
domain string
cluster string
requireSubdomain *bool
wantErr bool
}{
{
name: "subdomain present, require_subdomain true",
domain: "app.eu1.proxy.netbird.io",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: ptrBool(true),
wantErr: false,
},
{
name: "bare cluster domain, require_subdomain true",
domain: "eu1.proxy.netbird.io",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: ptrBool(true),
wantErr: true,
},
{
name: "bare cluster domain, require_subdomain false",
domain: "eu1.proxy.netbird.io",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: ptrBool(false),
wantErr: false,
},
{
name: "bare cluster domain, require_subdomain nil (default)",
domain: "eu1.proxy.netbird.io",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: nil,
wantErr: false,
},
{
name: "custom domain apex is not the cluster",
domain: "example.com",
cluster: "eu1.proxy.netbird.io",
requireSubdomain: ptrBool(true),
wantErr: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
mockCtrl := proxy.NewMockController(ctrl)
mockCtrl.EXPECT().ClusterRequireSubdomain(tc.cluster).Return(tc.requireSubdomain).AnyTimes()
mgr := &Manager{proxyController: mockCtrl}
err := mgr.validateSubdomainRequirement(tc.domain, tc.cluster)
if tc.wantErr {
require.Error(t, err)
assert.Contains(t, err.Error(), "requires a subdomain label")
} else {
require.NoError(t, err)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -44,7 +44,7 @@ func TestValidate_EmptyDomain(t *testing.T) {
func TestValidate_NoTargets(t *testing.T) {
rp := validProxy()
rp.Targets = nil
assert.ErrorContains(t, rp.Validate(), "at least one target")
assert.ErrorContains(t, rp.Validate(), "at least one target is required")
}
func TestValidate_EmptyTargetId(t *testing.T) {
@@ -120,9 +120,9 @@ func TestValidateTargetOptions_RequestTimeout(t *testing.T) {
}{
{"valid 30s", 30 * time.Second, ""},
{"valid 2m", 2 * time.Minute, ""},
{"valid 10m", 10 * time.Minute, ""},
{"zero is fine", 0, ""},
{"negative", -1 * time.Second, "must be positive"},
{"exceeds max", 10 * time.Minute, "exceeds maximum"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -273,7 +273,7 @@ func TestToProtoMapping_NoOptionsWhenDefault(t *testing.T) {
func TestIsDefaultPort(t *testing.T) {
tests := []struct {
scheme string
port int
port uint16
want bool
}{
{"http", 80, true},
@@ -299,7 +299,7 @@ func TestToProtoMapping_PortInTargetURL(t *testing.T) {
name string
protocol string
host string
port int
port uint16
wantTarget string
}{
{
@@ -645,8 +645,8 @@ func TestGenerateExposeName(t *testing.T) {
func TestExposeServiceRequest_ToService(t *testing.T) {
t.Run("basic HTTP service", func(t *testing.T) {
req := &ExposeServiceRequest{
Port: 8080,
Protocol: "http",
Port: 8080,
Mode: "http",
}
service := req.ToService("account-1", "peer-1", "mysvc")
@@ -658,7 +658,7 @@ func TestExposeServiceRequest_ToService(t *testing.T) {
require.Len(t, service.Targets, 1)
target := service.Targets[0]
assert.Equal(t, 8080, target.Port)
assert.Equal(t, uint16(8080), target.Port)
assert.Equal(t, "http", target.Protocol)
assert.Equal(t, "peer-1", target.TargetId)
assert.Equal(t, TargetTypePeer, target.TargetType)
@@ -730,3 +730,208 @@ func TestExposeServiceRequest_ToService(t *testing.T) {
require.NotNil(t, service.Auth.BearerAuth)
})
}
func TestValidate_TLSOnly(t *testing.T) {
rp := &Service{
Name: "tls-svc",
Mode: "tls",
Domain: "example.com",
ListenPort: 8443,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true},
},
}
require.NoError(t, rp.Validate())
}
func TestValidate_TLSMissingListenPort(t *testing.T) {
rp := &Service{
Name: "tls-svc",
Mode: "tls",
Domain: "example.com",
ListenPort: 0,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true},
},
}
assert.ErrorContains(t, rp.Validate(), "listen_port is required")
}
func TestValidate_TLSMissingDomain(t *testing.T) {
rp := &Service{
Name: "tls-svc",
Mode: "tls",
ListenPort: 8443,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 443, Enabled: true},
},
}
assert.ErrorContains(t, rp.Validate(), "domain is required")
}
func TestValidate_TCPValid(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
ListenPort: 5432,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
},
}
require.NoError(t, rp.Validate())
}
func TestValidate_TCPMissingListenPort(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
},
}
require.NoError(t, rp.Validate(), "TCP with listen_port=0 is valid (auto-assigned by manager)")
}
func TestValidate_L4MultipleTargets(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
ListenPort: 5432,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
{TargetId: "peer-2", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true},
},
}
assert.ErrorContains(t, rp.Validate(), "exactly one target")
}
func TestValidate_L4TargetMissingPort(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
ListenPort: 5432,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 0, Enabled: true},
},
}
assert.ErrorContains(t, rp.Validate(), "port is required")
}
func TestValidate_TLSInvalidTargetType(t *testing.T) {
rp := &Service{
Name: "tls-svc",
Mode: "tls",
Domain: "example.com",
ListenPort: 443,
Targets: []*Target{
{TargetId: "peer-1", TargetType: "invalid", Protocol: "tcp", Port: 443, Enabled: true},
},
}
assert.Error(t, rp.Validate())
}
func TestValidate_TLSSubnetValid(t *testing.T) {
rp := &Service{
Name: "tls-subnet",
Mode: "tls",
Domain: "example.com",
ListenPort: 8443,
Targets: []*Target{
{TargetId: "subnet-1", TargetType: TargetTypeSubnet, Protocol: "tcp", Port: 443, Host: "10.0.0.5", Enabled: true},
},
}
require.NoError(t, rp.Validate())
}
func TestValidate_L4DomainTargetValid(t *testing.T) {
modes := []struct {
mode string
port uint16
proto string
}{
{"tcp", 5432, "tcp"},
{"tls", 443, "tcp"},
{"udp", 5432, "udp"},
}
for _, m := range modes {
t.Run(m.mode, func(t *testing.T) {
rp := &Service{
Name: m.mode + "-domain",
Mode: m.mode,
Domain: "cluster.test",
ListenPort: m.port,
Targets: []*Target{
{TargetId: "resource-1", TargetType: TargetTypeDomain, Protocol: m.proto, Port: m.port, Enabled: true},
},
}
require.NoError(t, rp.Validate())
})
}
}
func TestValidate_HTTPProxyProtocolRejected(t *testing.T) {
rp := validProxy()
rp.Targets[0].ProxyProtocol = true
assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for HTTP")
}
func TestValidate_UDPProxyProtocolRejected(t *testing.T) {
rp := &Service{
Name: "udp-svc",
Mode: "udp",
Domain: "cluster.test",
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "udp", Port: 5432, Enabled: true, ProxyProtocol: true},
},
}
assert.ErrorContains(t, rp.Validate(), "proxy_protocol is not supported for UDP")
}
func TestValidate_TCPProxyProtocolAllowed(t *testing.T) {
rp := &Service{
Name: "tcp-svc",
Mode: "tcp",
Domain: "cluster.test",
ListenPort: 5432,
Targets: []*Target{
{TargetId: "peer-1", TargetType: TargetTypePeer, Protocol: "tcp", Port: 5432, Enabled: true, ProxyProtocol: true},
},
}
require.NoError(t, rp.Validate())
}
func TestExposeServiceRequest_Validate_L4RejectsAuth(t *testing.T) {
tests := []struct {
name string
req ExposeServiceRequest
}{
{
name: "tcp with pin",
req: ExposeServiceRequest{Port: 8080, Mode: "tcp", Pin: "123456"},
},
{
name: "udp with password",
req: ExposeServiceRequest{Port: 8080, Mode: "udp", Password: "secret"},
},
{
name: "tls with user groups",
req: ExposeServiceRequest{Port: 443, Mode: "tls", UserGroups: []string{"admins"}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.req.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "authentication is not supported")
})
}
}
func TestExposeServiceRequest_Validate_HTTPAllowsAuth(t *testing.T) {
req := ExposeServiceRequest{Port: 8080, Mode: "http", Pin: "123456"}
require.NoError(t, req.Validate())
}