[management] Refactor expose feature: move business logic from gRPC to manager (#5435)

Consolidate all expose business logic (validation, permission checks, TTL tracking, reaping) into the manager layer, making the gRPC layer a pure transport adapter that only handles proto conversion and authentication.

- Add ExposeServiceRequest/ExposeServiceResponse domain types with validation in the reverseproxy package
- Move expose tracker (TTL tracking, reaping, per-peer limits) from gRPC server into manager/expose_tracker.go
- Internalize tracking in CreateServiceFromPeer, RenewServiceFromPeer, and new StopServiceFromPeer so callers don't manage tracker state
- Untrack ephemeral services in DeleteService/DeleteAllServices to keep tracker in sync when services are deleted via API
- Simplify gRPC expose handlers to parse, auth, convert, delegate
- Remove tracker methods from Manager interface (internal detail)
This commit is contained in:
Maycon Santos
2026-02-24 15:09:30 +01:00
committed by GitHub
parent f8c0321aee
commit 327142837c
17 changed files with 1072 additions and 659 deletions

View File

@@ -0,0 +1,163 @@
package manager
import (
"context"
"sync"
"time"
"github.com/netbirdio/netbird/shared/management/status"
log "github.com/sirupsen/logrus"
)
const (
exposeTTL = 90 * time.Second
exposeReapInterval = 30 * time.Second
maxExposesPerPeer = 10
)
type trackedExpose struct {
mu sync.Mutex
domain string
accountID string
peerID string
lastRenewed time.Time
expiring bool
}
type exposeTracker struct {
activeExposes sync.Map
exposeCreateMu sync.Mutex
manager *managerImpl
}
func exposeKey(peerID, domain string) string {
return peerID + ":" + domain
}
// TrackExposeIfAllowed atomically checks the per-peer limit and registers a new
// active expose session under the same lock. Returns (true, false) if the expose
// was already tracked (duplicate), (false, true) if tracking succeeded, and
// (false, false) if the peer has reached the limit.
func (t *exposeTracker) TrackExposeIfAllowed(peerID, domain, accountID string) (alreadyTracked, ok bool) {
t.exposeCreateMu.Lock()
defer t.exposeCreateMu.Unlock()
key := exposeKey(peerID, domain)
_, loaded := t.activeExposes.LoadOrStore(key, &trackedExpose{
domain: domain,
accountID: accountID,
peerID: peerID,
lastRenewed: time.Now(),
})
if loaded {
return true, false
}
if t.CountPeerExposes(peerID) > maxExposesPerPeer {
t.activeExposes.Delete(key)
return false, false
}
return false, true
}
// UntrackExpose removes an active expose session from tracking.
func (t *exposeTracker) UntrackExpose(peerID, domain string) {
t.activeExposes.Delete(exposeKey(peerID, domain))
}
// CountPeerExposes returns the number of active expose sessions for a peer.
func (t *exposeTracker) CountPeerExposes(peerID string) int {
count := 0
t.activeExposes.Range(func(_, val any) bool {
if expose := val.(*trackedExpose); expose.peerID == peerID {
count++
}
return true
})
return count
}
// MaxExposesPerPeer returns the maximum number of concurrent exposes allowed per peer.
func (t *exposeTracker) MaxExposesPerPeer() int {
return maxExposesPerPeer
}
// RenewTrackedExpose updates the in-memory lastRenewed timestamp for a tracked expose.
// Returns false if the expose is not tracked or is being reaped.
func (t *exposeTracker) RenewTrackedExpose(peerID, domain string) bool {
key := exposeKey(peerID, domain)
val, ok := t.activeExposes.Load(key)
if !ok {
return false
}
expose := val.(*trackedExpose)
expose.mu.Lock()
if expose.expiring {
expose.mu.Unlock()
return false
}
expose.lastRenewed = time.Now()
expose.mu.Unlock()
return true
}
// StopTrackedExpose removes an active expose session from tracking.
// Returns false if the expose was not tracked.
func (t *exposeTracker) StopTrackedExpose(peerID, domain string) bool {
key := exposeKey(peerID, domain)
_, ok := t.activeExposes.LoadAndDelete(key)
return ok
}
// StartExposeReaper starts a background goroutine that reaps expired expose sessions.
func (t *exposeTracker) StartExposeReaper(ctx context.Context) {
go func() {
ticker := time.NewTicker(exposeReapInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
t.reapExpiredExposes()
}
}
}()
}
func (t *exposeTracker) reapExpiredExposes() {
t.activeExposes.Range(func(key, val any) bool {
expose := val.(*trackedExpose)
expose.mu.Lock()
expired := time.Since(expose.lastRenewed) > exposeTTL
if expired {
expose.expiring = true
}
expose.mu.Unlock()
if !expired {
return true
}
log.Infof("reaping expired expose session for peer %s, domain %s", expose.peerID, expose.domain)
err := t.manager.deleteServiceFromPeer(context.Background(), expose.accountID, expose.peerID, expose.domain, true)
s, _ := status.FromError(err)
switch {
case err == nil:
t.activeExposes.Delete(key)
case s.ErrorType == status.NotFound:
log.Debugf("service %s was already deleted", expose.domain)
default:
log.Errorf("failed to delete expired peer-exposed service for domain %s: %v", expose.domain, err)
}
return true
})
}

