[management] Refactor expose feature: move business logic from gRPC to manager (#5435)

Consolidate all expose business logic (validation, permission checks, TTL tracking, reaping) into the manager layer, making the gRPC layer a pure transport adapter that only handles proto conversion and authentication.

- Add ExposeServiceRequest/ExposeServiceResponse domain types with validation in the reverseproxy package
- Move expose tracker (TTL tracking, reaping, per-peer limits) from gRPC server into manager/expose_tracker.go
- Internalize tracking in CreateServiceFromPeer, RenewServiceFromPeer, and new StopServiceFromPeer so callers don't manage tracker state
- Untrack ephemeral services in DeleteService/DeleteAllServices to keep tracker in sync when services are deleted via API
- Simplify gRPC expose handlers to parse, auth, convert, delegate
- Remove tracker methods from Manager interface (internal detail)
This commit is contained in:
Maycon Santos
2026-02-24 15:09:30 +01:00
committed by GitHub
parent f8c0321aee
commit 327142837c
17 changed files with 1072 additions and 659 deletions

View File

@@ -58,7 +58,7 @@ func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
}
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
ticker := time.NewTicker(10 * time.Second)
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
defer m.stop(domain)

View File

@@ -21,8 +21,8 @@ type Manager interface {
GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error)
GetAccountServices(ctx context.Context, accountID string) ([]*Service, error)
GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error)
ValidateExposePermission(ctx context.Context, accountID, peerID string) error
CreateServiceFromPeer(ctx context.Context, accountID, peerID string, service *Service) (*Service, error)
DeleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
ExpireServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error
CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error)
RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error
StartExposeReaper(ctx context.Context)
}

View File

@@ -49,6 +49,21 @@ func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
}
// CreateServiceFromPeer mocks base method.
func (m *MockManager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateServiceFromPeer", ctx, accountID, peerID, req)
ret0, _ := ret[0].(*ExposeServiceResponse)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateServiceFromPeer indicates an expected call of CreateServiceFromPeer.
func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID, req interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req)
}
// DeleteAllServices mocks base method.
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
m.ctrl.T.Helper()
@@ -63,21 +78,6 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
}
// CreateServiceFromPeer mocks base method.
func (m *MockManager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, service *Service) (*Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateServiceFromPeer", ctx, accountID, peerID, service)
ret0, _ := ret[0].(*Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateServiceFromPeer indicates an expected call of CreateServiceFromPeer.
func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID, service interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, service)
}
// DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper()
@@ -92,48 +92,6 @@ func (mr *MockManagerMockRecorder) DeleteService(ctx, accountID, userID, service
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockManager)(nil).DeleteService), ctx, accountID, userID, serviceID)
}
// DeleteServiceFromPeer mocks base method.
func (m *MockManager) DeleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteServiceFromPeer", ctx, accountID, peerID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteServiceFromPeer indicates an expected call of DeleteServiceFromPeer.
func (mr *MockManagerMockRecorder) DeleteServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteServiceFromPeer", reflect.TypeOf((*MockManager)(nil).DeleteServiceFromPeer), ctx, accountID, peerID, serviceID)
}
// ExpireServiceFromPeer mocks base method.
func (m *MockManager) ExpireServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExpireServiceFromPeer", ctx, accountID, peerID, serviceID)
ret0, _ := ret[0].(error)
return ret0
}
// ExpireServiceFromPeer indicates an expected call of ExpireServiceFromPeer.
func (mr *MockManagerMockRecorder) ExpireServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpireServiceFromPeer", reflect.TypeOf((*MockManager)(nil).ExpireServiceFromPeer), ctx, accountID, peerID, serviceID)
}
// ValidateExposePermission mocks base method.
func (m *MockManager) ValidateExposePermission(ctx context.Context, accountID, peerID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ValidateExposePermission", ctx, accountID, peerID)
ret0, _ := ret[0].(error)
return ret0
}
// ValidateExposePermission indicates an expected call of ValidateExposePermission.
func (mr *MockManagerMockRecorder) ValidateExposePermission(ctx, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateExposePermission", reflect.TypeOf((*MockManager)(nil).ValidateExposePermission), ctx, accountID, peerID)
}
// GetAccountServices mocks base method.
func (m *MockManager) GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) {
m.ctrl.T.Helper()
@@ -252,6 +210,20 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID)
}
// RenewServiceFromPeer mocks base method.
func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain)
ret0, _ := ret[0].(error)
return ret0
}
// RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer.
func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain)
}
// SetCertificateIssuedAt mocks base method.
func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error {
m.ctrl.T.Helper()
@@ -280,6 +252,32 @@ func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status i
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status)
}
// StartExposeReaper mocks base method.
func (m *MockManager) StartExposeReaper(ctx context.Context) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "StartExposeReaper", ctx)
}
// StartExposeReaper indicates an expected call of StartExposeReaper.
func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartExposeReaper", reflect.TypeOf((*MockManager)(nil).StartExposeReaper), ctx)
}
// StopServiceFromPeer mocks base method.
func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain)
ret0, _ := ret[0].(error)
return ret0
}
// StopServiceFromPeer indicates an expected call of StopServiceFromPeer.
func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain)
}
// UpdateService mocks base method.
func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) {
m.ctrl.T.Helper()

View File

@@ -0,0 +1,163 @@
package manager
import (
"context"
"sync"
"time"
"github.com/netbirdio/netbird/shared/management/status"
log "github.com/sirupsen/logrus"
)
const (
exposeTTL = 90 * time.Second
exposeReapInterval = 30 * time.Second
maxExposesPerPeer = 10
)
type trackedExpose struct {
mu sync.Mutex
domain string
accountID string
peerID string
lastRenewed time.Time
expiring bool
}
type exposeTracker struct {
activeExposes sync.Map
exposeCreateMu sync.Mutex
manager *managerImpl
}
func exposeKey(peerID, domain string) string {
return peerID + ":" + domain
}
// TrackExposeIfAllowed atomically checks the per-peer limit and registers a new
// active expose session under the same lock. Returns (true, false) if the expose
// was already tracked (duplicate), (false, true) if tracking succeeded, and
// (false, false) if the peer has reached the limit.
func (t *exposeTracker) TrackExposeIfAllowed(peerID, domain, accountID string) (alreadyTracked, ok bool) {
t.exposeCreateMu.Lock()
defer t.exposeCreateMu.Unlock()
key := exposeKey(peerID, domain)
_, loaded := t.activeExposes.LoadOrStore(key, &trackedExpose{
domain: domain,
accountID: accountID,
peerID: peerID,
lastRenewed: time.Now(),
})
if loaded {
return true, false
}
if t.CountPeerExposes(peerID) > maxExposesPerPeer {
t.activeExposes.Delete(key)
return false, false
}
return false, true
}
// UntrackExpose removes an active expose session from tracking.
func (t *exposeTracker) UntrackExpose(peerID, domain string) {
t.activeExposes.Delete(exposeKey(peerID, domain))
}
// CountPeerExposes returns the number of active expose sessions for a peer.
func (t *exposeTracker) CountPeerExposes(peerID string) int {
count := 0
t.activeExposes.Range(func(_, val any) bool {
if expose := val.(*trackedExpose); expose.peerID == peerID {
count++
}
return true
})
return count
}
// MaxExposesPerPeer returns the maximum number of concurrent exposes allowed per peer.
func (t *exposeTracker) MaxExposesPerPeer() int {
return maxExposesPerPeer
}
// RenewTrackedExpose updates the in-memory lastRenewed timestamp for a tracked expose.
// Returns false if the expose is not tracked or is being reaped.
func (t *exposeTracker) RenewTrackedExpose(peerID, domain string) bool {
key := exposeKey(peerID, domain)
val, ok := t.activeExposes.Load(key)
if !ok {
return false
}
expose := val.(*trackedExpose)
expose.mu.Lock()
if expose.expiring {
expose.mu.Unlock()
return false
}
expose.lastRenewed = time.Now()
expose.mu.Unlock()
return true
}
// StopTrackedExpose removes an active expose session from tracking.
// Returns false if the expose was not tracked.
func (t *exposeTracker) StopTrackedExpose(peerID, domain string) bool {
key := exposeKey(peerID, domain)
_, ok := t.activeExposes.LoadAndDelete(key)
return ok
}
// StartExposeReaper starts a background goroutine that reaps expired expose sessions.
func (t *exposeTracker) StartExposeReaper(ctx context.Context) {
go func() {
ticker := time.NewTicker(exposeReapInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
t.reapExpiredExposes()
}
}
}()
}
func (t *exposeTracker) reapExpiredExposes() {
t.activeExposes.Range(func(key, val any) bool {
expose := val.(*trackedExpose)
expose.mu.Lock()
expired := time.Since(expose.lastRenewed) > exposeTTL
if expired {
expose.expiring = true
}
expose.mu.Unlock()
if !expired {
return true
}
log.Infof("reaping expired expose session for peer %s, domain %s", expose.peerID, expose.domain)
err := t.manager.deleteServiceFromPeer(context.Background(), expose.accountID, expose.peerID, expose.domain, true)
s, _ := status.FromError(err)
switch {
case err == nil:
t.activeExposes.Delete(key)
case s.ErrorType == status.NotFound:
log.Debugf("service %s was already deleted", expose.domain)
default:
log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", expose.domain, err)
}
return true
})
}

