diff --git a/client/internal/expose/manager.go b/client/internal/expose/manager.go index ba6aa6dc9..8cd93685e 100644 --- a/client/internal/expose/manager.go +++ b/client/internal/expose/manager.go @@ -58,7 +58,7 @@ func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) { } func (m *Manager) KeepAlive(ctx context.Context, domain string) error { - ticker := time.NewTicker(10 * time.Second) + ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() defer m.stop(domain) diff --git a/management/internals/modules/reverseproxy/interface.go b/management/internals/modules/reverseproxy/interface.go index 95402bdf7..e7a21a24c 100644 --- a/management/internals/modules/reverseproxy/interface.go +++ b/management/internals/modules/reverseproxy/interface.go @@ -21,8 +21,8 @@ type Manager interface { GetServiceByID(ctx context.Context, accountID, serviceID string) (*Service, error) GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) GetServiceIDByTargetID(ctx context.Context, accountID string, resourceID string) (string, error) - ValidateExposePermission(ctx context.Context, accountID, peerID string) error - CreateServiceFromPeer(ctx context.Context, accountID, peerID string, service *Service) (*Service, error) - DeleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error - ExpireServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error + CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error) + RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error + StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error + StartExposeReaper(ctx context.Context) } diff --git a/management/internals/modules/reverseproxy/interface_mock.go b/management/internals/modules/reverseproxy/interface_mock.go index 19a4ecfe5..893025195 100644 --- a/management/internals/modules/reverseproxy/interface_mock.go +++ b/management/internals/modules/reverseproxy/interface_mock.go @@ -49,6 +49,21 @@ func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service) } +// CreateServiceFromPeer mocks base method. +func (m *MockManager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *ExposeServiceRequest) (*ExposeServiceResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateServiceFromPeer", ctx, accountID, peerID, req) + ret0, _ := ret[0].(*ExposeServiceResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateServiceFromPeer indicates an expected call of CreateServiceFromPeer. +func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID, req interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, req) +} + // DeleteAllServices mocks base method. func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error { m.ctrl.T.Helper() @@ -63,21 +78,6 @@ func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID) } -// CreateServiceFromPeer mocks base method. -func (m *MockManager) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, service *Service) (*Service, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateServiceFromPeer", ctx, accountID, peerID, service) - ret0, _ := ret[0].(*Service) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateServiceFromPeer indicates an expected call of CreateServiceFromPeer. -func (mr *MockManagerMockRecorder) CreateServiceFromPeer(ctx, accountID, peerID, service interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateServiceFromPeer", reflect.TypeOf((*MockManager)(nil).CreateServiceFromPeer), ctx, accountID, peerID, service) -} - // DeleteService mocks base method. func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error { m.ctrl.T.Helper() @@ -92,48 +92,6 @@ func (mr *MockManagerMockRecorder) DeleteService(ctx, accountID, userID, service return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteService", reflect.TypeOf((*MockManager)(nil).DeleteService), ctx, accountID, userID, serviceID) } -// DeleteServiceFromPeer mocks base method. -func (m *MockManager) DeleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteServiceFromPeer", ctx, accountID, peerID, serviceID) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteServiceFromPeer indicates an expected call of DeleteServiceFromPeer. -func (mr *MockManagerMockRecorder) DeleteServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteServiceFromPeer", reflect.TypeOf((*MockManager)(nil).DeleteServiceFromPeer), ctx, accountID, peerID, serviceID) -} - -// ExpireServiceFromPeer mocks base method. -func (m *MockManager) ExpireServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ExpireServiceFromPeer", ctx, accountID, peerID, serviceID) - ret0, _ := ret[0].(error) - return ret0 -} - -// ExpireServiceFromPeer indicates an expected call of ExpireServiceFromPeer. -func (mr *MockManagerMockRecorder) ExpireServiceFromPeer(ctx, accountID, peerID, serviceID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpireServiceFromPeer", reflect.TypeOf((*MockManager)(nil).ExpireServiceFromPeer), ctx, accountID, peerID, serviceID) -} - -// ValidateExposePermission mocks base method. -func (m *MockManager) ValidateExposePermission(ctx context.Context, accountID, peerID string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ValidateExposePermission", ctx, accountID, peerID) - ret0, _ := ret[0].(error) - return ret0 -} - -// ValidateExposePermission indicates an expected call of ValidateExposePermission. -func (mr *MockManagerMockRecorder) ValidateExposePermission(ctx, accountID, peerID interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateExposePermission", reflect.TypeOf((*MockManager)(nil).ValidateExposePermission), ctx, accountID, peerID) -} - // GetAccountServices mocks base method. func (m *MockManager) GetAccountServices(ctx context.Context, accountID string) ([]*Service, error) { m.ctrl.T.Helper() @@ -252,6 +210,20 @@ func (mr *MockManagerMockRecorder) ReloadService(ctx, accountID, serviceID inter return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReloadService", reflect.TypeOf((*MockManager)(nil).ReloadService), ctx, accountID, serviceID) } +// RenewServiceFromPeer mocks base method. +func (m *MockManager) RenewServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RenewServiceFromPeer", ctx, accountID, peerID, domain) + ret0, _ := ret[0].(error) + return ret0 +} + +// RenewServiceFromPeer indicates an expected call of RenewServiceFromPeer. +func (mr *MockManagerMockRecorder) RenewServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RenewServiceFromPeer", reflect.TypeOf((*MockManager)(nil).RenewServiceFromPeer), ctx, accountID, peerID, domain) +} + // SetCertificateIssuedAt mocks base method. func (m *MockManager) SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error { m.ctrl.T.Helper() @@ -280,6 +252,32 @@ func (mr *MockManagerMockRecorder) SetStatus(ctx, accountID, serviceID, status i return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetStatus", reflect.TypeOf((*MockManager)(nil).SetStatus), ctx, accountID, serviceID, status) } +// StartExposeReaper mocks base method. +func (m *MockManager) StartExposeReaper(ctx context.Context) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StartExposeReaper", ctx) +} + +// StartExposeReaper indicates an expected call of StartExposeReaper. +func (mr *MockManagerMockRecorder) StartExposeReaper(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartExposeReaper", reflect.TypeOf((*MockManager)(nil).StartExposeReaper), ctx) +} + +// StopServiceFromPeer mocks base method. +func (m *MockManager) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StopServiceFromPeer", ctx, accountID, peerID, domain) + ret0, _ := ret[0].(error) + return ret0 +} + +// StopServiceFromPeer indicates an expected call of StopServiceFromPeer. +func (mr *MockManagerMockRecorder) StopServiceFromPeer(ctx, accountID, peerID, domain interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopServiceFromPeer", reflect.TypeOf((*MockManager)(nil).StopServiceFromPeer), ctx, accountID, peerID, domain) +} + // UpdateService mocks base method. func (m *MockManager) UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error) { m.ctrl.T.Helper() diff --git a/management/internals/modules/reverseproxy/manager/expose_tracker.go b/management/internals/modules/reverseproxy/manager/expose_tracker.go new file mode 100644 index 000000000..ef285e923 --- /dev/null +++ b/management/internals/modules/reverseproxy/manager/expose_tracker.go @@ -0,0 +1,163 @@ +package manager + +import ( + "context" + "sync" + "time" + + "github.com/netbirdio/netbird/shared/management/status" + log "github.com/sirupsen/logrus" +) + +const ( + exposeTTL = 90 * time.Second + exposeReapInterval = 30 * time.Second + maxExposesPerPeer = 10 +) + +type trackedExpose struct { + mu sync.Mutex + domain string + accountID string + peerID string + lastRenewed time.Time + expiring bool +} + +type exposeTracker struct { + activeExposes sync.Map + exposeCreateMu sync.Mutex + manager *managerImpl +} + +func exposeKey(peerID, domain string) string { + return peerID + ":" + domain +} + +// TrackExposeIfAllowed atomically checks the per-peer limit and registers a new +// active expose session under the same lock. Returns (true, false) if the expose +// was already tracked (duplicate), (false, true) if tracking succeeded, and +// (false, false) if the peer has reached the limit. +func (t *exposeTracker) TrackExposeIfAllowed(peerID, domain, accountID string) (alreadyTracked, ok bool) { + t.exposeCreateMu.Lock() + defer t.exposeCreateMu.Unlock() + + key := exposeKey(peerID, domain) + _, loaded := t.activeExposes.LoadOrStore(key, &trackedExpose{ + domain: domain, + accountID: accountID, + peerID: peerID, + lastRenewed: time.Now(), + }) + if loaded { + return true, false + } + + if t.CountPeerExposes(peerID) > maxExposesPerPeer { + t.activeExposes.Delete(key) + return false, false + } + + return false, true +} + +// UntrackExpose removes an active expose session from tracking. +func (t *exposeTracker) UntrackExpose(peerID, domain string) { + t.activeExposes.Delete(exposeKey(peerID, domain)) +} + +// CountPeerExposes returns the number of active expose sessions for a peer. +func (t *exposeTracker) CountPeerExposes(peerID string) int { + count := 0 + t.activeExposes.Range(func(_, val any) bool { + if expose := val.(*trackedExpose); expose.peerID == peerID { + count++ + } + return true + }) + return count +} + +// MaxExposesPerPeer returns the maximum number of concurrent exposes allowed per peer. +func (t *exposeTracker) MaxExposesPerPeer() int { + return maxExposesPerPeer +} + +// RenewTrackedExpose updates the in-memory lastRenewed timestamp for a tracked expose. +// Returns false if the expose is not tracked or is being reaped. +func (t *exposeTracker) RenewTrackedExpose(peerID, domain string) bool { + key := exposeKey(peerID, domain) + val, ok := t.activeExposes.Load(key) + if !ok { + return false + } + + expose := val.(*trackedExpose) + expose.mu.Lock() + if expose.expiring { + expose.mu.Unlock() + return false + } + expose.lastRenewed = time.Now() + expose.mu.Unlock() + + return true +} + +// StopTrackedExpose removes an active expose session from tracking. +// Returns false if the expose was not tracked. +func (t *exposeTracker) StopTrackedExpose(peerID, domain string) bool { + key := exposeKey(peerID, domain) + _, ok := t.activeExposes.LoadAndDelete(key) + return ok +} + +// StartExposeReaper starts a background goroutine that reaps expired expose sessions. +func (t *exposeTracker) StartExposeReaper(ctx context.Context) { + go func() { + ticker := time.NewTicker(exposeReapInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + t.reapExpiredExposes() + } + } + }() +} + +func (t *exposeTracker) reapExpiredExposes() { + t.activeExposes.Range(func(key, val any) bool { + expose := val.(*trackedExpose) + expose.mu.Lock() + expired := time.Since(expose.lastRenewed) > exposeTTL + if expired { + expose.expiring = true + } + expose.mu.Unlock() + + if !expired { + return true + } + + log.Infof("reaping expired expose session for peer %s, domain %s", expose.peerID, expose.domain) + + err := t.manager.deleteServiceFromPeer(context.Background(), expose.accountID, expose.peerID, expose.domain, true) + + s, _ := status.FromError(err) + + switch { + case err == nil: + t.activeExposes.Delete(key) + case s.ErrorType == status.NotFound: + log.Debugf("service %s was already deleted", expose.domain) + default: + log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", expose.domain, err) + } + + return true + }) +} diff --git a/management/internals/modules/reverseproxy/manager/expose_tracker_test.go b/management/internals/modules/reverseproxy/manager/expose_tracker_test.go new file mode 100644 index 000000000..2dc726590 --- /dev/null +++ b/management/internals/modules/reverseproxy/manager/expose_tracker_test.go @@ -0,0 +1,256 @@ +package manager + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" +) + +func TestExposeKey(t *testing.T) { + assert.Equal(t, "peer1:example.com", exposeKey("peer1", "example.com")) + assert.Equal(t, "peer2:other.com", exposeKey("peer2", "other.com")) + assert.NotEqual(t, exposeKey("peer1", "a.com"), exposeKey("peer1", "b.com")) +} + +func TestTrackExposeIfAllowed(t *testing.T) { + t.Run("first track succeeds", func(t *testing.T) { + tracker := &exposeTracker{} + alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1") + assert.False(t, alreadyTracked, "first track should not be duplicate") + assert.True(t, ok, "first track should be allowed") + }) + + t.Run("duplicate track detected", func(t *testing.T) { + tracker := &exposeTracker{} + tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1") + + alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1") + assert.True(t, alreadyTracked, "second track should be duplicate") + assert.False(t, ok) + }) + + t.Run("rejects when at limit", func(t *testing.T) { + tracker := &exposeTracker{} + for i := range maxExposesPerPeer { + _, ok := tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1") + assert.True(t, ok, "track %d should be allowed", i) + } + + alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "over-limit.com", "acct1") + assert.False(t, alreadyTracked) + assert.False(t, ok, "should reject when at limit") + }) + + t.Run("other peer unaffected by limit", func(t *testing.T) { + tracker := &exposeTracker{} + for i := range maxExposesPerPeer { + tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1") + } + + _, ok := tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1") + assert.True(t, ok, "other peer should still be within limit") + }) +} + +func TestUntrackExpose(t *testing.T) { + tracker := &exposeTracker{} + + tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1") + assert.Equal(t, 1, tracker.CountPeerExposes("peer1")) + + tracker.UntrackExpose("peer1", "a.com") + assert.Equal(t, 0, tracker.CountPeerExposes("peer1")) +} + +func TestCountPeerExposes(t *testing.T) { + tracker := &exposeTracker{} + + assert.Equal(t, 0, tracker.CountPeerExposes("peer1")) + + tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1") + tracker.TrackExposeIfAllowed("peer1", "b.com", "acct1") + tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1") + + assert.Equal(t, 2, tracker.CountPeerExposes("peer1"), "peer1 should have 2 exposes") + assert.Equal(t, 1, tracker.CountPeerExposes("peer2"), "peer2 should have 1 expose") + assert.Equal(t, 0, tracker.CountPeerExposes("peer3"), "peer3 should have 0 exposes") +} + +func TestMaxExposesPerPeer(t *testing.T) { + tracker := &exposeTracker{} + assert.Equal(t, maxExposesPerPeer, tracker.MaxExposesPerPeer()) +} + +func TestRenewTrackedExpose(t *testing.T) { + tracker := &exposeTracker{} + + found := tracker.RenewTrackedExpose("peer1", "a.com") + assert.False(t, found, "should not find untracked expose") + + tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1") + + found = tracker.RenewTrackedExpose("peer1", "a.com") + assert.True(t, found, "should find tracked expose") +} + +func TestRenewTrackedExpose_RejectsExpiring(t *testing.T) { + tracker := &exposeTracker{} + tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1") + + // Simulate reaper marking the expose as expiring + key := exposeKey("peer1", "a.com") + val, _ := tracker.activeExposes.Load(key) + expose := val.(*trackedExpose) + expose.mu.Lock() + expose.expiring = true + expose.mu.Unlock() + + found := tracker.RenewTrackedExpose("peer1", "a.com") + assert.False(t, found, "should reject renewal when expiring") +} + +func TestReapExpiredExposes(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + tracker := mgr.exposeTracker + + ctx := context.Background() + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ + Port: 8080, + Protocol: "http", + }) + require.NoError(t, err) + + // Manually expire the tracked entry + key := exposeKey(testPeerID, resp.Domain) + val, _ := tracker.activeExposes.Load(key) + expose := val.(*trackedExpose) + expose.mu.Lock() + expose.lastRenewed = time.Now().Add(-2 * exposeTTL) + expose.mu.Unlock() + + // Add an active (non-expired) tracking entry + tracker.activeExposes.Store(exposeKey("peer1", "active.com"), &trackedExpose{ + domain: "active.com", + accountID: testAccountID, + peerID: "peer1", + lastRenewed: time.Now(), + }) + + tracker.reapExpiredExposes() + + _, exists := tracker.activeExposes.Load(key) + assert.False(t, exists, "expired expose should be removed") + + _, exists = tracker.activeExposes.Load(exposeKey("peer1", "active.com")) + assert.True(t, exists, "active expose should remain") +} + +func TestReapExpiredExposes_SetsExpiringFlag(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + tracker := mgr.exposeTracker + + ctx := context.Background() + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ + Port: 8080, + Protocol: "http", + }) + require.NoError(t, err) + + key := exposeKey(testPeerID, resp.Domain) + val, _ := tracker.activeExposes.Load(key) + expose := val.(*trackedExpose) + + // Expire it + expose.mu.Lock() + expose.lastRenewed = time.Now().Add(-2 * exposeTTL) + expose.mu.Unlock() + + // Renew should succeed before reaping + assert.True(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should succeed before reaper runs") + + // Re-expire and reap + expose.mu.Lock() + expose.lastRenewed = time.Now().Add(-2 * exposeTTL) + expose.mu.Unlock() + + tracker.reapExpiredExposes() + + // Entry is deleted, renew returns false + assert.False(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should fail after reap") +} + +func TestConcurrentTrackAndCount(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + tracker := mgr.exposeTracker + ctx := context.Background() + + for i := range 5 { + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ + Port: 8080 + i, + Protocol: "http", + }) + require.NoError(t, err) + } + + // Manually expire all tracked entries + tracker.activeExposes.Range(func(_, val any) bool { + expose := val.(*trackedExpose) + expose.mu.Lock() + expose.lastRenewed = time.Now().Add(-2 * exposeTTL) + expose.mu.Unlock() + return true + }) + + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + tracker.reapExpiredExposes() + }() + go func() { + defer wg.Done() + tracker.CountPeerExposes(testPeerID) + }() + wg.Wait() + + assert.Equal(t, 0, tracker.CountPeerExposes(testPeerID), "all expired exposes should be reaped") +} + +func TestTrackedExposeMutexProtectsLastRenewed(t *testing.T) { + expose := &trackedExpose{ + lastRenewed: time.Now().Add(-1 * time.Hour), + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + for range 100 { + expose.mu.Lock() + expose.lastRenewed = time.Now() + expose.mu.Unlock() + } + }() + + go func() { + defer wg.Done() + for range 100 { + expose.mu.Lock() + _ = time.Since(expose.lastRenewed) + expose.mu.Unlock() + } + }() + + wg.Wait() + + expose.mu.Lock() + require.False(t, expose.lastRenewed.IsZero(), "lastRenewed should not be zero after concurrent access") + expose.mu.Unlock() +} diff --git a/management/internals/modules/reverseproxy/manager/manager.go b/management/internals/modules/reverseproxy/manager/manager.go index ac839b8ea..b2c67e0c1 100644 --- a/management/internals/modules/reverseproxy/manager/manager.go +++ b/management/internals/modules/reverseproxy/manager/manager.go @@ -40,11 +40,12 @@ type managerImpl struct { settingsManager settings.Manager proxyGRPCServer *nbgrpc.ProxyServiceServer clusterDeriver ClusterDeriver + exposeTracker *exposeTracker } // NewManager creates a new service manager. func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, settingsManager settings.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager { - return &managerImpl{ + mgr := &managerImpl{ store: store, accountManager: accountManager, permissionsManager: permissionsManager, @@ -52,6 +53,13 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa proxyGRPCServer: proxyGRPCServer, clusterDeriver: clusterDeriver, } + mgr.exposeTracker = &exposeTracker{manager: mgr} + return mgr +} + +// StartExposeReaper delegates to the expose tracker. +func (m *managerImpl) StartExposeReaper(ctx context.Context) { + m.exposeTracker.StartExposeReaper(ctx) } func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) { @@ -418,6 +426,10 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv return err } + if service.Source == reverseproxy.SourceEphemeral { + m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain) + } + m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta()) m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "") @@ -460,6 +472,9 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig() for _, service := range services { + if service.Source == reverseproxy.SourceEphemeral { + m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain) + } m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, service.EventMeta()) mapping := service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg) clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping) @@ -617,9 +632,9 @@ func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID stri return target.ServiceID, nil } -// ValidateExposePermission checks whether the peer is allowed to use the expose feature. +// validateExposePermission checks whether the peer is allowed to use the expose feature. // It verifies the account has peer expose enabled and that the peer belongs to an allowed group. -func (m *managerImpl) ValidateExposePermission(ctx context.Context, accountID, peerID string) error { +func (m *managerImpl) validateExposePermission(ctx context.Context, accountID, peerID string) error { settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get account settings: %v", err) @@ -650,8 +665,23 @@ func (m *managerImpl) ValidateExposePermission(ctx context.Context, accountID, p } // CreateServiceFromPeer creates a service initiated by a peer expose request. -// It skips user permission checks since authorization is done at the gRPC handler level. -func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, service *reverseproxy.Service) (*reverseproxy.Service, error) { +// It validates the request, checks expose permissions, enforces the per-peer limit, +// creates the service, and tracks it for TTL-based reaping. +func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { + if err := req.Validate(); err != nil { + return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err) + } + + if err := m.validateExposePermission(ctx, accountID, peerID); err != nil { + return nil, err + } + + serviceName, err := reverseproxy.GenerateExposeName(req.NamePrefix) + if err != nil { + return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err) + } + + service := req.ToService(accountID, peerID, serviceName) service.Source = reverseproxy.SourceEphemeral if service.Domain == "" { @@ -665,7 +695,7 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled { groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, service.Auth.BearerAuth.DistributionGroups) if err != nil { - return nil, fmt.Errorf("get group ids for service %s: %w", service.ID, err) + return nil, fmt.Errorf("get group ids for service %s: %w", service.Name, err) } service.Auth.BearerAuth.DistributionGroups = groupIDs } @@ -687,8 +717,21 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer return nil, err } - meta := addPeerInfoToEventMeta(service.EventMeta(), peer) + alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, service.Domain, accountID) + if alreadyTracked { + if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil { + log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", service.Domain, err) + } + return nil, status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain") + } + if !allowed { + if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil { + log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", service.Domain, err) + } + return nil, status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) + } + meta := addPeerInfoToEventMeta(service.EventMeta(), peer) m.accountManager.StoreEvent(ctx, peerID, service.ID, accountID, activity.PeerServiceExposed, meta) if err := m.replaceHostByLookup(ctx, accountID, service); err != nil { @@ -696,10 +739,13 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer } m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "") - m.accountManager.UpdateAccountPeers(ctx, accountID) - return service, nil + return &reverseproxy.ExposeServiceResponse{ + ServiceName: service.Name, + ServiceURL: "https://" + service.Domain, + Domain: service.Domain, + }, nil } func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) { @@ -718,6 +764,9 @@ func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string } func (m *managerImpl) buildRandomDomain(name string) (string, error) { + if m.clusterDeriver == nil { + return "", fmt.Errorf("unable to get random domain") + } clusterDomains := m.clusterDeriver.GetClusterDomains() if len(clusterDomains) == 0 { return "", fmt.Errorf("no cluster domains found for service %s", name) @@ -727,15 +776,60 @@ func (m *managerImpl) buildRandomDomain(name string) (string, error) { return domain, nil } -// DeleteServiceFromPeer deletes a peer-initiated service. -// It validates that the service was created by a peer to prevent deleting API-created services. -func (m *managerImpl) DeleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { - return m.deletePeerService(ctx, accountID, peerID, serviceID, activity.PeerServiceUnexposed) +// RenewServiceFromPeer renews the in-memory TTL tracker for the peer's expose session. +// Returns an error if the expose is not actively tracked. +func (m *managerImpl) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error { + if !m.exposeTracker.RenewTrackedExpose(peerID, domain) { + return status.Errorf(status.NotFound, "no active expose session for domain %s", domain) + } + return nil } -// ExpireServiceFromPeer deletes a peer-initiated service that was not renewed within the TTL. -func (m *managerImpl) ExpireServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error { - return m.deletePeerService(ctx, accountID, peerID, serviceID, activity.PeerServiceExposeExpired) +// StopServiceFromPeer stops a peer's active expose session by untracking and deleting the service. +func (m *managerImpl) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error { + if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil { + log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err) + return err + } + + if !m.exposeTracker.StopTrackedExpose(peerID, domain) { + log.WithContext(ctx).Warnf("expose tracker entry for domain %s already removed; service was deleted", domain) + } + + return nil +} + +// deleteServiceFromPeer deletes a peer-initiated service identified by domain. +// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed. +func (m *managerImpl) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error { + service, err := m.lookupPeerService(ctx, accountID, peerID, domain) + if err != nil { + return err + } + + activityCode := activity.PeerServiceUnexposed + if expired { + activityCode = activity.PeerServiceExposeExpired + } + return m.deletePeerService(ctx, accountID, peerID, service.ID, activityCode) +} + +// lookupPeerService finds a peer-initiated service by domain and validates ownership. +func (m *managerImpl) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*reverseproxy.Service, error) { + service, err := m.store.GetServiceByDomain(ctx, accountID, domain) + if err != nil { + return nil, err + } + + if service.Source != reverseproxy.SourceEphemeral { + return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose") + } + + if service.SourcePeer != peerID { + return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer") + } + + return service, nil } func (m *managerImpl) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error { diff --git a/management/internals/modules/reverseproxy/manager/manager_test.go b/management/internals/modules/reverseproxy/manager/manager_test.go index eab853cf3..17849f622 100644 --- a/management/internals/modules/reverseproxy/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/manager/manager_test.go @@ -658,6 +658,13 @@ func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) { PeerExposeEnabled: true, PeerExposeGroups: []string{testGroupID}, }, + Users: map[string]*types.User{ + testUserID: { + Id: testUserID, + AccountID: testAccountID, + Role: types.UserRoleAdmin, + }, + }, Peers: map[string]*nbpeer.Peer{ testPeerID: { ID: testPeerID, @@ -712,16 +719,17 @@ func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) { domains: []string{"test.netbird.io"}, }, } + mgr.exposeTracker = &exposeTracker{manager: mgr} return mgr, testStore } -func TestValidateExposePermission(t *testing.T) { +func Test_validateExposePermission(t *testing.T) { ctx := context.Background() t.Run("allowed when peer is in expose group", func(t *testing.T) { mgr, _ := setupIntegrationTest(t) - err := mgr.ValidateExposePermission(ctx, testAccountID, testPeerID) + err := mgr.validateExposePermission(ctx, testAccountID, testPeerID) assert.NoError(t, err) }) @@ -742,7 +750,7 @@ func TestValidateExposePermission(t *testing.T) { }) require.NoError(t, err) - err = mgr.ValidateExposePermission(ctx, testAccountID, otherPeerID) + err = mgr.validateExposePermission(ctx, testAccountID, otherPeerID) require.Error(t, err) assert.Contains(t, err.Error(), "not in an allowed expose group") }) @@ -757,7 +765,7 @@ func TestValidateExposePermission(t *testing.T) { err = testStore.SaveAccountSettings(ctx, testAccountID, s) require.NoError(t, err) - err = mgr.ValidateExposePermission(ctx, testAccountID, testPeerID) + err = mgr.validateExposePermission(ctx, testAccountID, testPeerID) require.Error(t, err) assert.Contains(t, err.Error(), "not enabled") }) @@ -772,7 +780,7 @@ func TestValidateExposePermission(t *testing.T) { err = testStore.SaveAccountSettings(ctx, testAccountID, s) require.NoError(t, err) - err = mgr.ValidateExposePermission(ctx, testAccountID, testPeerID) + err = mgr.validateExposePermission(ctx, testAccountID, testPeerID) assert.Error(t, err) }) @@ -781,7 +789,7 @@ func TestValidateExposePermission(t *testing.T) { mockStore := store.NewMockStore(ctrl) mockStore.EXPECT().GetAccountSettings(gomock.Any(), gomock.Any(), testAccountID).Return(nil, errors.New("store error")) mgr := &managerImpl{store: mockStore} - err := mgr.ValidateExposePermission(ctx, testAccountID, testPeerID) + err := mgr.validateExposePermission(ctx, testAccountID, testPeerID) require.Error(t, err) assert.Contains(t, err.Error(), "get account settings") }) @@ -793,82 +801,290 @@ func TestCreateServiceFromPeer(t *testing.T) { t.Run("creates service with random domain", func(t *testing.T) { mgr, testStore := setupIntegrationTest(t) - service := &reverseproxy.Service{ - Name: "my-expose", - Enabled: true, - Targets: []*reverseproxy.Target{ - { - AccountID: testAccountID, - Port: 8080, - Protocol: "http", - TargetId: testPeerID, - TargetType: reverseproxy.TargetTypePeer, - Enabled: true, - }, - }, + req := &reverseproxy.ExposeServiceRequest{ + Port: 8080, + Protocol: "http", } - created, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, service) + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) require.NoError(t, err) - assert.NotEmpty(t, created.ID, "service should have an ID") - assert.Contains(t, created.Domain, "test.netbird.io", "domain should use cluster domain") - assert.Equal(t, reverseproxy.SourceEphemeral, created.Source, "source should be ephemeral") - assert.Equal(t, testPeerID, created.SourcePeer, "source peer should be set") - assert.NotNil(t, created.Meta.LastRenewedAt, "last renewed should be set") + assert.NotEmpty(t, resp.ServiceName, "service name should be generated") + assert.Contains(t, resp.Domain, "test.netbird.io", "domain should use cluster domain") + assert.NotEmpty(t, resp.ServiceURL, "service URL should be set") // Verify service is persisted in store - persisted, err := testStore.GetServiceByID(ctx, store.LockingStrengthNone, testAccountID, created.ID) + persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) require.NoError(t, err) - assert.Equal(t, created.ID, persisted.ID) - assert.Equal(t, created.Domain, persisted.Domain) + assert.Equal(t, resp.Domain, persisted.Domain) + assert.Equal(t, reverseproxy.SourceEphemeral, persisted.Source, "source should be ephemeral") + assert.Equal(t, testPeerID, persisted.SourcePeer, "source peer should be set") + assert.NotNil(t, persisted.Meta.LastRenewedAt, "last renewed should be set") }) t.Run("creates service with custom domain", func(t *testing.T) { mgr, _ := setupIntegrationTest(t) - service := &reverseproxy.Service{ - Name: "custom", - Domain: "custom.example.com", - Enabled: true, - Targets: []*reverseproxy.Target{ - { - AccountID: testAccountID, - Port: 80, - Protocol: "http", - TargetId: testPeerID, - TargetType: reverseproxy.TargetTypePeer, - Enabled: true, - }, - }, + req := &reverseproxy.ExposeServiceRequest{ + Port: 80, + Protocol: "http", + Domain: "example.com", } - created, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, service) + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) require.NoError(t, err) - assert.Equal(t, "custom.example.com", created.Domain, "should keep the provided domain") + assert.Contains(t, resp.Domain, "example.com", "should use the provided domain") }) - t.Run("replaces host by peer IP lookup", func(t *testing.T) { - mgr, _ := setupIntegrationTest(t) + t.Run("validates expose permission internally", func(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) - service := &reverseproxy.Service{ - Name: "lookup-test", - Enabled: true, - Targets: []*reverseproxy.Target{ - { - AccountID: testAccountID, - Port: 3000, - Protocol: "http", - TargetId: testPeerID, - TargetType: reverseproxy.TargetTypePeer, - Enabled: true, - }, - }, + // Disable peer expose + s, err := testStore.GetAccountSettings(ctx, store.LockingStrengthNone, testAccountID) + require.NoError(t, err) + s.PeerExposeEnabled = false + err = testStore.SaveAccountSettings(ctx, testAccountID, s) + require.NoError(t, err) + + req := &reverseproxy.ExposeServiceRequest{ + Port: 8080, + Protocol: "http", } - created, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, service) + _, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) + require.Error(t, err) + assert.Contains(t, err.Error(), "not enabled") + }) + + t.Run("validates request fields", func(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + + req := &reverseproxy.ExposeServiceRequest{ + Port: 0, + Protocol: "http", + } + + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) + require.Error(t, err) + assert.Contains(t, err.Error(), "port") + }) +} + +func TestExposeServiceRequestValidate(t *testing.T) { + tests := []struct { + name string + req reverseproxy.ExposeServiceRequest + wantErr string + }{ + { + name: "valid http request", + req: reverseproxy.ExposeServiceRequest{Port: 8080, Protocol: "http"}, + wantErr: "", + }, + { + name: "valid https request with pin", + req: reverseproxy.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"}, + wantErr: "", + }, + { + name: "port zero rejected", + req: reverseproxy.ExposeServiceRequest{Port: 0, Protocol: "http"}, + wantErr: "port must be between 1 and 65535", + }, + { + name: "negative port rejected", + req: reverseproxy.ExposeServiceRequest{Port: -1, Protocol: "http"}, + wantErr: "port must be between 1 and 65535", + }, + { + name: "port above 65535 rejected", + req: reverseproxy.ExposeServiceRequest{Port: 65536, Protocol: "http"}, + wantErr: "port must be between 1 and 65535", + }, + { + name: "unsupported protocol", + req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "tcp"}, + wantErr: "unsupported protocol", + }, + { + name: "invalid pin format", + req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"}, + wantErr: "invalid pin", + }, + { + name: "pin too short", + req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"}, + wantErr: "invalid pin", + }, + { + name: "valid 6-digit pin", + req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"}, + wantErr: "", + }, + { + name: "empty user group name", + req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}}, + wantErr: "user group name cannot be empty", + }, + { + name: "invalid name prefix", + req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"}, + wantErr: "invalid name prefix", + }, + { + name: "valid name prefix", + req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"}, + wantErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.req.Validate() + if tt.wantErr == "" { + assert.NoError(t, err) + } else { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } + }) + } + + t.Run("nil receiver", func(t *testing.T) { + var req *reverseproxy.ExposeServiceRequest + err := req.Validate() + require.Error(t, err) + assert.Contains(t, err.Error(), "request cannot be nil") + }) +} + +func TestDeleteServiceFromPeer_ByDomain(t *testing.T) { + ctx := context.Background() + + t.Run("deletes service by domain", func(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + + // First create a service + req := &reverseproxy.ExposeServiceRequest{ + Port: 8080, + Protocol: "http", + } + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) require.NoError(t, err) - require.Len(t, created.Targets, 1) - assert.Equal(t, "100.64.0.1", created.Targets[0].Host, "host should be resolved to peer IP") + + // Delete by domain using unexported method + err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, false) + require.NoError(t, err) + + // Verify service is deleted + _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + require.Error(t, err, "service should be deleted") + }) + + t.Run("expire uses correct activity", func(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + + req := &reverseproxy.ExposeServiceRequest{ + Port: 8080, + Protocol: "http", + } + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) + require.NoError(t, err) + + err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, true) + require.NoError(t, err) + }) +} + +func TestStopServiceFromPeer(t *testing.T) { + ctx := context.Background() + + t.Run("stops service by domain", func(t *testing.T) { + mgr, testStore := setupIntegrationTest(t) + + req := &reverseproxy.ExposeServiceRequest{ + Port: 8080, + Protocol: "http", + } + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req) + require.NoError(t, err) + + err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + require.NoError(t, err) + + _, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain) + require.Error(t, err, "service should be deleted") + }) +} + +func TestDeleteService_UntracksEphemeralExpose(t *testing.T) { + ctx := context.Background() + mgr, _ := setupIntegrationTest(t) + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ + Port: 8080, + Protocol: "http", + }) + require.NoError(t, err) + assert.Equal(t, 1, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be tracked after create") + + // Look up the service by domain to get its store ID + svc, err := mgr.store.GetServiceByDomain(ctx, testAccountID, resp.Domain) + require.NoError(t, err) + + // Delete via the API path (user-initiated) + err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID) + require.NoError(t, err) + + assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be untracked after API delete") + + // A new expose should succeed (not blocked by stale tracking) + _, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ + Port: 9090, + Protocol: "http", + }) + assert.NoError(t, err, "new expose should succeed after API delete cleared tracking") +} + +func TestDeleteAllServices_UntracksEphemeralExposes(t *testing.T) { + ctx := context.Background() + mgr, _ := setupIntegrationTest(t) + + for i := range 3 { + _, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ + Port: 8080 + i, + Protocol: "http", + }) + require.NoError(t, err) + } + + assert.Equal(t, 3, mgr.exposeTracker.CountPeerExposes(testPeerID), "all exposes should be tracked") + + err := mgr.DeleteAllServices(ctx, testAccountID, testUserID) + require.NoError(t, err) + + assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "all exposes should be untracked after DeleteAllServices") +} + +func TestRenewServiceFromPeer(t *testing.T) { + ctx := context.Background() + + t.Run("renews tracked expose", func(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + + resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{ + Port: 8080, + Protocol: "http", + }) + require.NoError(t, err) + + err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain) + require.NoError(t, err) + }) + + t.Run("fails for untracked domain", func(t *testing.T) { + mgr, _ := setupIntegrationTest(t) + err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com") + require.Error(t, err) }) } diff --git a/management/internals/modules/reverseproxy/reverseproxy.go b/management/internals/modules/reverseproxy/reverseproxy.go index ebe9ace96..10226710b 100644 --- a/management/internals/modules/reverseproxy/reverseproxy.go +++ b/management/internals/modules/reverseproxy/reverseproxy.go @@ -318,63 +318,6 @@ func isDefaultPort(scheme string, port int) bool { return (scheme == "https" && port == 443) || (scheme == "http" && port == 80) } -// FromExposeRequest builds a Service from a peer expose gRPC request. -func FromExposeRequest(req *proto.ExposeServiceRequest, accountID, peerID, serviceName string) *Service { - service := &Service{ - AccountID: accountID, - Name: serviceName, - Enabled: true, - Targets: []*Target{ - { - AccountID: accountID, - Port: int(req.Port), - Protocol: exposeProtocolToString(req.Protocol), - TargetId: peerID, - TargetType: TargetTypePeer, - Enabled: true, - }, - }, - } - - if req.Domain != "" { - service.Domain = serviceName + "." + req.Domain - } - - if req.Pin != "" { - service.Auth.PinAuth = &PINAuthConfig{ - Enabled: true, - Pin: req.Pin, - } - } - - if req.Password != "" { - service.Auth.PasswordAuth = &PasswordAuthConfig{ - Enabled: true, - Password: req.Password, - } - } - - if len(req.UserGroups) > 0 { - service.Auth.BearerAuth = &BearerAuthConfig{ - Enabled: true, - DistributionGroups: req.UserGroups, - } - } - - return service -} - -func exposeProtocolToString(p proto.ExposeProtocol) string { - switch p { - case proto.ExposeProtocol_EXPOSE_HTTP: - return "http" - case proto.ExposeProtocol_EXPOSE_HTTPS: - return "https" - default: - return "http" - } -} - func (s *Service) FromAPIRequest(req *api.ServiceRequest, accountID string) { s.Name = req.Name s.Domain = req.Domain @@ -534,10 +477,107 @@ func (s *Service) DecryptSensitiveData(enc *crypt.FieldEncrypt) error { return nil } +var pinRegexp = regexp.MustCompile(`^\d{6}$`) + const alphanumCharset = "abcdefghijklmnopqrstuvwxyz0123456789" var validNamePrefix = regexp.MustCompile(`^[a-z0-9]([a-z0-9-]{0,30}[a-z0-9])?$`) +// ExposeServiceRequest contains the parameters for creating a peer-initiated expose service. +type ExposeServiceRequest struct { + NamePrefix string + Port int + Protocol string + Domain string + Pin string + Password string + UserGroups []string +} + +// Validate checks all fields of the expose request. +func (r *ExposeServiceRequest) Validate() error { + if r == nil { + return errors.New("request cannot be nil") + } + + if r.Port < 1 || r.Port > 65535 { + return fmt.Errorf("port must be between 1 and 65535, got %d", r.Port) + } + + if r.Protocol != "http" && r.Protocol != "https" { + return fmt.Errorf("unsupported protocol %q: must be http or https", r.Protocol) + } + + if r.Pin != "" && !pinRegexp.MatchString(r.Pin) { + return errors.New("invalid pin: must be exactly 6 digits") + } + + for _, g := range r.UserGroups { + if g == "" { + return errors.New("user group name cannot be empty") + } + } + + if r.NamePrefix != "" && !validNamePrefix.MatchString(r.NamePrefix) { + return fmt.Errorf("invalid name prefix %q: must be lowercase alphanumeric with optional hyphens, 1-32 characters", r.NamePrefix) + } + + return nil +} + +// ToService builds a Service from the expose request. +func (r *ExposeServiceRequest) ToService(accountID, peerID, serviceName string) *Service { + service := &Service{ + AccountID: accountID, + Name: serviceName, + Enabled: true, + Targets: []*Target{ + { + AccountID: accountID, + Port: r.Port, + Protocol: r.Protocol, + TargetId: peerID, + TargetType: TargetTypePeer, + Enabled: true, + }, + }, + } + + if r.Domain != "" { + service.Domain = serviceName + "." + r.Domain + } + + if r.Pin != "" { + service.Auth.PinAuth = &PINAuthConfig{ + Enabled: true, + Pin: r.Pin, + } + } + + if r.Password != "" { + service.Auth.PasswordAuth = &PasswordAuthConfig{ + Enabled: true, + Password: r.Password, + } + } + + if len(r.UserGroups) > 0 { + service.Auth.BearerAuth = &BearerAuthConfig{ + Enabled: true, + DistributionGroups: r.UserGroups, + } + } + + return service +} + +// ExposeServiceResponse contains the result of a successful peer expose creation. +type ExposeServiceResponse struct { + ServiceName string + ServiceURL string + Domain string +} + // GenerateExposeName generates a random service name for peer-exposed services. // The prefix, if provided, must be a valid DNS label component (lowercase alphanumeric and hyphens). func GenerateExposeName(prefix string) (string, error) { diff --git a/management/internals/modules/reverseproxy/reverseproxy_test.go b/management/internals/modules/reverseproxy/reverseproxy_test.go index c80d7e342..cb75ee61f 100644 --- a/management/internals/modules/reverseproxy/reverseproxy_test.go +++ b/management/internals/modules/reverseproxy/reverseproxy_test.go @@ -458,14 +458,14 @@ func TestGenerateExposeName(t *testing.T) { }) } -func TestFromExposeRequest(t *testing.T) { +func TestExposeServiceRequest_ToService(t *testing.T) { t.Run("basic HTTP service", func(t *testing.T) { - req := &proto.ExposeServiceRequest{ + req := &ExposeServiceRequest{ Port: 8080, - Protocol: proto.ExposeProtocol_EXPOSE_HTTP, + Protocol: "http", } - service := FromExposeRequest(req, "account-1", "peer-1", "mysvc") + service := req.ToService("account-1", "peer-1", "mysvc") assert.Equal(t, "account-1", service.AccountID) assert.Equal(t, "mysvc", service.Name) @@ -483,22 +483,22 @@ func TestFromExposeRequest(t *testing.T) { }) t.Run("with custom domain", func(t *testing.T) { - req := &proto.ExposeServiceRequest{ + req := &ExposeServiceRequest{ Port: 3000, Domain: "example.com", } - service := FromExposeRequest(req, "acc", "peer", "web") + service := req.ToService("acc", "peer", "web") assert.Equal(t, "web.example.com", service.Domain) }) t.Run("with PIN auth", func(t *testing.T) { - req := &proto.ExposeServiceRequest{ + req := &ExposeServiceRequest{ Port: 80, Pin: "1234", } - service := FromExposeRequest(req, "acc", "peer", "svc") + service := req.ToService("acc", "peer", "svc") require.NotNil(t, service.Auth.PinAuth) assert.True(t, service.Auth.PinAuth.Enabled) assert.Equal(t, "1234", service.Auth.PinAuth.Pin) @@ -507,31 +507,31 @@ func TestFromExposeRequest(t *testing.T) { }) t.Run("with password auth", func(t *testing.T) { - req := &proto.ExposeServiceRequest{ + req := &ExposeServiceRequest{ Port: 80, Password: "secret", } - service := FromExposeRequest(req, "acc", "peer", "svc") + service := req.ToService("acc", "peer", "svc") require.NotNil(t, service.Auth.PasswordAuth) assert.True(t, service.Auth.PasswordAuth.Enabled) assert.Equal(t, "secret", service.Auth.PasswordAuth.Password) }) t.Run("with user groups (bearer auth)", func(t *testing.T) { - req := &proto.ExposeServiceRequest{ + req := &ExposeServiceRequest{ Port: 80, UserGroups: []string{"admins", "devs"}, } - service := FromExposeRequest(req, "acc", "peer", "svc") + service := req.ToService("acc", "peer", "svc") require.NotNil(t, service.Auth.BearerAuth) assert.True(t, service.Auth.BearerAuth.Enabled) assert.Equal(t, []string{"admins", "devs"}, service.Auth.BearerAuth.DistributionGroups) }) t.Run("with all auth types", func(t *testing.T) { - req := &proto.ExposeServiceRequest{ + req := &ExposeServiceRequest{ Port: 443, Domain: "myco.com", Pin: "9999", @@ -539,7 +539,7 @@ func TestFromExposeRequest(t *testing.T) { UserGroups: []string{"ops"}, } - service := FromExposeRequest(req, "acc", "peer", "full") + service := req.ToService("acc", "peer", "full") assert.Equal(t, "full.myco.com", service.Domain) require.NotNil(t, service.Auth.PinAuth) require.NotNil(t, service.Auth.PasswordAuth) diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 216ea0857..45c1b763f 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -152,8 +152,11 @@ func (s *BaseServer) GRPCServer() *grpc.Server { if err != nil { log.Fatalf("failed to create management server: %v", err) } - srv.SetReverseProxyManager(s.ReverseProxyManager()) - srv.StartExposeReaper(context.Background()) + reverseProxyMgr := s.ReverseProxyManager() + srv.SetReverseProxyManager(reverseProxyMgr) + if reverseProxyMgr != nil { + reverseProxyMgr.StartExposeReaper(context.Background()) + } mgmtProto.RegisterManagementServiceServer(gRPCAPIHandler, srv) mgmtProto.RegisterProxyServiceServer(gRPCAPIHandler, s.ReverseProxyGRPCServer()) diff --git a/management/internals/shared/grpc/expose_service.go b/management/internals/shared/grpc/expose_service.go index 45b60ceec..ef00354af 100644 --- a/management/internals/shared/grpc/expose_service.go +++ b/management/internals/shared/grpc/expose_service.go @@ -2,9 +2,6 @@ package grpc import ( "context" - "regexp" - "sync" - "time" pb "github.com/golang/protobuf/proto" // nolint log "github.com/sirupsen/logrus" @@ -21,27 +18,6 @@ import ( internalStatus "github.com/netbirdio/netbird/shared/management/status" ) -var pinRegexp = regexp.MustCompile(`^\d{6}$`) - -const ( - exposeTTL = 90 * time.Second - exposeReapInterval = 30 * time.Second - maxExposesPerPeer = 10 -) - -type activeExpose struct { - mu sync.Mutex - serviceID string - domain string - accountID string - peerID string - lastRenewed time.Time -} - -func exposeKey(peerID, domain string) string { - return peerID + ":" + domain -} - // CreateExpose handles a peer request to create a new expose service. func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { exposeReq := &proto.ExposeServiceRequest{} @@ -58,72 +34,29 @@ func (s *Server) CreateExpose(ctx context.Context, req *proto.EncryptedMessage) // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) - if exposeReq.Protocol != proto.ExposeProtocol_EXPOSE_HTTP && exposeReq.Protocol != proto.ExposeProtocol_EXPOSE_HTTPS { - return nil, status.Errorf(codes.InvalidArgument, "only HTTP or HTTPS protocol are supported") - } - - if exposeReq.Pin != "" && !pinRegexp.MatchString(exposeReq.Pin) { - return nil, status.Errorf(codes.InvalidArgument, "invalid pin: must be exactly 6 digits") - } - - for _, g := range exposeReq.UserGroups { - if g == "" { - return nil, status.Errorf(codes.InvalidArgument, "user group name cannot be empty") - } - } - reverseProxyMgr := s.getReverseProxyManager() if reverseProxyMgr == nil { return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") } - if err := reverseProxyMgr.ValidateExposePermission(ctx, accountID, peer.ID); err != nil { - log.WithContext(ctx).Debugf("expose permission denied for peer %s: %v", peer.ID, err) - return nil, status.Errorf(codes.PermissionDenied, "permission denied") - } - - serviceName, err := reverseproxy.GenerateExposeName(exposeReq.NamePrefix) + created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, &reverseproxy.ExposeServiceRequest{ + NamePrefix: exposeReq.NamePrefix, + Port: int(exposeReq.Port), + Protocol: exposeProtocolToString(exposeReq.Protocol), + Domain: exposeReq.Domain, + Pin: exposeReq.Pin, + Password: exposeReq.Password, + UserGroups: exposeReq.UserGroups, + }) if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "generate service name: %v", err) + return nil, mapExposeError(ctx, err) } - service := reverseproxy.FromExposeRequest(exposeReq, accountID, peer.ID, serviceName) - - // Serialize the count check to prevent concurrent CreateExpose calls from - // exceeding maxExposesPerPeer. The lock is held only for the check; the - // actual service creation happens outside the lock. - s.exposeCreateMu.Lock() - if s.countPeerExposes(peer.ID) >= maxExposesPerPeer { - s.exposeCreateMu.Unlock() - return nil, status.Errorf(codes.ResourceExhausted, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer) - } - s.exposeCreateMu.Unlock() - - created, err := reverseProxyMgr.CreateServiceFromPeer(ctx, accountID, peer.ID, service) - if err != nil { - log.WithContext(ctx).Errorf("failed to create service from peer: %v", err) - return nil, status.Errorf(codes.Internal, "create service: %v", err) - } - - key := exposeKey(peer.ID, created.Domain) - if _, loaded := s.activeExposes.LoadOrStore(key, &activeExpose{ - serviceID: created.ID, - domain: created.Domain, - accountID: accountID, - peerID: peer.ID, - lastRenewed: time.Now(), - }); loaded { - s.deleteExposeService(ctx, accountID, peer.ID, created) - return nil, status.Errorf(codes.AlreadyExists, "peer already has an active expose session for this domain") - } - - resp := &proto.ExposeServiceResponse{ - ServiceName: created.Name, - ServiceUrl: "https://" + created.Domain, + return s.encryptResponse(peerKey, &proto.ExposeServiceResponse{ + ServiceName: created.ServiceName, + ServiceUrl: created.ServiceURL, Domain: created.Domain, - } - - return s.encryptResponse(peerKey, resp) + }) } // RenewExpose extends the TTL of an active expose session. @@ -134,21 +67,19 @@ func (s *Server) RenewExpose(ctx context.Context, req *proto.EncryptedMessage) ( return nil, err } - _, peer, err := s.authenticateExposePeer(ctx, peerKey) + accountID, peer, err := s.authenticateExposePeer(ctx, peerKey) if err != nil { return nil, err } - key := exposeKey(peer.ID, renewReq.Domain) - val, ok := s.activeExposes.Load(key) - if !ok { - return nil, status.Errorf(codes.NotFound, "no active expose session for domain %s", renewReq.Domain) + reverseProxyMgr := s.getReverseProxyManager() + if reverseProxyMgr == nil { + return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") } - expose := val.(*activeExpose) - expose.mu.Lock() - expose.lastRenewed = time.Now() - expose.mu.Unlock() + if err := reverseProxyMgr.RenewServiceFromPeer(ctx, accountID, peer.ID, renewReq.Domain); err != nil { + return nil, mapExposeError(ctx, err) + } return s.encryptResponse(peerKey, &proto.RenewExposeResponse{}) } @@ -161,55 +92,45 @@ func (s *Server) StopExpose(ctx context.Context, req *proto.EncryptedMessage) (* return nil, err } - _, peer, err := s.authenticateExposePeer(ctx, peerKey) + accountID, peer, err := s.authenticateExposePeer(ctx, peerKey) if err != nil { return nil, err } - key := exposeKey(peer.ID, stopReq.Domain) - val, ok := s.activeExposes.LoadAndDelete(key) - if !ok { - return nil, status.Errorf(codes.NotFound, "no active expose session for domain %s", stopReq.Domain) + reverseProxyMgr := s.getReverseProxyManager() + if reverseProxyMgr == nil { + return nil, status.Errorf(codes.Internal, "reverse proxy manager not available") } - expose := val.(*activeExpose) - s.cleanupExpose(expose, false) + if err := reverseProxyMgr.StopServiceFromPeer(ctx, accountID, peer.ID, stopReq.Domain); err != nil { + return nil, mapExposeError(ctx, err) + } return s.encryptResponse(peerKey, &proto.StopExposeResponse{}) } -// StartExposeReaper starts a background goroutine that reaps expired expose sessions. -func (s *Server) StartExposeReaper(ctx context.Context) { - go func() { - ticker := time.NewTicker(exposeReapInterval) - defer ticker.Stop() +func mapExposeError(ctx context.Context, err error) error { + s, ok := internalStatus.FromError(err) + if !ok { + log.WithContext(ctx).Errorf("expose service error: %v", err) + return status.Errorf(codes.Internal, "internal error") + } - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - s.reapExpiredExposes() - } - } - }() -} - -func (s *Server) reapExpiredExposes() { - s.activeExposes.Range(func(key, val any) bool { - expose := val.(*activeExpose) - expose.mu.Lock() - expired := time.Since(expose.lastRenewed) > exposeTTL - expose.mu.Unlock() - - if expired { - if _, deleted := s.activeExposes.LoadAndDelete(key); deleted { - log.Infof("reaping expired expose session for peer %s, domain %s", expose.peerID, expose.domain) - s.cleanupExpose(expose, true) - } - } - return true - }) + switch s.Type() { + case internalStatus.InvalidArgument: + return status.Errorf(codes.InvalidArgument, "%s", s.Message) + case internalStatus.PermissionDenied: + return status.Errorf(codes.PermissionDenied, "%s", s.Message) + case internalStatus.NotFound: + return status.Errorf(codes.NotFound, "%s", s.Message) + case internalStatus.AlreadyExists: + return status.Errorf(codes.AlreadyExists, "%s", s.Message) + case internalStatus.PreconditionFailed: + return status.Errorf(codes.ResourceExhausted, "%s", s.Message) + default: + log.WithContext(ctx).Errorf("expose service error: %v", err) + return status.Errorf(codes.Internal, "internal error") + } } func (s *Server) encryptResponse(peerKey wgtypes.Key, msg pb.Message) (*proto.EncryptedMessage, error) { @@ -246,47 +167,6 @@ func (s *Server) authenticateExposePeer(ctx context.Context, peerKey wgtypes.Key return accountID, peer, nil } -func (s *Server) deleteExposeService(ctx context.Context, accountID, peerID string, service *reverseproxy.Service) { - reverseProxyMgr := s.getReverseProxyManager() - if reverseProxyMgr == nil { - return - } - if err := reverseProxyMgr.DeleteServiceFromPeer(ctx, accountID, peerID, service.ID); err != nil { - log.WithContext(ctx).Debugf("failed to delete expose service %s: %v", service.ID, err) - } -} - -func (s *Server) cleanupExpose(expose *activeExpose, expired bool) { - bgCtx := context.Background() - - reverseProxyMgr := s.getReverseProxyManager() - if reverseProxyMgr == nil { - log.Errorf("cannot cleanup exposed service %s: reverse proxy manager not available", expose.serviceID) - return - } - - var err error - if expired { - err = reverseProxyMgr.ExpireServiceFromPeer(bgCtx, expose.accountID, expose.peerID, expose.serviceID) - } else { - err = reverseProxyMgr.DeleteServiceFromPeer(bgCtx, expose.accountID, expose.peerID, expose.serviceID) - } - if err != nil { - log.Errorf("failed to delete peer-exposed service %s: %v", expose.serviceID, err) - } -} - -func (s *Server) countPeerExposes(peerID string) int { - count := 0 - s.activeExposes.Range(func(_, val any) bool { - if expose := val.(*activeExpose); expose.peerID == peerID { - count++ - } - return true - }) - return count -} - func (s *Server) getReverseProxyManager() reverseproxy.Manager { s.reverseProxyMu.RLock() defer s.reverseProxyMu.RUnlock() @@ -299,3 +179,14 @@ func (s *Server) SetReverseProxyManager(mgr reverseproxy.Manager) { defer s.reverseProxyMu.Unlock() s.reverseProxyManager = mgr } + +func exposeProtocolToString(p proto.ExposeProtocol) string { + switch p { + case proto.ExposeProtocol_EXPOSE_HTTP: + return "http" + case proto.ExposeProtocol_EXPOSE_HTTPS: + return "https" + default: + return "http" + } +} diff --git a/management/internals/shared/grpc/expose_service_test.go b/management/internals/shared/grpc/expose_service_test.go deleted file mode 100644 index 75a16ae44..000000000 --- a/management/internals/shared/grpc/expose_service_test.go +++ /dev/null @@ -1,242 +0,0 @@ -package grpc - -import ( - "sync" - "testing" - "time" - - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/management/internals/modules/reverseproxy" -) - -func TestPinValidation(t *testing.T) { - tests := []struct { - pin string - valid bool - }{ - {"123456", true}, - {"000000", true}, - {"12345", false}, - {"1234567", false}, - {"abcdef", false}, - {"12345a", false}, - {"", false}, - {"12 345", false}, - } - - for _, tt := range tests { - assert.Equal(t, tt.valid, pinRegexp.MatchString(tt.pin), "pin %q", tt.pin) - } -} - -func TestExposeKey(t *testing.T) { - assert.Equal(t, "peer1:example.com", exposeKey("peer1", "example.com")) - assert.Equal(t, "peer2:other.com", exposeKey("peer2", "other.com")) - assert.NotEqual(t, exposeKey("peer1", "a.com"), exposeKey("peer1", "b.com")) -} - -func TestCountPeerExposes(t *testing.T) { - s := &Server{} - - // No exposes - assert.Equal(t, 0, s.countPeerExposes("peer1")) - - // Add some exposes for different peers - s.activeExposes.Store("peer1:a.com", &activeExpose{peerID: "peer1"}) - s.activeExposes.Store("peer1:b.com", &activeExpose{peerID: "peer1"}) - s.activeExposes.Store("peer2:a.com", &activeExpose{peerID: "peer2"}) - - assert.Equal(t, 2, s.countPeerExposes("peer1"), "peer1 should have 2 exposes") - assert.Equal(t, 1, s.countPeerExposes("peer2"), "peer2 should have 1 expose") - assert.Equal(t, 0, s.countPeerExposes("peer3"), "peer3 should have 0 exposes") -} - -func TestReapExpiredExposes(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockMgr := reverseproxy.NewMockManager(ctrl) - - s := &Server{} - s.SetReverseProxyManager(mockMgr) - - now := time.Now() - - // Add an expired expose and a still-active one - s.activeExposes.Store("peer1:expired.com", &activeExpose{ - serviceID: "svc-expired", - domain: "expired.com", - accountID: "acct1", - peerID: "peer1", - lastRenewed: now.Add(-2 * exposeTTL), - }) - s.activeExposes.Store("peer1:active.com", &activeExpose{ - serviceID: "svc-active", - domain: "active.com", - accountID: "acct1", - peerID: "peer1", - lastRenewed: now, - }) - - // Expect ExpireServiceFromPeer called only for the expired one - mockMgr.EXPECT(). - ExpireServiceFromPeer(gomock.Any(), "acct1", "peer1", "svc-expired"). - Return(nil) - - s.reapExpiredExposes() - - // Verify expired one is removed - _, exists := s.activeExposes.Load("peer1:expired.com") - assert.False(t, exists, "expired expose should be removed") - - // Verify active one remains - _, exists = s.activeExposes.Load("peer1:active.com") - assert.True(t, exists, "active expose should remain") -} - -func TestCleanupExpose_Delete(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockMgr := reverseproxy.NewMockManager(ctrl) - - s := &Server{} - s.SetReverseProxyManager(mockMgr) - - mockMgr.EXPECT(). - DeleteServiceFromPeer(gomock.Any(), "acct1", "peer1", "svc1"). - Return(nil) - - s.cleanupExpose(&activeExpose{ - serviceID: "svc1", - accountID: "acct1", - peerID: "peer1", - }, false) -} - -func TestCleanupExpose_Expire(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockMgr := reverseproxy.NewMockManager(ctrl) - - s := &Server{} - s.SetReverseProxyManager(mockMgr) - - mockMgr.EXPECT(). - ExpireServiceFromPeer(gomock.Any(), "acct1", "peer1", "svc1"). - Return(nil) - - s.cleanupExpose(&activeExpose{ - serviceID: "svc1", - accountID: "acct1", - peerID: "peer1", - }, true) -} - -func TestCleanupExpose_NilManager(t *testing.T) { - s := &Server{} - // Should not panic when reverse proxy manager is nil - s.cleanupExpose(&activeExpose{ - serviceID: "svc1", - accountID: "acct1", - peerID: "peer1", - }, false) -} - -func TestSetReverseProxyManager(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - s := &Server{} - - // Initially nil - assert.Nil(t, s.getReverseProxyManager()) - - mockMgr := reverseproxy.NewMockManager(ctrl) - s.SetReverseProxyManager(mockMgr) - assert.NotNil(t, s.getReverseProxyManager()) - - // Can set to nil - s.SetReverseProxyManager(nil) - assert.Nil(t, s.getReverseProxyManager()) -} - -func TestReapExpiredExposes_ConcurrentSafety(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - mockMgr := reverseproxy.NewMockManager(ctrl) - mockMgr.EXPECT(). - ExpireServiceFromPeer(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(nil). - AnyTimes() - - s := &Server{} - s.SetReverseProxyManager(mockMgr) - - // Pre-populate with expired sessions - for i := range 20 { - peerID := "peer1" - domain := "domain-" + string(rune('a'+i)) - s.activeExposes.Store(exposeKey(peerID, domain), &activeExpose{ - serviceID: "svc-" + domain, - domain: domain, - accountID: "acct1", - peerID: peerID, - lastRenewed: time.Now().Add(-2 * exposeTTL), - }) - } - - // Run reaper concurrently with count - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - s.reapExpiredExposes() - }() - go func() { - defer wg.Done() - s.countPeerExposes("peer1") - }() - wg.Wait() - - assert.Equal(t, 0, s.countPeerExposes("peer1"), "all expired exposes should be reaped") -} - -func TestActiveExposeMutexProtectsLastRenewed(t *testing.T) { - expose := &activeExpose{ - lastRenewed: time.Now().Add(-1 * time.Hour), - } - - // Simulate concurrent renew and read - var wg sync.WaitGroup - wg.Add(2) - - go func() { - defer wg.Done() - for range 100 { - expose.mu.Lock() - expose.lastRenewed = time.Now() - expose.mu.Unlock() - } - }() - - go func() { - defer wg.Done() - for range 100 { - expose.mu.Lock() - _ = time.Since(expose.lastRenewed) - expose.mu.Unlock() - } - }() - - wg.Wait() - - expose.mu.Lock() - require.False(t, expose.lastRenewed.IsZero(), "lastRenewed should not be zero after concurrent access") - expose.mu.Unlock() -} diff --git a/management/internals/shared/grpc/proxy_group_access_test.go b/management/internals/shared/grpc/proxy_group_access_test.go index 611ee36b6..827897981 100644 --- a/management/internals/shared/grpc/proxy_group_access_test.go +++ b/management/internals/shared/grpc/proxy_group_access_test.go @@ -76,21 +76,19 @@ func (m *mockReverseProxyManager) GetServiceIDByTargetID(_ context.Context, _, _ return "", nil } -func (m *mockReverseProxyManager) ValidateExposePermission(_ context.Context, _, _ string) error { +func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { + return &reverseproxy.ExposeServiceResponse{}, nil +} + +func (m *mockReverseProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { return nil } -func (m *mockReverseProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { - return &reverseproxy.Service{}, nil -} - -func (m *mockReverseProxyManager) DeleteServiceFromPeer(_ context.Context, _, _, _ string) error { +func (m *mockReverseProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error { return nil } -func (m *mockReverseProxyManager) ExpireServiceFromPeer(_ context.Context, _, _, _ string) error { - return nil -} +func (m *mockReverseProxyManager) StartExposeReaper(_ context.Context) {} type mockUsersManager struct { users map[string]*types.User diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 3df9ce7ba..029d71e2e 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -82,8 +82,6 @@ type Server struct { syncLimEnabled bool syncLim int32 - activeExposes sync.Map - exposeCreateMu sync.Mutex reverseProxyManager reverseproxy.Manager reverseProxyMu sync.RWMutex } diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 1e03a461a..640a27bb2 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -196,7 +196,7 @@ func TestValidateSession_ProxyNotFound(t *testing.T) { require.NoError(t, err) assert.False(t, resp.Valid, "Unknown proxy should be denied") - assert.Equal(t, "proxy_not_found", resp.DeniedReason) + assert.Equal(t, "service_not_found", resp.DeniedReason) } func TestValidateSession_InvalidToken(t *testing.T) { @@ -263,6 +263,10 @@ func (m *testValidateSessionProxyManager) DeleteService(_ context.Context, _, _, return nil } +func (m *testValidateSessionProxyManager) DeleteAllServices(_ context.Context, _, _ string) error { + return nil +} + func (m *testValidateSessionProxyManager) SetCertificateIssuedAt(_ context.Context, _, _ string) error { return nil } @@ -295,22 +299,20 @@ func (m *testValidateSessionProxyManager) GetServiceIDByTargetID(_ context.Conte return "", nil } -func (m *testValidateSessionProxyManager) ValidateExposePermission(_ context.Context, _, _ string) error { - return nil -} - -func (m *testValidateSessionProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *testValidateSessionProxyManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { return nil, nil } -func (m *testValidateSessionProxyManager) DeleteServiceFromPeer(_ context.Context, _, _, _ string) error { +func (m *testValidateSessionProxyManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { return nil } -func (m *testValidateSessionProxyManager) ExpireServiceFromPeer(_ context.Context, _, _, _ string) error { +func (m *testValidateSessionProxyManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error { return nil } +func (m *testValidateSessionProxyManager) StartExposeReaper(_ context.Context) {} + type testValidateSessionUsersManager struct { store store.Store } diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index 77d50d818..12634dda4 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -413,22 +413,20 @@ func (m *testServiceManager) GetServiceIDByTargetID(_ context.Context, _, _ stri return "", nil } -func (m *testServiceManager) ValidateExposePermission(_ context.Context, _, _ string) error { - return nil -} - -func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { +func (m *testServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { return nil, nil } -func (m *testServiceManager) DeleteServiceFromPeer(_ context.Context, _, _, _ string) error { +func (m *testServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { return nil } -func (m *testServiceManager) ExpireServiceFromPeer(_ context.Context, _, _, _ string) error { +func (m *testServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error { return nil } +func (m *testServiceManager) StartExposeReaper(_ context.Context) {} + func createTestState(t *testing.T, ps *nbgrpc.ProxyServiceServer, redirectURL string) string { t.Helper() diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 12cec89ff..e91335a81 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -247,21 +247,19 @@ func (m *storeBackedServiceManager) GetServiceIDByTargetID(ctx context.Context, return "", nil } -func (m *storeBackedServiceManager) ValidateExposePermission(_ context.Context, _, _ string) error { +func (m *storeBackedServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) { + return &reverseproxy.ExposeServiceResponse{}, nil +} + +func (m *storeBackedServiceManager) RenewServiceFromPeer(_ context.Context, _, _, _ string) error { return nil } -func (m *storeBackedServiceManager) CreateServiceFromPeer(_ context.Context, _, _ string, _ *reverseproxy.Service) (*reverseproxy.Service, error) { - return &reverseproxy.Service{}, nil -} - -func (m *storeBackedServiceManager) DeleteServiceFromPeer(_ context.Context, _, _, _ string) error { +func (m *storeBackedServiceManager) StopServiceFromPeer(_ context.Context, _, _, _ string) error { return nil } -func (m *storeBackedServiceManager) ExpireServiceFromPeer(_ context.Context, _, _, _ string) error { - return nil -} +func (m *storeBackedServiceManager) StartExposeReaper(_ context.Context) {} func strPtr(s string) *string { return &s