View File

@@ -0,0 +1,256 @@
package manager
import (
"context"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy"
)
func TestExposeKey(t *testing.T) {
assert.Equal(t, "peer1:example.com", exposeKey("peer1", "example.com"))
assert.Equal(t, "peer2:other.com", exposeKey("peer2", "other.com"))
assert.NotEqual(t, exposeKey("peer1", "a.com"), exposeKey("peer1", "b.com"))
}
func TestTrackExposeIfAllowed(t *testing.T) {
t.Run("first track succeeds", func(t *testing.T) {
tracker := &exposeTracker{}
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
assert.False(t, alreadyTracked, "first track should not be duplicate")
assert.True(t, ok, "first track should be allowed")
})
t.Run("duplicate track detected", func(t *testing.T) {
tracker := &exposeTracker{}
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
assert.True(t, alreadyTracked, "second track should be duplicate")
assert.False(t, ok)
})
t.Run("rejects when at limit", func(t *testing.T) {
tracker := &exposeTracker{}
for i := range maxExposesPerPeer {
_, ok := tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1")
assert.True(t, ok, "track %d should be allowed", i)
}
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "over-limit.com", "acct1")
assert.False(t, alreadyTracked)
assert.False(t, ok, "should reject when at limit")
})
t.Run("other peer unaffected by limit", func(t *testing.T) {
tracker := &exposeTracker{}
for i := range maxExposesPerPeer {
tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1")
}
_, ok := tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1")
assert.True(t, ok, "other peer should still be within limit")
})
}
func TestUntrackExpose(t *testing.T) {
tracker := &exposeTracker{}
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
assert.Equal(t, 1, tracker.CountPeerExposes("peer1"))
tracker.UntrackExpose("peer1", "a.com")
assert.Equal(t, 0, tracker.CountPeerExposes("peer1"))
}
func TestCountPeerExposes(t *testing.T) {
tracker := &exposeTracker{}
assert.Equal(t, 0, tracker.CountPeerExposes("peer1"))
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
tracker.TrackExposeIfAllowed("peer1", "b.com", "acct1")
tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1")
assert.Equal(t, 2, tracker.CountPeerExposes("peer1"), "peer1 should have 2 exposes")
assert.Equal(t, 1, tracker.CountPeerExposes("peer2"), "peer2 should have 1 expose")
assert.Equal(t, 0, tracker.CountPeerExposes("peer3"), "peer3 should have 0 exposes")
}
func TestMaxExposesPerPeer(t *testing.T) {
tracker := &exposeTracker{}
assert.Equal(t, maxExposesPerPeer, tracker.MaxExposesPerPeer())
}
func TestRenewTrackedExpose(t *testing.T) {
tracker := &exposeTracker{}
found := tracker.RenewTrackedExpose("peer1", "a.com")
assert.False(t, found, "should not find untracked expose")
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
found = tracker.RenewTrackedExpose("peer1", "a.com")
assert.True(t, found, "should find tracked expose")
}
func TestRenewTrackedExpose_RejectsExpiring(t *testing.T) {
tracker := &exposeTracker{}
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
// Simulate reaper marking the expose as expiring
key := exposeKey("peer1", "a.com")
val, _ := tracker.activeExposes.Load(key)
expose := val.(*trackedExpose)
expose.mu.Lock()
expose.expiring = true
expose.mu.Unlock()
found := tracker.RenewTrackedExpose("peer1", "a.com")
assert.False(t, found, "should reject renewal when expiring")
}
func TestReapExpiredExposes(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
tracker := mgr.exposeTracker
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
// Manually expire the tracked entry
key := exposeKey(testPeerID, resp.Domain)
val, _ := tracker.activeExposes.Load(key)
expose := val.(*trackedExpose)
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
// Add an active (non-expired) tracking entry
tracker.activeExposes.Store(exposeKey("peer1", "active.com"), &trackedExpose{
domain: "active.com",
accountID: testAccountID,
peerID: "peer1",
lastRenewed: time.Now(),
})
tracker.reapExpiredExposes()
_, exists := tracker.activeExposes.Load(key)
assert.False(t, exists, "expired expose should be removed")
_, exists = tracker.activeExposes.Load(exposeKey("peer1", "active.com"))
assert.True(t, exists, "active expose should remain")
}
func TestReapExpiredExposes_SetsExpiringFlag(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
tracker := mgr.exposeTracker
ctx := context.Background()
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
key := exposeKey(testPeerID, resp.Domain)
val, _ := tracker.activeExposes.Load(key)
expose := val.(*trackedExpose)
// Expire it
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
// Renew should succeed before reaping
assert.True(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should succeed before reaper runs")
// Re-expire and reap
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
tracker.reapExpiredExposes()
// Entry is deleted, renew returns false
assert.False(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should fail after reap")
}
func TestConcurrentTrackAndCount(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
tracker := mgr.exposeTracker
ctx := context.Background()
for i := range 5 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
})
require.NoError(t, err)
}
// Manually expire all tracked entries
tracker.activeExposes.Range(func(_, val any) bool {
expose := val.(*trackedExpose)
expose.mu.Lock()
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
expose.mu.Unlock()
return true
})
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
tracker.reapExpiredExposes()
}()
go func() {
defer wg.Done()
tracker.CountPeerExposes(testPeerID)
}()
wg.Wait()
assert.Equal(t, 0, tracker.CountPeerExposes(testPeerID), "all expired exposes should be reaped")
}
func TestTrackedExposeMutexProtectsLastRenewed(t *testing.T) {
expose := &trackedExpose{
lastRenewed: time.Now().Add(-1 * time.Hour),
}
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
for range 100 {
expose.mu.Lock()
expose.lastRenewed = time.Now()
expose.mu.Unlock()
}
}()
go func() {
defer wg.Done()
for range 100 {
expose.mu.Lock()
_ = time.Since(expose.lastRenewed)
expose.mu.Unlock()
}
}()
wg.Wait()
expose.mu.Lock()
require.False(t, expose.lastRenewed.IsZero(), "lastRenewed should not be zero after concurrent access")
expose.mu.Unlock()
}