View File

@@ -0,0 +1,256 @@
package manager
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
)
func TestExposeKey(t *testing.T) {
assert.Equal(t, "peer1:example.com", exposeKey("peer1", "example.com"))
assert.Equal(t, "peer2:other.com", exposeKey("peer2", "other.com"))
assert.NotEqual(t, exposeKey("peer1", "a.com"), exposeKey("peer1", "b.com"))
}
func TestTrackExposeIfAllowed(t *testing.T) {
t.Run("first track succeeds", func(t *testing.T) {
tracker := &exposeTracker{}
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
assert.False(t, alreadyTracked, "first track should not be duplicate")
assert.True(t, ok, "first track should be allowed")
})
t.Run("duplicate track detected", func(t *testing.T) {
tracker := &exposeTracker{}
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
assert.True(t, alreadyTracked, "second track should be duplicate")
assert.False(t, ok)
})
t.Run("rejects when at limit", func(t *testing.T) {
tracker := &exposeTracker{}
for i := range maxExposesPerPeer {
_, ok := tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1")
assert.True(t, ok, "track %d should be allowed", i)
}
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "over-limit.com", "acct1")
assert.False(t, alreadyTracked)
assert.False(t, ok, "should reject when at limit")
})
t.Run("other peer unaffected by limit", func(t *testing.T) {
tracker := &exposeTracker{}
for i := range maxExposesPerPeer {
tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1")
}
_, ok := tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1")
assert.True(t, ok, "other peer should still be within limit")
})
}
func TestUntrackExpose(t *testing.T) {
tracker := &exposeTracker{}
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
assert.Equal(t, 1, tracker.CountPeerExposes("peer1"))
tracker.UntrackExpose("peer1", "a.com")
assert.Equal(t, 0, tracker.CountPeerExposes("peer1"))
}
func TestCountPeerExposes(t *testing.T) {
tracker := &exposeTracker{}
assert.Equal(t, 0, tracker.CountPeerExposes("peer1"))
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
tracker.TrackExposeIfAllowed("peer1", "b.com", "acct1")
tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1")
assert.Equal(t, 2, tracker.CountPeerExposes("peer1"), "peer1 should have 2 exposes")
assert.Equal(t, 1, tracker.CountPeerExposes("peer2"), "peer2 should have 1 expose")
assert.Equal(t, 0, tracker.CountPeerExposes("peer3"), "peer3 should have 0 exposes")
}
func TestMaxExposesPerPeer(t *testing.T) {
tracker := &exposeTracker{}
assert.Equal(t, maxExposesPerPeer, tracker.MaxExposesPerPeer())
}
func TestRenewTrackedExpose(t *testing.T) {
tracker := &exposeTracker{}
found := tracker.RenewTrackedExpose("peer1", "a.com")
assert.False(t, found, "should not find untracked expose")
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
found = tracker.RenewTrackedExpose("peer1", "a.com")
assert.True(t, found, "should find tracked expose")
}
func TestRenewTrackedExpose_RejectsExpiring(t *testing.T) {
tracker := &exposeTracker{}
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
// Simulate reaper marking the expose as expiring
key := exposeKey("peer1", "a.com")
val, _ := tracker.activeExposes.Load(key)
expose := val.(*trackedExpose)
expose.mu.Lock()
expose.expiring = true
expose.mu.Unlock()
found := tracker.RenewTrackedExpose("peer1", "a.com")
assert.False(t, found, "should reject renewal when expiring")
}
func TestReapExpiredExposes(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
tracker := mgr.exposeTracker
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
// Manually expire the tracked entry
key := exposeKey(testPeerID, resp.Domain)
val, _ := tracker.activeExposes.Load(key)
expose := val.(*trackedExpose)
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
// Add an active (non-expired) tracking entry
tracker.activeExposes.Store(exposeKey("peer1", "active.com"), &trackedExpose{
domain: "active.com",
accountID: testAccountID,
peerID: "peer1",
lastRenewed: time.Now(),
})
tracker.reapExpiredExposes()
_, exists := tracker.activeExposes.Load(key)
assert.False(t, exists, "expired expose should be removed")
_, exists = tracker.activeExposes.Load(exposeKey("peer1", "active.com"))
assert.True(t, exists, "active expose should remain")
}
func TestReapExpiredExposes_SetsExpiringFlag(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
tracker := mgr.exposeTracker
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
key := exposeKey(testPeerID, resp.Domain)
val, _ := tracker.activeExposes.Load(key)
expose := val.(*trackedExpose)
// Expire it
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
// Renew should succeed before reaping
assert.True(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should succeed before reaper runs")
// Re-expire and reap
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
tracker.reapExpiredExposes()
// Entry is deleted, renew returns false
assert.False(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should fail after reap")
}
func TestConcurrentTrackAndCount(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
tracker := mgr.exposeTracker
ctx := context.Background()
for i := range 5 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
})
require.NoError(t, err)
}
// Manually expire all tracked entries
tracker.activeExposes.Range(func(_, val any) bool {
expose := val.(*trackedExpose)
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
return true
})
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
tracker.reapExpiredExposes()
}()
go func() {
defer wg.Done()
tracker.CountPeerExposes(testPeerID)
}()
wg.Wait()
assert.Equal(t, 0, tracker.CountPeerExposes(testPeerID), "all expired exposes should be reaped")
}
func TestTrackedExposeMutexProtectsLastRenewed(t *testing.T) {
expose := &trackedExpose{
lastRenewed: time.Now().Add(-1 * time.Hour),
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
for range 100 {
expose.mu.Lock()
expose.lastRenewed = time.Now()
expose.mu.Unlock()
}
}()
go func() {
defer wg.Done()
for range 100 {
expose.mu.Lock()
_ = time.Since(expose.lastRenewed)
expose.mu.Unlock()
}
}()
wg.Wait()
expose.mu.Lock()
require.False(t, expose.lastRenewed.IsZero(), "lastRenewed should not be zero after concurrent access")
expose.mu.Unlock()
}

View File

@@ -40,11 +40,12 @@ type managerImpl struct {
settingsManager settings.Manager
proxyGRPCServer *nbgrpc.ProxyServiceServer
clusterDeriver ClusterDeriver
exposeTracker *exposeTracker
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, settingsManager settings.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
return &managerImpl{
mgr := &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
@@ -52,6 +53,13 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa
proxyGRPCServer: proxyGRPCServer,
clusterDeriver: clusterDeriver,
}
mgr.exposeTracker = &exposeTracker{manager: mgr}
return mgr
}
// StartExposeReaper delegates to the expose tracker.
func (m *managerImpl) StartExposeReaper(ctx context.Context) {
m.exposeTracker.StartExposeReaper(ctx)
}
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
@@ -418,6 +426,10 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
return err
}
if service.Source == reverseproxy.SourceEphemeral {
m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain)
}
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
@@ -460,6 +472,9 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
for _, service := range services {
if service.Source == reverseproxy.SourceEphemeral {
m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain)
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, service.EventMeta())
mapping := service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg)
clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping)
@@ -617,9 +632,9 @@ func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID stri
return target.ServiceID, nil
}
// ValidateExposePermission checks whether the peer is allowed to use the expose feature.
// validateExposePermission checks whether the peer is allowed to use the expose feature.
// It verifies the account has peer expose enabled and that the peer belongs to an allowed group.
func (m *managerImpl) ValidateExposePermission(ctx context.Context, accountID, peerID string) error {
func (m *managerImpl) validateExposePermission(ctx context.Context, accountID, peerID string) error {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
@@ -650,8 +665,23 @@ func (m *managerImpl) ValidateExposePermission(ctx context.Context, accountID, p
}
// CreateServiceFromPeer creates a service initiated by a peer expose request.
// It skips user permission checks since authorization is done at the gRPC handler level.
func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
// It validates the request, checks expose permissions, enforces the per-peer limit,
// creates the service, and tracks it for TTL-based reaping.
func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
if err := req.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err)
}
if err := m.validateExposePermission(ctx, accountID, peerID); err != nil {
return nil, err
}
serviceName, err := reverseproxy.GenerateExposeName(req.NamePrefix)
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err)
}
service := req.ToService(accountID, peerID, serviceName)
service.Source = reverseproxy.SourceEphemeral
if service.Domain == "" {
@@ -665,7 +695,7 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, service.Auth.BearerAuth.DistributionGroups)
if err != nil {
return nil, fmt.Errorf("get group ids for service %s: %w", service.ID, err)
return nil, fmt.Errorf("get group ids for service %s: %w", service.Name, err)
}
service.Auth.BearerAuth.DistributionGroups = groupIDs
}
@@ -687,8 +717,21 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer
return nil, err
}
meta := addPeerInfoToEventMeta(service.EventMeta(), peer)
alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, service.Domain, accountID)
if alreadyTracked {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", service.Domain, err)
}
return nil, status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
if !allowed {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", service.Domain, err)
}
return nil, status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
meta := addPeerInfoToEventMeta(service.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, service.ID, accountID, activity.PeerServiceExposed, meta)
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
@@ -696,10 +739,13 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer
}
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
return &reverseproxy.ExposeServiceResponse{
ServiceName: service.Name,
ServiceURL: "https://" + service.Domain,
Domain: service.Domain,
}, nil
}
func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
@@ -718,6 +764,9 @@ func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string
}
func (m *managerImpl) buildRandomDomain(name string) (string, error) {
if m.clusterDeriver == nil {
return "", fmt.Errorf("unable to get random domain")
}
clusterDomains := m.clusterDeriver.GetClusterDomains()
if len(clusterDomains) == 0 {
return "", fmt.Errorf("no cluster domains found for service %s", name)
@@ -727,15 +776,60 @@ func (m *managerImpl) buildRandomDomain(name string) (string, error) {
return domain, nil
}
// DeleteServiceFromPeer deletes a peer-initiated service.
// It validates that the service was created by a peer to prevent deleting API-created services.
func (m *managerImpl) DeleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
return m.deletePeerService(ctx, accountID, peerID, serviceID, activity.PeerServiceUnexposed)
// RenewServiceFromPeer renews the in-memory TTL tracker for the peer's expose session.
// Returns an error if the expose is not actively tracked.
func (m *managerImpl) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error {
if !m.exposeTracker.RenewTrackedExpose(peerID, domain) {
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
}
return nil
}
// ExpireServiceFromPeer deletes a peer-initiated service that was not renewed within the TTL.
func (m *managerImpl) ExpireServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
return m.deletePeerService(ctx, accountID, peerID, serviceID, activity.PeerServiceExposeExpired)
// StopServiceFromPeer stops a peer's active expose session by untracking and deleting the service.
func (m *managerImpl) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
return err
}
if !m.exposeTracker.StopTrackedExpose(peerID, domain) {
log.WithContext(ctx).Warnf("expose tracker entry for domain %s already removed; service was deleted", domain)
}
return nil
}
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
func (m *managerImpl) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
service, err := m.lookupPeerService(ctx, accountID, peerID, domain)
if err != nil {
return err
}
activityCode := activity.PeerServiceUnexposed
if expired {
activityCode = activity.PeerServiceExposeExpired
}
return m.deletePeerService(ctx, accountID, peerID, service.ID, activityCode)
}
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *managerImpl) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*reverseproxy.Service, error) {
service, err := m.store.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
return nil, err
}
if service.Source != reverseproxy.SourceEphemeral {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
}
if service.SourcePeer != peerID {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
}
return service, nil
}
func (m *managerImpl) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {

View File

@@ -658,6 +658,13 @@ func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) {
PeerExposeEnabled: true,
PeerExposeGroups: []string{testGroupID},
},
Users: map[string]*types.User{
testUserID: {
Id: testUserID,
AccountID: testAccountID,
Role: types.UserRoleAdmin,
},
},
Peers: map[string]*nbpeer.Peer{
testPeerID: {
ID: testPeerID,
@@ -712,16 +719,17 @@ func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) {
domains: []string{"test.netbird.io"},
},
}
mgr.exposeTracker = &exposeTracker{manager: mgr}
return mgr, testStore
}
func TestValidateExposePermission(t *testing.T) {
func Test_validateExposePermission(t *testing.T) {
ctx := context.Background()
t.Run("allowed when peer is in expose group", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
err := mgr.ValidateExposePermission(ctx, testAccountID, testPeerID)
err := mgr.validateExposePermission(ctx, testAccountID, testPeerID)
assert.NoError(t, err)
})
@@ -742,7 +750,7 @@ func TestValidateExposePermission(t *testing.T) {
})
require.NoError(t, err)
err = mgr.ValidateExposePermission(ctx, testAccountID, otherPeerID)
err = mgr.validateExposePermission(ctx, testAccountID, otherPeerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "not in an allowed expose group")
})
@@ -757,7 +765,7 @@ func TestValidateExposePermission(t *testing.T) {
err = testStore.SaveAccountSettings(ctx, testAccountID, s)
require.NoError(t, err)
err = mgr.ValidateExposePermission(ctx, testAccountID, testPeerID)
err = mgr.validateExposePermission(ctx, testAccountID, testPeerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "not enabled")
})
@@ -772,7 +780,7 @@ func TestValidateExposePermission(t *testing.T) {
err = testStore.SaveAccountSettings(ctx, testAccountID, s)
require.NoError(t, err)
err = mgr.ValidateExposePermission(ctx, testAccountID, testPeerID)
err = mgr.validateExposePermission(ctx, testAccountID, testPeerID)
assert.Error(t, err)
})
@@ -781,7 +789,7 @@ func TestValidateExposePermission(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountSettings(gomock.Any(), gomock.Any(), testAccountID).Return(nil, errors.New("store error"))
mgr := &managerImpl{store: mockStore}
err := mgr.ValidateExposePermission(ctx, testAccountID, testPeerID)
err := mgr.validateExposePermission(ctx, testAccountID, testPeerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "get account settings")
})
@@ -793,82 +801,290 @@ func TestCreateServiceFromPeer(t *testing.T) {
t.Run("creates service with random domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
service := &reverseproxy.Service{
Name: "my-expose",
Enabled: true,
Targets: []*reverseproxy.Target{
{
AccountID: testAccountID,
Port: 8080,
Protocol: "http",
TargetId: testPeerID,
TargetType: reverseproxy.TargetTypePeer,
Enabled: true,
},
},
req := &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
created, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, service)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
assert.NotEmpty(t, created.ID, "service should have an ID")
assert.Contains(t, created.Domain, "test.netbird.io", "domain should use cluster domain")
assert.Equal(t, reverseproxy.SourceEphemeral, created.Source, "source should be ephemeral")
assert.Equal(t, testPeerID, created.SourcePeer, "source peer should be set")
assert.NotNil(t, created.Meta.LastRenewedAt, "last renewed should be set")
assert.NotEmpty(t, resp.ServiceName, "service name should be generated")
assert.Contains(t, resp.Domain, "test.netbird.io", "domain should use cluster domain")
assert.NotEmpty(t, resp.ServiceURL, "service URL should be set")
// Verify service is persisted in store
persisted, err := testStore.GetServiceByID(ctx, store.LockingStrengthNone, testAccountID, created.ID)
persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.NoError(t, err)
assert.Equal(t, created.ID, persisted.ID)
assert.Equal(t, created.Domain, persisted.Domain)
assert.Equal(t, resp.Domain, persisted.Domain)
assert.Equal(t, reverseproxy.SourceEphemeral, persisted.Source, "source should be ephemeral")
assert.Equal(t, testPeerID, persisted.SourcePeer, "source peer should be set")
assert.NotNil(t, persisted.Meta.LastRenewedAt, "last renewed should be set")
})
t.Run("creates service with custom domain", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
service := &reverseproxy.Service{
Name: "custom",
Domain: "custom.example.com",
Enabled: true,
Targets: []*reverseproxy.Target{
{
AccountID: testAccountID,
Port: 80,
Protocol: "http",
TargetId: testPeerID,
TargetType: reverseproxy.TargetTypePeer,
Enabled: true,
},
},
req := &reverseproxy.ExposeServiceRequest{
Port: 80,
Protocol: "http",
Domain: "example.com",
}
created, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, service)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
assert.Equal(t, "custom.example.com", created.Domain, "should keep the provided domain")
assert.Contains(t, resp.Domain, "example.com", "should use the provided domain")
})
t.Run("replaces host by peer IP lookup", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
t.Run("validates expose permission internally", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
service := &reverseproxy.Service{
Name: "lookup-test",
Enabled: true,
Targets: []*reverseproxy.Target{
{
AccountID: testAccountID,
Port: 3000,
Protocol: "http",
TargetId: testPeerID,
TargetType: reverseproxy.TargetTypePeer,
Enabled: true,
},
},
// Disable peer expose
s, err := testStore.GetAccountSettings(ctx, store.LockingStrengthNone, testAccountID)
require.NoError(t, err)
s.PeerExposeEnabled = false
err = testStore.SaveAccountSettings(ctx, testAccountID, s)
require.NoError(t, err)
req := &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
created, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, service)
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.Error(t, err)
assert.Contains(t, err.Error(), "not enabled")
})
t.Run("validates request fields", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
Port: 0,
Protocol: "http",
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.Error(t, err)
assert.Contains(t, err.Error(), "port")
})
}
func TestExposeServiceRequestValidate(t *testing.T) {
tests := []struct {
name string
req reverseproxy.ExposeServiceRequest
wantErr string
}{
{
name: "valid http request",
req: reverseproxy.ExposeServiceRequest{Port: 8080, Protocol: "http"},
wantErr: "",
},
{
name: "valid https request with pin",
req: reverseproxy.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
wantErr: "",
},
{
name: "port zero rejected",
req: reverseproxy.ExposeServiceRequest{Port: 0, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "negative port rejected",
req: reverseproxy.ExposeServiceRequest{Port: -1, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "port above 65535 rejected",
req: reverseproxy.ExposeServiceRequest{Port: 65536, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "unsupported protocol",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
wantErr: "unsupported protocol",
},
{
name: "invalid pin format",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
wantErr: "invalid pin",
},
{
name: "pin too short",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
wantErr: "invalid pin",
},
{
name: "valid 6-digit pin",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
wantErr: "",
},
{
name: "empty user group name",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
wantErr: "user group name cannot be empty",
},
{
name: "invalid name prefix",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
wantErr: "invalid name prefix",
},
{
name: "valid name prefix",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.req.Validate()
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
}
})
}
t.Run("nil receiver", func(t *testing.T) {
var req *reverseproxy.ExposeServiceRequest
err := req.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "request cannot be nil")
})
}
func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
ctx := context.Background()
t.Run("deletes service by domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
// First create a service
req := &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
require.Len(t, created.Targets, 1)
assert.Equal(t, "100.64.0.1", created.Targets[0].Host, "host should be resolved to peer IP")
// Delete by domain using unexported method
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, false)
require.NoError(t, err)
// Verify service is deleted
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.Error(t, err, "service should be deleted")
})
t.Run("expire uses correct activity", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, true)
require.NoError(t, err)
})
}
func TestStopServiceFromPeer(t *testing.T) {
ctx := context.Background()
t.Run("stops service by domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.Error(t, err, "service should be deleted")
})
}
func TestDeleteService_UntracksEphemeralExpose(t *testing.T) {
ctx := context.Background()
mgr, _ := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
assert.Equal(t, 1, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be tracked after create")
// Look up the service by domain to get its store ID
svc, err := mgr.store.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.NoError(t, err)
// Delete via the API path (user-initiated)
err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID)
require.NoError(t, err)
assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be untracked after API delete")
// A new expose should succeed (not blocked by stale tracking)
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 9090,
Protocol: "http",
})
assert.NoError(t, err, "new expose should succeed after API delete cleared tracking")
}
func TestDeleteAllServices_UntracksEphemeralExposes(t *testing.T) {
ctx := context.Background()
mgr, _ := setupIntegrationTest(t)
for i := range 3 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
})
require.NoError(t, err)
}
assert.Equal(t, 3, mgr.exposeTracker.CountPeerExposes(testPeerID), "all exposes should be tracked")
err := mgr.DeleteAllServices(ctx, testAccountID, testUserID)
require.NoError(t, err)
assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "all exposes should be untracked after DeleteAllServices")
}
func TestRenewServiceFromPeer(t *testing.T) {
ctx := context.Background()
t.Run("renews tracked expose", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
})
t.Run("fails for untracked domain", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
require.Error(t, err)
})
}