View File

@@ -40,11 +40,12 @@ type managerImpl struct {
settingsManager settings.Manager
proxyGRPCServer *nbgrpc.ProxyServiceServer
clusterDeriver ClusterDeriver
exposeTracker *exposeTracker
}
// NewManager creates a new service manager.
func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, settingsManager settings.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, clusterDeriver ClusterDeriver) reverseproxy.Manager {
return &managerImpl{
mgr := &managerImpl{
store: store,
accountManager: accountManager,
permissionsManager: permissionsManager,
@@ -52,6 +53,13 @@ func NewManager(store store.Store, accountManager account.Manager, permissionsMa
proxyGRPCServer: proxyGRPCServer,
clusterDeriver: clusterDeriver,
}
mgr.exposeTracker = &exposeTracker{manager: mgr}
return mgr
}
// StartExposeReaper delegates to the expose tracker.
func (m *managerImpl) StartExposeReaper(ctx context.Context) {
m.exposeTracker.StartExposeReaper(ctx)
}
func (m *managerImpl) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
@@ -418,6 +426,10 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
return err
}
if service.Source == reverseproxy.SourceEphemeral {
m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain)
}
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
@@ -460,6 +472,9 @@ func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID s
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
for _, service := range services {
if service.Source == reverseproxy.SourceEphemeral {
m.exposeTracker.UntrackExpose(service.SourcePeer, service.Domain)
}
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, service.EventMeta())
mapping := service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg)
clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping)
@@ -617,9 +632,9 @@ func (m *managerImpl) GetServiceIDByTargetID(ctx context.Context, accountID stri
return target.ServiceID, nil
}
// ValidateExposePermission checks whether the peer is allowed to use the expose feature.
// validateExposePermission checks whether the peer is allowed to use the expose feature.
// It verifies the account has peer expose enabled and that the peer belongs to an allowed group.
func (m *managerImpl) ValidateExposePermission(ctx context.Context, accountID, peerID string) error {
func (m *managerImpl) validateExposePermission(ctx context.Context, accountID, peerID string) error {
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
@@ -650,8 +665,23 @@ func (m *managerImpl) ValidateExposePermission(ctx context.Context, accountID, p
}
// CreateServiceFromPeer creates a service initiated by a peer expose request.
// It skips user permission checks since authorization is done at the gRPC handler level.
func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, service *reverseproxy.Service) (*reverseproxy.Service, error) {
// It validates the request, checks expose permissions, enforces the per-peer limit,
// creates the service, and tracks it for TTL-based reaping.
func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peerID string, req *reverseproxy.ExposeServiceRequest) (*reverseproxy.ExposeServiceResponse, error) {
if err := req.Validate(); err != nil {
return nil, status.Errorf(status.InvalidArgument, "validate expose request: %v", err)
}
if err := m.validateExposePermission(ctx, accountID, peerID); err != nil {
return nil, err
}
serviceName, err := reverseproxy.GenerateExposeName(req.NamePrefix)
if err != nil {
return nil, status.Errorf(status.InvalidArgument, "generate service name: %v", err)
}
service := req.ToService(accountID, peerID, serviceName)
service.Source = reverseproxy.SourceEphemeral
if service.Domain == "" {
@@ -665,7 +695,7 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer
if service.Auth.BearerAuth != nil && service.Auth.BearerAuth.Enabled {
groupIDs, err := m.getGroupIDsFromNames(ctx, accountID, service.Auth.BearerAuth.DistributionGroups)
if err != nil {
return nil, fmt.Errorf("get group ids for service %s: %w", service.ID, err)
return nil, fmt.Errorf("get group ids for service %s: %w", service.Name, err)
}
service.Auth.BearerAuth.DistributionGroups = groupIDs
}
@@ -687,8 +717,21 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer
return nil, err
}
meta := addPeerInfoToEventMeta(service.EventMeta(), peer)
alreadyTracked, allowed := m.exposeTracker.TrackExposeIfAllowed(peerID, service.Domain, accountID)
if alreadyTracked {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete duplicate expose service for domain %s: %v", service.Domain, err)
}
return nil, status.Errorf(status.AlreadyExists, "peer already has an active expose session for this domain")
}
if !allowed {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, service.Domain, false); err != nil {
log.WithContext(ctx).Debugf("failed to delete service after limit exceeded for domain %s: %v", service.Domain, err)
}
return nil, status.Errorf(status.PreconditionFailed, "peer has reached the maximum number of active expose sessions (%d)", maxExposesPerPeer)
}
meta := addPeerInfoToEventMeta(service.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, service.ID, accountID, activity.PeerServiceExposed, meta)
if err := m.replaceHostByLookup(ctx, accountID, service); err != nil {
@@ -696,10 +739,13 @@ func (m *managerImpl) CreateServiceFromPeer(ctx context.Context, accountID, peer
}
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil
return &reverseproxy.ExposeServiceResponse{
ServiceName: service.Name,
ServiceURL: "https://" + service.Domain,
Domain: service.Domain,
}, nil
}
func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string, groupNames []string) ([]string, error) {
@@ -718,6 +764,9 @@ func (m *managerImpl) getGroupIDsFromNames(ctx context.Context, accountID string
}
func (m *managerImpl) buildRandomDomain(name string) (string, error) {
if m.clusterDeriver == nil {
return "", fmt.Errorf("unable to get random domain")
}
clusterDomains := m.clusterDeriver.GetClusterDomains()
if len(clusterDomains) == 0 {
return "", fmt.Errorf("no cluster domains found for service %s", name)
@@ -727,15 +776,60 @@ func (m *managerImpl) buildRandomDomain(name string) (string, error) {
return domain, nil
}
// DeleteServiceFromPeer deletes a peer-initiated service.
// It validates that the service was created by a peer to prevent deleting API-created services.
func (m *managerImpl) DeleteServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
return m.deletePeerService(ctx, accountID, peerID, serviceID, activity.PeerServiceUnexposed)
// RenewServiceFromPeer renews the in-memory TTL tracker for the peer's expose session.
// Returns an error if the expose is not actively tracked.
func (m *managerImpl) RenewServiceFromPeer(_ context.Context, _, peerID, domain string) error {
if !m.exposeTracker.RenewTrackedExpose(peerID, domain) {
return status.Errorf(status.NotFound, "no active expose session for domain %s", domain)
}
return nil
}
// ExpireServiceFromPeer deletes a peer-initiated service that was not renewed within the TTL.
func (m *managerImpl) ExpireServiceFromPeer(ctx context.Context, accountID, peerID, serviceID string) error {
return m.deletePeerService(ctx, accountID, peerID, serviceID, activity.PeerServiceExposeExpired)
// StopServiceFromPeer stops a peer's active expose session by untracking and deleting the service.
func (m *managerImpl) StopServiceFromPeer(ctx context.Context, accountID, peerID, domain string) error {
if err := m.deleteServiceFromPeer(ctx, accountID, peerID, domain, false); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer-exposed service for domain %s: %v", domain, err)
return err
}
if !m.exposeTracker.StopTrackedExpose(peerID, domain) {
log.WithContext(ctx).Warnf("expose tracker entry for domain %s already removed; service was deleted", domain)
}
return nil
}
// deleteServiceFromPeer deletes a peer-initiated service identified by domain.
// When expired is true, the activity is recorded as PeerServiceExposeExpired instead of PeerServiceUnexposed.
func (m *managerImpl) deleteServiceFromPeer(ctx context.Context, accountID, peerID, domain string, expired bool) error {
service, err := m.lookupPeerService(ctx, accountID, peerID, domain)
if err != nil {
return err
}
activityCode := activity.PeerServiceUnexposed
if expired {
activityCode = activity.PeerServiceExposeExpired
}
return m.deletePeerService(ctx, accountID, peerID, service.ID, activityCode)
}
// lookupPeerService finds a peer-initiated service by domain and validates ownership.
func (m *managerImpl) lookupPeerService(ctx context.Context, accountID, peerID, domain string) (*reverseproxy.Service, error) {
service, err := m.store.GetServiceByDomain(ctx, accountID, domain)
if err != nil {
return nil, err
}
if service.Source != reverseproxy.SourceEphemeral {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on API-created service via peer expose")
}
if service.SourcePeer != peerID {
return nil, status.Errorf(status.PermissionDenied, "cannot operate on service exposed by another peer")
}
return service, nil
}
func (m *managerImpl) deletePeerService(ctx context.Context, accountID, peerID, serviceID string, activityCode activity.Activity) error {

View File

@@ -658,6 +658,13 @@ func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) {
PeerExposeEnabled: true,
PeerExposeGroups: []string{testGroupID},
},
Users: map[string]*types.User{
testUserID: {
Id: testUserID,
AccountID: testAccountID,
Role: types.UserRoleAdmin,
},
},
Peers: map[string]*nbpeer.Peer{
testPeerID: {
ID: testPeerID,
@@ -712,16 +719,17 @@ func setupIntegrationTest(t *testing.T) (*managerImpl, store.Store) {
domains: []string{"test.netbird.io"},
},
}
mgr.exposeTracker = &exposeTracker{manager: mgr}
return mgr, testStore
}
func TestValidateExposePermission(t *testing.T) {
func Test_validateExposePermission(t *testing.T) {
ctx := context.Background()
t.Run("allowed when peer is in expose group", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
err := mgr.ValidateExposePermission(ctx, testAccountID, testPeerID)
err := mgr.validateExposePermission(ctx, testAccountID, testPeerID)
assert.NoError(t, err)
})
@@ -742,7 +750,7 @@ func TestValidateExposePermission(t *testing.T) {
})
require.NoError(t, err)
err = mgr.ValidateExposePermission(ctx, testAccountID, otherPeerID)
err = mgr.validateExposePermission(ctx, testAccountID, otherPeerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "not in an allowed expose group")
})
@@ -757,7 +765,7 @@ func TestValidateExposePermission(t *testing.T) {
err = testStore.SaveAccountSettings(ctx, testAccountID, s)
require.NoError(t, err)
err = mgr.ValidateExposePermission(ctx, testAccountID, testPeerID)
err = mgr.validateExposePermission(ctx, testAccountID, testPeerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "not enabled")
})
@@ -772,7 +780,7 @@ func TestValidateExposePermission(t *testing.T) {
err = testStore.SaveAccountSettings(ctx, testAccountID, s)
require.NoError(t, err)
err = mgr.ValidateExposePermission(ctx, testAccountID, testPeerID)
err = mgr.validateExposePermission(ctx, testAccountID, testPeerID)
assert.Error(t, err)
})
@@ -781,7 +789,7 @@ func TestValidateExposePermission(t *testing.T) {
mockStore := store.NewMockStore(ctrl)
mockStore.EXPECT().GetAccountSettings(gomock.Any(), gomock.Any(), testAccountID).Return(nil, errors.New("store error"))
mgr := &managerImpl{store: mockStore}
err := mgr.ValidateExposePermission(ctx, testAccountID, testPeerID)
err := mgr.validateExposePermission(ctx, testAccountID, testPeerID)
require.Error(t, err)
assert.Contains(t, err.Error(), "get account settings")
})
@@ -793,82 +801,290 @@ func TestCreateServiceFromPeer(t *testing.T) {
t.Run("creates service with random domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
service := &reverseproxy.Service{
Name: "my-expose",
Enabled: true,
Targets: []*reverseproxy.Target{
{
AccountID: testAccountID,
Port: 8080,
Protocol: "http",
TargetId: testPeerID,
TargetType: reverseproxy.TargetTypePeer,
Enabled: true,
},
},
req := &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
created, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, service)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
assert.NotEmpty(t, created.ID, "service should have an ID")
assert.Contains(t, created.Domain, "test.netbird.io", "domain should use cluster domain")
assert.Equal(t, reverseproxy.SourceEphemeral, created.Source, "source should be ephemeral")
assert.Equal(t, testPeerID, created.SourcePeer, "source peer should be set")
assert.NotNil(t, created.Meta.LastRenewedAt, "last renewed should be set")
assert.NotEmpty(t, resp.ServiceName, "service name should be generated")
assert.Contains(t, resp.Domain, "test.netbird.io", "domain should use cluster domain")
assert.NotEmpty(t, resp.ServiceURL, "service URL should be set")
// Verify service is persisted in store
persisted, err := testStore.GetServiceByID(ctx, store.LockingStrengthNone, testAccountID, created.ID)
persisted, err := testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.NoError(t, err)
assert.Equal(t, created.ID, persisted.ID)
assert.Equal(t, created.Domain, persisted.Domain)
assert.Equal(t, resp.Domain, persisted.Domain)
assert.Equal(t, reverseproxy.SourceEphemeral, persisted.Source, "source should be ephemeral")
assert.Equal(t, testPeerID, persisted.SourcePeer, "source peer should be set")
assert.NotNil(t, persisted.Meta.LastRenewedAt, "last renewed should be set")
})
t.Run("creates service with custom domain", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
service := &reverseproxy.Service{
Name: "custom",
Domain: "custom.example.com",
Enabled: true,
Targets: []*reverseproxy.Target{
{
AccountID: testAccountID,
Port: 80,
Protocol: "http",
TargetId: testPeerID,
TargetType: reverseproxy.TargetTypePeer,
Enabled: true,
},
},
req := &reverseproxy.ExposeServiceRequest{
Port: 80,
Protocol: "http",
Domain: "example.com",
}
created, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, service)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
assert.Equal(t, "custom.example.com", created.Domain, "should keep the provided domain")
assert.Contains(t, resp.Domain, "example.com", "should use the provided domain")
})
t.Run("replaces host by peer IP lookup", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
t.Run("validates expose permission internally", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
service := &reverseproxy.Service{
Name: "lookup-test",
Enabled: true,
Targets: []*reverseproxy.Target{
{
AccountID: testAccountID,
Port: 3000,
Protocol: "http",
TargetId: testPeerID,
TargetType: reverseproxy.TargetTypePeer,
Enabled: true,
},
},
// Disable peer expose
s, err := testStore.GetAccountSettings(ctx, store.LockingStrengthNone, testAccountID)
require.NoError(t, err)
s.PeerExposeEnabled = false
err = testStore.SaveAccountSettings(ctx, testAccountID, s)
require.NoError(t, err)
req := &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
created, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, service)
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.Error(t, err)
assert.Contains(t, err.Error(), "not enabled")
})
t.Run("validates request fields", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
Port: 0,
Protocol: "http",
}
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.Error(t, err)
assert.Contains(t, err.Error(), "port")
})
}
func TestExposeServiceRequestValidate(t *testing.T) {
tests := []struct {
name string
req reverseproxy.ExposeServiceRequest
wantErr string
}{
{
name: "valid http request",
req: reverseproxy.ExposeServiceRequest{Port: 8080, Protocol: "http"},
wantErr: "",
},
{
name: "valid https request with pin",
req: reverseproxy.ExposeServiceRequest{Port: 443, Protocol: "https", Pin: "123456"},
wantErr: "",
},
{
name: "port zero rejected",
req: reverseproxy.ExposeServiceRequest{Port: 0, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "negative port rejected",
req: reverseproxy.ExposeServiceRequest{Port: -1, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "port above 65535 rejected",
req: reverseproxy.ExposeServiceRequest{Port: 65536, Protocol: "http"},
wantErr: "port must be between 1 and 65535",
},
{
name: "unsupported protocol",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "tcp"},
wantErr: "unsupported protocol",
},
{
name: "invalid pin format",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "abc"},
wantErr: "invalid pin",
},
{
name: "pin too short",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "12345"},
wantErr: "invalid pin",
},
{
name: "valid 6-digit pin",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", Pin: "000000"},
wantErr: "",
},
{
name: "empty user group name",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", UserGroups: []string{"valid", ""}},
wantErr: "user group name cannot be empty",
},
{
name: "invalid name prefix",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "INVALID"},
wantErr: "invalid name prefix",
},
{
name: "valid name prefix",
req: reverseproxy.ExposeServiceRequest{Port: 80, Protocol: "http", NamePrefix: "my-service"},
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.req.Validate()
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
}
})
}
t.Run("nil receiver", func(t *testing.T) {
var req *reverseproxy.ExposeServiceRequest
err := req.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "request cannot be nil")
})
}
func TestDeleteServiceFromPeer_ByDomain(t *testing.T) {
ctx := context.Background()
t.Run("deletes service by domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
// First create a service
req := &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
require.Len(t, created.Targets, 1)
assert.Equal(t, "100.64.0.1", created.Targets[0].Host, "host should be resolved to peer IP")
// Delete by domain using unexported method
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, false)
require.NoError(t, err)
// Verify service is deleted
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.Error(t, err, "service should be deleted")
})
t.Run("expire uses correct activity", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
err = mgr.deleteServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain, true)
require.NoError(t, err)
})
}
func TestStopServiceFromPeer(t *testing.T) {
ctx := context.Background()
t.Run("stops service by domain", func(t *testing.T) {
mgr, testStore := setupIntegrationTest(t)
req := &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
}
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, req)
require.NoError(t, err)
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.Error(t, err, "service should be deleted")
})
}
func TestDeleteService_UntracksEphemeralExpose(t *testing.T) {
ctx := context.Background()
mgr, _ := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
assert.Equal(t, 1, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be tracked after create")
// Look up the service by domain to get its store ID
svc, err := mgr.store.GetServiceByDomain(ctx, testAccountID, resp.Domain)
require.NoError(t, err)
// Delete via the API path (user-initiated)
err = mgr.DeleteService(ctx, testAccountID, testUserID, svc.ID)
require.NoError(t, err)
assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "expose should be untracked after API delete")
// A new expose should succeed (not blocked by stale tracking)
_, err = mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 9090,
Protocol: "http",
})
assert.NoError(t, err, "new expose should succeed after API delete cleared tracking")
}
func TestDeleteAllServices_UntracksEphemeralExposes(t *testing.T) {
ctx := context.Background()
mgr, _ := setupIntegrationTest(t)
for i := range 3 {
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080 + i,
Protocol: "http",
})
require.NoError(t, err)
}
assert.Equal(t, 3, mgr.exposeTracker.CountPeerExposes(testPeerID), "all exposes should be tracked")
err := mgr.DeleteAllServices(ctx, testAccountID, testUserID)
require.NoError(t, err)
assert.Equal(t, 0, mgr.exposeTracker.CountPeerExposes(testPeerID), "all exposes should be untracked after DeleteAllServices")
}
func TestRenewServiceFromPeer(t *testing.T) {
ctx := context.Background()
t.Run("renews tracked expose", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &reverseproxy.ExposeServiceRequest{
Port: 8080,
Protocol: "http",
})
require.NoError(t, err)
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
require.NoError(t, err)
})
t.Run("fails for untracked domain", func(t *testing.T) {
mgr, _ := setupIntegrationTest(t)
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
require.Error(t, err)
})
}