View File

@@ -318,63 +318,6 @@ func isDefaultPort(scheme string, port int) bool {
return (scheme == "https" && port == 443) || (scheme == "http" && port == 80)
}
// FromExposeRequest builds a Service from a peer expose gRPC request.
func FromExposeRequest(req *proto.ExposeServiceRequest, accountID, peerID, serviceName string) *Service {
service := &Service{
AccountID: accountID,
Name: serviceName,
Enabled: true,
Targets: []*Target{
{
AccountID: accountID,
Port: int(req.Port),
Protocol: exposeProtocolToString(req.Protocol),
TargetId: peerID,
TargetType: TargetTypePeer,
Enabled: true,
},
},
}
if req.Domain != "" {
service.Domain = serviceName + "." + req.Domain
}
if req.Pin != "" {
service.Auth.PinAuth = &PINAuthConfig{
Enabled: true,
Pin: req.Pin,
}
}
if req.Password != "" {
service.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: true,
Password: req.Password,
}
}
if len(req.UserGroups) > 0 {
service.Auth.BearerAuth = &BearerAuthConfig{
Enabled: true,
DistributionGroups: req.UserGroups,
}
}
return service
}
func exposeProtocolToString(p proto.ExposeProtocol) string {
switch p {
case proto.ExposeProtocol_EXPOSE_HTTP:
return "http"
case proto.ExposeProtocol_EXPOSE_HTTPS:
return "https"
default:
return "http"
}
}
func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) {
s.Name = req.Name
s.Domain = req.Domain
@@ -534,10 +477,107 @@ func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
return nil
}
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
const alphanumCharset = "abcdefghijklmnopqrstuvwxyz0123456789"
var validNamePrefix = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,30}[a-z0-9])?$`)
// ExposeServiceRequest contains the parameters for creating a peer-initiated expose service.
type ExposeServiceRequest struct {
NamePrefix string
Port int
Protocol string
Domain string
Pin string
Password string
UserGroups []string
}
// Validate checks all fields of the expose request.
func (r *ExposeServiceRequest) Validate() error {
if r == nil {
return errors.New("request cannot be nil")
}
if r.Port < 1 || r.Port > 65535 {
return fmt.Errorf("port must be between 1 and 65535, got %d", r.Port)
}
if r.Protocol != "http" && r.Protocol != "https" {
return fmt.Errorf("unsupported protocol %q: must be http or https", r.Protocol)
}
if r.Pin != "" && !pinRegexp.MatchString(r.Pin) {
return errors.New("invalid pin: must be exactly 6 digits")
}
for _, g := range r.UserGroups {
if g == "" {
return errors.New("user group name cannot be empty")
}
}
if r.NamePrefix != "" && !validNamePrefix.MatchString(r.NamePrefix) {
return fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", r.NamePrefix)
}
return nil
}
// ToService builds a Service from the expose request.
func (r *ExposeServiceRequest) ToService(accountID, peerID, serviceName string) *Service {
service := &Service{
AccountID: accountID,
Name: serviceName,
Enabled: true,
Targets: []*Target{
{
AccountID: accountID,
Port: r.Port,
Protocol: r.Protocol,
TargetId: peerID,
TargetType: TargetTypePeer,
Enabled: true,
},
},
}
if r.Domain != "" {
service.Domain = serviceName + "." + r.Domain
}
if r.Pin != "" {
service.Auth.PinAuth = &PINAuthConfig{
Enabled: true,
Pin: r.Pin,
}
}
if r.Password != "" {
service.Auth.PasswordAuth = &PasswordAuthConfig{
Enabled: true,
Password: r.Password,
}
}
if len(r.UserGroups) > 0 {
service.Auth.BearerAuth = &BearerAuthConfig{
Enabled: true,
DistributionGroups: r.UserGroups,
}
}
return service
}
// ExposeServiceResponse contains the result of a successful peer expose creation.
type ExposeServiceResponse struct {
ServiceName string
ServiceURL string
Domain string
}
// GenerateExposeName generates a random service name for peer-exposed services.
// The prefix, if provided, must be a valid DNS label component (lowercase alphanumeric and hyphens).
func GenerateExposeName(prefix string) (string, error) {

View File

@@ -458,14 +458,14 @@ func TestGenerateExposeName(t *testing.T) {
})
}
func TestFromExposeRequest(t *testing.T) {
func TestExposeServiceRequest_ToService(t *testing.T) {
t.Run("basic HTTP service", func(t *testing.T) {
req := &proto.ExposeServiceRequest{
req := &ExposeServiceRequest{
Port: 8080,
Protocol: proto.ExposeProtocol_EXPOSE_HTTP,
Protocol: "http",
}
service := FromExposeRequest(req, "account-1", "peer-1", "mysvc")
service := req.ToService("account-1", "peer-1", "mysvc")
assert.Equal(t, "account-1", service.AccountID)
assert.Equal(t, "mysvc", service.Name)
@@ -483,22 +483,22 @@ func TestFromExposeRequest(t *testing.T) {
})
t.Run("with custom domain", func(t *testing.T) {
req := &proto.ExposeServiceRequest{
req := &ExposeServiceRequest{
Port: 3000,
Domain: "example.com",
}
service := FromExposeRequest(req, "acc", "peer", "web")
service := req.ToService("acc", "peer", "web")
assert.Equal(t, "web.example.com", service.Domain)
})
t.Run("with PIN auth", func(t *testing.T) {
req := &proto.ExposeServiceRequest{
req := &ExposeServiceRequest{
Port: 80,
Pin: "1234",
}
service := FromExposeRequest(req, "acc", "peer", "svc")
service := req.ToService("acc", "peer", "svc")
require.NotNil(t, service.Auth.PinAuth)
assert.True(t, service.Auth.PinAuth.Enabled)
assert.Equal(t, "1234", service.Auth.PinAuth.Pin)
@@ -507,31 +507,31 @@ func TestFromExposeRequest(t *testing.T) {
})
t.Run("with password auth", func(t *testing.T) {
req := &proto.ExposeServiceRequest{
req := &ExposeServiceRequest{
Port: 80,
Password: "secret",
}
service := FromExposeRequest(req, "acc", "peer", "svc")
service := req.ToService("acc", "peer", "svc")
require.NotNil(t, service.Auth.PasswordAuth)
assert.True(t, service.Auth.PasswordAuth.Enabled)
assert.Equal(t, "secret", service.Auth.PasswordAuth.Password)
})
t.Run("with user groups (bearer auth)", func(t *testing.T) {
req := &proto.ExposeServiceRequest{
req := &ExposeServiceRequest{
Port: 80,
UserGroups: []string{"admins", "devs"},
}
service := FromExposeRequest(req, "acc", "peer", "svc")
service := req.ToService("acc", "peer", "svc")
require.NotNil(t, service.Auth.BearerAuth)
assert.True(t, service.Auth.BearerAuth.Enabled)
assert.Equal(t, []string{"admins", "devs"}, service.Auth.BearerAuth.DistributionGroups)
})
t.Run("with all auth types", func(t *testing.T) {
req := &proto.ExposeServiceRequest{
req := &ExposeServiceRequest{
Port: 443,
Domain: "myco.com",
Pin: "9999",
@@ -539,7 +539,7 @@ func TestFromExposeRequest(t *testing.T) {
UserGroups: []string{"ops"},
}
service := FromExposeRequest(req, "acc", "peer", "full")
service := req.ToService("acc", "peer", "full")
assert.Equal(t, "full.myco.com", service.Domain)
require.NotNil(t, service.Auth.PinAuth)
require.NotNil(t, service.Auth.PasswordAuth)

View File

@@ -152,8 +152,11 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
if err != nil {
log.Fatalf("failed to create management server: %v", err)
}
srv.SetReverseProxyManager(s.ReverseProxyManager())
srv.StartExposeReaper(context.Background())
reverseProxyMgr := s.ReverseProxyManager()
srv.SetReverseProxyManager(reverseProxyMgr)
if reverseProxyMgr != nil {
reverseProxyMgr.StartExposeReaper(context.Background())
}
mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv)
mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer())

View File

@@ -2,9 +2,6 @@ package grpc
import (
"context"
"regexp"
"sync"
"time"
pb "github.com/golang/protobuf/proto" // nolint
log "github.com/sirupsen/logrus"
@@ -21,27 +18,6 @@ import (
internalStatus "github.com/netbirdio/netbird/shared/management/status"
)
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
const (
exposeTTL = 90 * time.Second
exposeReapInterval = 30 * time.Second
maxExposesPerPeer = 10
)
type activeExpose struct {
mu sync.Mutex
serviceID string
domain string
accountID string
peerID string
lastRenewed time.Time
}
func exposeKey(peerID, domain string) string {
return peerID + ":" + domain
}
// CreateExpose handles a peer request to create a new expose service.
func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
exposeReq := &proto.ExposeServiceRequest{}
@@ -58,72 +34,29 @@ func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage)
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
if exposeReq.Protocol != proto.ExposeProtocol_EXPOSE_HTTP && exposeReq.Protocol != proto.ExposeProtocol_EXPOSE_HTTPS {
return nil, status.Errorf(codes.InvalidArgument, "only HTTP or HTTPS protocol are supported")
}
if exposeReq.Pin != "" && !pinRegexp.MatchString(exposeReq.Pin) {
return nil, status.Errorf(codes.InvalidArgument, "invalid pin: must be exactly 6 digits")
}
for _, g := range exposeReq.UserGroups {
if g == "" {
return nil, status.Errorf(codes.InvalidArgument, "user group name cannot be empty")
}
}
reverseProxyMgr := s.getReverseProxyManager()
if reverseProxyMgr == nil {
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
if err := reverseProxyMgr.ValidateExposePermission(ctx, accountID, peer.ID); err != nil {
log.WithContext(ctx).Debugf("expose permission denied for peer %s: %v", peer.ID, err)
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
serviceName, err := reverseproxy.GenerateExposeName(exposeReq.NamePrefix)
created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &reverseproxy.ExposeServiceRequest{
NamePrefix: exposeReq.NamePrefix,
Port: int(exposeReq.Port),
Protocol: exposeProtocolToString(exposeReq.Protocol),
Domain: exposeReq.Domain,
Pin: exposeReq.Pin,
Password: exposeReq.Password,
UserGroups: exposeReq.UserGroups,
})
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "generate service name: %v", err)
return nil, mapExposeError(ctx, err)
}
service := reverseproxy.FromExposeRequest(exposeReq, accountID, peer.ID, serviceName)
// Serialize the count check to prevent concurrent CreateExpose calls from
// exceeding maxExposesPerPeer. The lock is held only for the check; the
// actual service creation happens outside the lock.
s.exposeCreateMu.Lock()
if s.countPeerExposes(peer.ID) >= maxExposesPerPeer {
s.exposeCreateMu.Unlock()
return nil, status.Errorf(codes.ResourceExhausted, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
s.exposeCreateMu.Unlock()
created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, service)
if err != nil {
log.WithContext(ctx).Errorf("failed to create service from peer: %v", err)
return nil, status.Errorf(codes.Internal, "create service: %v", err)
}
key := exposeKey(peer.ID, created.Domain)
if _, loaded := s.activeExposes.LoadOrStore(key, &activeExpose{
serviceID: created.ID,
domain: created.Domain,
accountID: accountID,
peerID: peer.ID,
lastRenewed: time.Now(),
}); loaded {
s.deleteExposeService(ctx, accountID, peer.ID, created)
return nil, status.Errorf(codes.AlreadyExists, "peer already has an active expose session for this domain")
}
resp := &proto.ExposeServiceResponse{
ServiceName: created.Name,
ServiceUrl: "https://" + created.Domain,
return s.encryptResponse(peerKey, &proto.ExposeServiceResponse{
ServiceName: created.ServiceName,
ServiceUrl: created.ServiceURL,
Domain: created.Domain,
}
return s.encryptResponse(peerKey, resp)
})
}
// RenewExpose extends the TTL of an active expose session.
@@ -134,21 +67,19 @@ func (s *Server) RenewExpose(ctx context.Context, req *proto.EncryptedMessage) (
return nil, err
}
_, peer, err := s.authenticateExposePeer(ctx, peerKey)
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
if err != nil {
return nil, err
}
key := exposeKey(peer.ID, renewReq.Domain)
val, ok := s.activeExposes.Load(key)
if !ok {
return nil, status.Errorf(codes.NotFound, "no active expose session for domain %s", renewReq.Domain)
reverseProxyMgr := s.getReverseProxyManager()
if reverseProxyMgr == nil {
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
expose := val.(*activeExpose)
expose.mu.Lock()
expose.lastRenewed = time.Now()
expose.mu.Unlock()
if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, renewReq.Domain); err != nil {
return nil, mapExposeError(ctx, err)
}
return s.encryptResponse(peerKey, &proto.RenewExposeResponse{})
}
@@ -161,55 +92,45 @@ func (s *Server) StopExpose(ctx context.Context, req *proto.EncryptedMessage) (*
return nil, err
}
_, peer, err := s.authenticateExposePeer(ctx, peerKey)
accountID, peer, err := s.authenticateExposePeer(ctx, peerKey)
if err != nil {
return nil, err
}
key := exposeKey(peer.ID, stopReq.Domain)
val, ok := s.activeExposes.LoadAndDelete(key)
if !ok {
return nil, status.Errorf(codes.NotFound, "no active expose session for domain %s", stopReq.Domain)
reverseProxyMgr := s.getReverseProxyManager()
if reverseProxyMgr == nil {
return nil, status.Errorf(codes.Internal, "reverse proxy manager not available")
}
expose := val.(*activeExpose)
s.cleanupExpose(expose, false)
if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, stopReq.Domain); err != nil {
return nil, mapExposeError(ctx, err)
}
return s.encryptResponse(peerKey, &proto.StopExposeResponse{})
}
// StartExposeReaper starts a background goroutine that reaps expired expose sessions.
func (s *Server) StartExposeReaper(ctx context.Context) {
go func() {
ticker := time.NewTicker(exposeReapInterval)
defer ticker.Stop()
func mapExposeError(ctx context.Context, err error) error {
s, ok := internalStatus.FromError(err)
if !ok {
log.WithContext(ctx).Errorf("expose service error: %v", err)
return status.Errorf(codes.Internal, "internal error")
}
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.reapExpiredExposes()
}
}
}()
}
func (s *Server) reapExpiredExposes() {
s.activeExposes.Range(func(key, val any) bool {
expose := val.(*activeExpose)
expose.mu.Lock()
expired := time.Since(expose.lastRenewed) > exposeTTL
expose.mu.Unlock()
if expired {
if _, deleted := s.activeExposes.LoadAndDelete(key); deleted {
log.Infof("reaping expired expose session for peer %s, domain %s", expose.peerID, expose.domain)
s.cleanupExpose(expose, true)
}
}
return true
})
switch s.Type() {
case internalStatus.InvalidArgument:
return status.Errorf(codes.InvalidArgument, "%s", s.Message)
case internalStatus.PermissionDenied:
return status.Errorf(codes.PermissionDenied, "%s", s.Message)
case internalStatus.NotFound:
return status.Errorf(codes.NotFound, "%s", s.Message)
case internalStatus.AlreadyExists:
return status.Errorf(codes.AlreadyExists, "%s", s.Message)
case internalStatus.PreconditionFailed:
return status.Errorf(codes.ResourceExhausted, "%s", s.Message)
default:
log.WithContext(ctx).Errorf("expose service error: %v", err)
return status.Errorf(codes.Internal, "internal error")
}
}
func (s *Server) encryptResponse(peerKey wgtypes.Key, msg pb.Message) (*proto.EncryptedMessage, error) {
@@ -246,47 +167,6 @@ func (s *Server) authenticateExposePeer(ctx context.Context, peerKey wgtypes.Key
return accountID, peer, nil
}
func (s *Server) deleteExposeService(ctx context.Context, accountID, peerID string, service *reverseproxy.Service) {
reverseProxyMgr := s.getReverseProxyManager()
if reverseProxyMgr == nil {
return
}
if err := reverseProxyMgr.DeleteServiceFromPeer(ctx, accountID, peerID, service.ID); err != nil {
log.WithContext(ctx).Debugf("failed to delete expose service %s: %v", service.ID, err)
}
}
func (s *Server) cleanupExpose(expose *activeExpose, expired bool) {
bgCtx := context.Background()
reverseProxyMgr := s.getReverseProxyManager()
if reverseProxyMgr == nil {
log.Errorf("cannot cleanup exposed service %s: reverse proxy manager not available", expose.serviceID)
return
}
var err error
if expired {
err = reverseProxyMgr.ExpireServiceFromPeer(bgCtx, expose.accountID, expose.peerID, expose.serviceID)
} else {
err = reverseProxyMgr.DeleteServiceFromPeer(bgCtx, expose.accountID, expose.peerID, expose.serviceID)
}
if err != nil {
log.Errorf("failed to delete peer-exposed service %s: %v", expose.serviceID, err)
}
}
func (s *Server) countPeerExposes(peerID string) int {
count := 0
s.activeExposes.Range(func(_, val any) bool {
if expose := val.(*activeExpose); expose.peerID == peerID {
count++
}
return true
})
return count
}
func (s *Server) getReverseProxyManager() reverseproxy.Manager {
s.reverseProxyMu.RLock()
defer s.reverseProxyMu.RUnlock()
@@ -299,3 +179,14 @@ func (s *Server) SetReverseProxyManager(mgr reverseproxy.Manager) {
defer s.reverseProxyMu.Unlock()
s.reverseProxyManager = mgr
}
func exposeProtocolToString(p proto.ExposeProtocol) string {
switch p {
case proto.ExposeProtocol_EXPOSE_HTTP:
return "http"
case proto.ExposeProtocol_EXPOSE_HTTPS:
return "https"
default:
return "http"
}
}

View File

@@ -1,242 +0,0 @@
package grpc
import (
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
)
func TestPinValidation(t *testing.T) {
tests := []struct {
pin string
valid bool
}{
{"123456", true},
{"000000", true},
{"12345", false},
{"1234567", false},
{"abcdef", false},
{"12345a", false},
{"", false},
{"12 345", false},
}
for _, tt := range tests {
assert.Equal(t, tt.valid, pinRegexp.MatchString(tt.pin), "pin %q", tt.pin)
}
}
func TestExposeKey(t *testing.T) {
assert.Equal(t, "peer1:example.com", exposeKey("peer1", "example.com"))
assert.Equal(t, "peer2:other.com", exposeKey("peer2", "other.com"))
assert.NotEqual(t, exposeKey("peer1", "a.com"), exposeKey("peer1", "b.com"))
}
func TestCountPeerExposes(t *testing.T) {
s := &Server{}
// No exposes
assert.Equal(t, 0, s.countPeerExposes("peer1"))
// Add some exposes for different peers
s.activeExposes.Store("peer1:a.com", &activeExpose{peerID: "peer1"})
s.activeExposes.Store("peer1:b.com", &activeExpose{peerID: "peer1"})
s.activeExposes.Store("peer2:a.com", &activeExpose{peerID: "peer2"})
assert.Equal(t, 2, s.countPeerExposes("peer1"), "peer1 should have 2 exposes")
assert.Equal(t, 1, s.countPeerExposes("peer2"), "peer2 should have 1 expose")
assert.Equal(t, 0, s.countPeerExposes("peer3"), "peer3 should have 0 exposes")
}
func TestReapExpiredExposes(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockMgr := reverseproxy.NewMockManager(ctrl)
s := &Server{}
s.SetReverseProxyManager(mockMgr)
now := time.Now()
// Add an expired expose and a still-active one
s.activeExposes.Store("peer1:expired.com", &activeExpose{
serviceID: "svc-expired",
domain: "expired.com",
accountID: "acct1",
peerID: "peer1",
lastRenewed: now.Add(-2 * exposeTTL),
})
s.activeExposes.Store("peer1:active.com", &activeExpose{
serviceID: "svc-active",
domain: "active.com",
accountID: "acct1",
peerID: "peer1",
lastRenewed: now,
})
// Expect ExpireServiceFromPeer called only for the expired one
mockMgr.EXPECT().
ExpireServiceFromPeer(gomock.Any(), "acct1", "peer1", "svc-expired").
Return(nil)
s.reapExpiredExposes()
// Verify expired one is removed
_, exists := s.activeExposes.Load("peer1:expired.com")
assert.False(t, exists, "expired expose should be removed")
// Verify active one remains
_, exists = s.activeExposes.Load("peer1:active.com")
assert.True(t, exists, "active expose should remain")
}
func TestCleanupExpose_Delete(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockMgr := reverseproxy.NewMockManager(ctrl)
s := &Server{}
s.SetReverseProxyManager(mockMgr)
mockMgr.EXPECT().
DeleteServiceFromPeer(gomock.Any(), "acct1", "peer1", "svc1").
Return(nil)
s.cleanupExpose(&activeExpose{
serviceID: "svc1",
accountID: "acct1",
peerID: "peer1",
}, false)
}
func TestCleanupExpose_Expire(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockMgr := reverseproxy.NewMockManager(ctrl)
s := &Server{}
s.SetReverseProxyManager(mockMgr)
mockMgr.EXPECT().
ExpireServiceFromPeer(gomock.Any(), "acct1", "peer1", "svc1").
Return(nil)
s.cleanupExpose(&activeExpose{
serviceID: "svc1",
accountID: "acct1",
peerID: "peer1",
}, true)
}
func TestCleanupExpose_NilManager(t *testing.T) {
s := &Server{}
// Should not panic when reverse proxy manager is nil
s.cleanupExpose(&activeExpose{
serviceID: "svc1",
accountID: "acct1",
peerID: "peer1",
}, false)
}
func TestSetReverseProxyManager(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
s := &Server{}
// Initially nil
assert.Nil(t, s.getReverseProxyManager())
mockMgr := reverseproxy.NewMockManager(ctrl)
s.SetReverseProxyManager(mockMgr)
assert.NotNil(t, s.getReverseProxyManager())
// Can set to nil
s.SetReverseProxyManager(nil)
assert.Nil(t, s.getReverseProxyManager())
}
func TestReapExpiredExposes_ConcurrentSafety(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockMgr := reverseproxy.NewMockManager(ctrl)
mockMgr.EXPECT().
ExpireServiceFromPeer(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()
s := &Server{}
s.SetReverseProxyManager(mockMgr)
// Pre-populate with expired sessions
for i := range 20 {
peerID := "peer1"
domain := "domain-" + string(rune('a'+i))
s.activeExposes.Store(exposeKey(peerID, domain), &activeExpose{
serviceID: "svc-" + domain,
domain: domain,
accountID: "acct1",
peerID: peerID,
lastRenewed: time.Now().Add(-2 * exposeTTL),
})
}
// Run reaper concurrently with count
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
s.reapExpiredExposes()
}()
go func() {
defer wg.Done()
s.countPeerExposes("peer1")
}()
wg.Wait()
assert.Equal(t, 0, s.countPeerExposes("peer1"), "all expired exposes should be reaped")
}
func TestActiveExposeMutexProtectsLastRenewed(t *testing.T) {
expose := &activeExpose{
lastRenewed: time.Now().Add(-1 * time.Hour),
}
// Simulate concurrent renew and read
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
for range 100 {
expose.mu.Lock()
expose.lastRenewed = time.Now()
expose.mu.Unlock()
}
}()
go func() {
defer wg.Done()
for range 100 {
expose.mu.Lock()
_ = time.Since(expose.lastRenewed)
expose.mu.Unlock()
}
}()
wg.Wait()
expose.mu.Lock()
require.False(t, expose.lastRenewed.IsZero(), "lastRenewed should not be zero after concurrent access")
expose.mu.Unlock()
}

View File

@@ -76,21 +76,19 @@ func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _
return "", nil
}
func (m *mockReverseProxyManager) ValidateExposePermission(_ context.Context, _, _ string) error {
func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
return &reverseproxy.ExposeServiceResponse{}, nil
}
func (m *mockReverseProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *mockReverseProxyManager) DeleteServiceFromPeer(_ context.Context, _, _, _ string) error {
func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *mockReverseProxyManager) ExpireServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {}
type mockUsersManager struct {
users map[string]*types.User

View File

@@ -82,8 +82,6 @@ type Server struct {
syncLimEnabled bool
syncLim int32
activeExposes sync.Map
exposeCreateMu sync.Mutex
reverseProxyManager reverseproxy.Manager
reverseProxyMu sync.RWMutex
}

View File

@@ -196,7 +196,7 @@ func TestValidateSession_ProxyNotFound(t *testing.T) {
require.NoError(t, err)
assert.False(t, resp.Valid, "Unknown proxy should be denied")
assert.Equal(t, "proxy_not_found", resp.DeniedReason)
assert.Equal(t, "service_not_found", resp.DeniedReason)
}
func TestValidateSession_InvalidToken(t *testing.T) {
@@ -263,6 +263,10 @@ func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _,
return nil
}
func (m *testValidateSessionProxyManager) DeleteAllServices(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error {
return nil
}
@@ -295,22 +299,20 @@ func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Conte
return "", nil
}
func (m *testValidateSessionProxyManager) ValidateExposePermission(_ context.Context, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *testValidateSessionProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
return nil, nil
}
func (m *testValidateSessionProxyManager) DeleteServiceFromPeer(_ context.Context, _, _, _ string) error {
func (m *testValidateSessionProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) ExpireServiceFromPeer(_ context.Context, _, _, _ string) error {
func (m *testValidateSessionProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testValidateSessionProxyManager) StartExposeReaper(_ context.Context) {}
type testValidateSessionUsersManager struct {
store store.Store
}

View File

@@ -413,22 +413,20 @@ func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ stri
return "", nil
}
func (m *testServiceManager) ValidateExposePermission(_ context.Context, _, _ string) error {
return nil
}
func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
return nil, nil
}
func (m *testServiceManager) DeleteServiceFromPeer(_ context.Context, _, _, _ string) error {
func (m *testServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testServiceManager) ExpireServiceFromPeer(_ context.Context, _, _, _ string) error {
func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *testServiceManager) StartExposeReaper(_ context.Context) {}
func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string {
t.Helper()

View File

@@ -247,21 +247,19 @@ func (m *storeBackedServiceManager) GetServiceIDByTargetID(ctx context.Context,
return "", nil
}
func (m *storeBackedServiceManager) ValidateExposePermission(_ context.Context, _, _ string) error {
func (m *storeBackedServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
return &reverseproxy.ExposeServiceResponse{}, nil
}
func (m *storeBackedServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *storeBackedServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) {
return &reverseproxy.Service{}, nil
}
func (m *storeBackedServiceManager) DeleteServiceFromPeer(_ context.Context, _, _, _ string) error {
func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *storeBackedServiceManager) ExpireServiceFromPeer(_ context.Context, _, _, _ string) error {
return nil
}
func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {}
func strPtr(s string) *string {
return &s