mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
[management] Replace in-memory expose tracker with SQL-backed operations (#5494)
The expose tracker used sync.Map for in-memory TTL tracking of active expose sessions, which broke and lost all sessions on restart. Replace with SQL-backed operations that reuse the existing meta_last_renewed_at column: - Add store methods: RenewEphemeralService, GetExpiredEphemeralServices, CountEphemeralServicesByPeer, EphemeralServiceExists - Move duplicate/limit checks inside a transaction with row-level locking (SELECT ... FOR UPDATE) to prevent concurrent bypass - Reaper re-checks expiry under row lock to avoid deleting a just-renewed service and prevent duplicate event emission - Add composite index on (source, source_peer) for efficient queries - Batch-limit and column-select the reaper query to avoid DB/GC spikes - Filter out malformed rows with empty source_peer
This commit is contained in:
@@ -10,184 +10,62 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
func TestExposeKey(t *testing.T) {
|
||||
assert.Equal(t, "peer1:example.com", exposeKey("peer1", "example.com"))
|
||||
assert.Equal(t, "peer2:other.com", exposeKey("peer2", "other.com"))
|
||||
assert.NotEqual(t, exposeKey("peer1", "a.com"), exposeKey("peer1", "b.com"))
|
||||
}
|
||||
|
||||
func TestTrackExposeIfAllowed(t *testing.T) {
|
||||
t.Run("first track succeeds", func(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
assert.False(t, alreadyTracked, "first track should not be duplicate")
|
||||
assert.True(t, ok, "first track should be allowed")
|
||||
})
|
||||
|
||||
t.Run("duplicate track detected", func(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
|
||||
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
assert.True(t, alreadyTracked, "second track should be duplicate")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("rejects when at limit", func(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
for i := range maxExposesPerPeer {
|
||||
_, ok := tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1")
|
||||
assert.True(t, ok, "track %d should be allowed", i)
|
||||
}
|
||||
|
||||
alreadyTracked, ok := tracker.TrackExposeIfAllowed("peer1", "over-limit.com", "acct1")
|
||||
assert.False(t, alreadyTracked)
|
||||
assert.False(t, ok, "should reject when at limit")
|
||||
})
|
||||
|
||||
t.Run("other peer unaffected by limit", func(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
for i := range maxExposesPerPeer {
|
||||
tracker.TrackExposeIfAllowed("peer1", "domain-"+string(rune('a'+i))+".com", "acct1")
|
||||
}
|
||||
|
||||
_, ok := tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1")
|
||||
assert.True(t, ok, "other peer should still be within limit")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUntrackExpose(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
assert.Equal(t, 1, tracker.CountPeerExposes("peer1"))
|
||||
|
||||
tracker.UntrackExpose("peer1", "a.com")
|
||||
assert.Equal(t, 0, tracker.CountPeerExposes("peer1"))
|
||||
}
|
||||
|
||||
func TestCountPeerExposes(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
|
||||
assert.Equal(t, 0, tracker.CountPeerExposes("peer1"))
|
||||
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
tracker.TrackExposeIfAllowed("peer1", "b.com", "acct1")
|
||||
tracker.TrackExposeIfAllowed("peer2", "a.com", "acct1")
|
||||
|
||||
assert.Equal(t, 2, tracker.CountPeerExposes("peer1"), "peer1 should have 2 exposes")
|
||||
assert.Equal(t, 1, tracker.CountPeerExposes("peer2"), "peer2 should have 1 expose")
|
||||
assert.Equal(t, 0, tracker.CountPeerExposes("peer3"), "peer3 should have 0 exposes")
|
||||
}
|
||||
|
||||
func TestMaxExposesPerPeer(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
assert.Equal(t, maxExposesPerPeer, tracker.MaxExposesPerPeer())
|
||||
}
|
||||
|
||||
func TestRenewTrackedExpose(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
|
||||
found := tracker.RenewTrackedExpose("peer1", "a.com")
|
||||
assert.False(t, found, "should not find untracked expose")
|
||||
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
|
||||
found = tracker.RenewTrackedExpose("peer1", "a.com")
|
||||
assert.True(t, found, "should find tracked expose")
|
||||
}
|
||||
|
||||
func TestRenewTrackedExpose_RejectsExpiring(t *testing.T) {
|
||||
tracker := &exposeTracker{}
|
||||
tracker.TrackExposeIfAllowed("peer1", "a.com", "acct1")
|
||||
|
||||
// Simulate reaper marking the expose as expiring
|
||||
key := exposeKey("peer1", "a.com")
|
||||
val, _ := tracker.activeExposes.Load(key)
|
||||
expose := val.(*trackedExpose)
|
||||
expose.mu.Lock()
|
||||
expose.expiring = true
|
||||
expose.mu.Unlock()
|
||||
|
||||
found := tracker.RenewTrackedExpose("peer1", "a.com")
|
||||
assert.False(t, found, "should reject renewal when expiring")
|
||||
}
|
||||
|
||||
func TestReapExpiredExposes(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
tracker := mgr.exposeTracker
|
||||
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Manually expire the tracked entry
|
||||
key := exposeKey(testPeerID, resp.Domain)
|
||||
val, _ := tracker.activeExposes.Load(key)
|
||||
expose := val.(*trackedExpose)
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
|
||||
expose.mu.Unlock()
|
||||
// Manually expire the service by backdating meta_last_renewed_at
|
||||
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||
|
||||
// Add an active (non-expired) tracking entry
|
||||
tracker.activeExposes.Store(exposeKey("peer1", "active.com"), &trackedExpose{
|
||||
domain: "active.com",
|
||||
accountID: testAccountID,
|
||||
peerID: "peer1",
|
||||
lastRenewed: time.Now(),
|
||||
// Create a non-expired service
|
||||
resp2, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8081,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
tracker.reapExpiredExposes()
|
||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||
|
||||
_, exists := tracker.activeExposes.Load(key)
|
||||
assert.False(t, exists, "expired expose should be removed")
|
||||
// Expired service should be deleted
|
||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
require.Error(t, err, "expired service should be deleted")
|
||||
|
||||
_, exists = tracker.activeExposes.Load(exposeKey("peer1", "active.com"))
|
||||
assert.True(t, exists, "active expose should remain")
|
||||
// Non-expired service should remain
|
||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp2.Domain)
|
||||
require.NoError(t, err, "active service should remain")
|
||||
}
|
||||
|
||||
func TestReapExpiredExposes_SetsExpiringFlag(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
tracker := mgr.exposeTracker
|
||||
|
||||
func TestReapAlreadyDeletedService(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
key := exposeKey(testPeerID, resp.Domain)
|
||||
val, _ := tracker.activeExposes.Load(key)
|
||||
expose := val.(*trackedExpose)
|
||||
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||
|
||||
// Expire it
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
|
||||
expose.mu.Unlock()
|
||||
// Delete the service before reaping
|
||||
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Renew should succeed before reaping
|
||||
assert.True(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should succeed before reaper runs")
|
||||
|
||||
// Re-expire and reap
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
|
||||
expose.mu.Unlock()
|
||||
|
||||
tracker.reapExpiredExposes()
|
||||
|
||||
// Entry is deleted, renew returns false
|
||||
assert.False(t, tracker.RenewTrackedExpose(testPeerID, resp.Domain), "renew should fail after reap")
|
||||
// Reaping should handle the already-deleted service gracefully
|
||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||
}
|
||||
|
||||
func TestConcurrentTrackAndCount(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
tracker := mgr.exposeTracker
|
||||
func TestConcurrentReapAndRenew(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := range 5 {
|
||||
@@ -198,59 +76,133 @@ func TestConcurrentTrackAndCount(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Manually expire all tracked entries
|
||||
tracker.activeExposes.Range(func(_, val any) bool {
|
||||
expose := val.(*trackedExpose)
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now().Add(-2 * exposeTTL)
|
||||
expose.mu.Unlock()
|
||||
return true
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tracker.reapExpiredExposes()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
tracker.CountPeerExposes(testPeerID)
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
assert.Equal(t, 0, tracker.CountPeerExposes(testPeerID), "all expired exposes should be reaped")
|
||||
}
|
||||
|
||||
func TestTrackedExposeMutexProtectsLastRenewed(t *testing.T) {
|
||||
expose := &trackedExpose{
|
||||
lastRenewed: time.Now().Add(-1 * time.Hour),
|
||||
// Expire all services
|
||||
services, err := testStore.GetAccountServices(ctx, store.LockingStrengthNone, testAccountID)
|
||||
require.NoError(t, err)
|
||||
for _, svc := range services {
|
||||
if svc.Source == rpservice.SourceEphemeral {
|
||||
expireEphemeralService(t, testStore, testAccountID, svc.Domain)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range 100 {
|
||||
expose.mu.Lock()
|
||||
expose.lastRenewed = time.Now()
|
||||
expose.mu.Unlock()
|
||||
}
|
||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for range 100 {
|
||||
expose.mu.Lock()
|
||||
_ = time.Since(expose.lastRenewed)
|
||||
expose.mu.Unlock()
|
||||
}
|
||||
_, _ = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
expose.mu.Lock()
|
||||
require.False(t, expose.lastRenewed.IsZero(), "lastRenewed should not be zero after concurrent access")
|
||||
expose.mu.Unlock()
|
||||
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count, "all expired services should be reaped")
|
||||
}
|
||||
|
||||
func TestRenewEphemeralService(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("renew succeeds for active service", func(t *testing.T) {
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8082,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("renew fails for nonexistent domain", func(t *testing.T) {
|
||||
err := mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, "nonexistent.com")
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no active expose session")
|
||||
})
|
||||
}
|
||||
|
||||
func TestCountAndExistsEphemeralServices(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
count, err := mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(0), count)
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8083,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
count, err = mgr.store.CountEphemeralServicesByPeer(ctx, store.LockingStrengthNone, testAccountID, testPeerID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
exists, err := mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, exists, "service should exist")
|
||||
|
||||
exists, err = mgr.store.EphemeralServiceExists(ctx, store.LockingStrengthNone, testAccountID, testPeerID, "no-such.domain")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "non-existent service should not exist")
|
||||
}
|
||||
|
||||
func TestMaxExposesPerPeerEnforced(t *testing.T) {
|
||||
mgr, _ := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := range maxExposesPerPeer {
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8090 + i,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err, "expose %d should succeed", i)
|
||||
}
|
||||
|
||||
_, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 9999,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "maximum number of active expose sessions")
|
||||
}
|
||||
|
||||
func TestReapSkipsRenewedService(t *testing.T) {
|
||||
mgr, testStore := setupIntegrationTest(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
|
||||
Port: 8086,
|
||||
Protocol: "http",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expire the service
|
||||
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
|
||||
|
||||
// Renew it before the reaper runs
|
||||
err = mgr.RenewServiceFromPeer(ctx, testAccountID, testPeerID, resp.Domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reaper should skip it because the re-check sees a fresh timestamp
|
||||
mgr.exposeReaper.reapExpiredExposes(ctx)
|
||||
|
||||
_, err = testStore.GetServiceByDomain(ctx, testAccountID, resp.Domain)
|
||||
require.NoError(t, err, "renewed service should survive reaping")
|
||||
}
|
||||
|
||||
// expireEphemeralService backdates meta_last_renewed_at to force expiration.
|
||||
func expireEphemeralService(t *testing.T, s store.Store, accountID, domain string) {
|
||||
t.Helper()
|
||||
svc, err := s.GetServiceByDomain(context.Background(), accountID, domain)
|
||||
require.NoError(t, err)
|
||||
|
||||
expired := time.Now().Add(-2 * exposeTTL)
|
||||
svc.Meta.LastRenewedAt = &expired
|
||||
err = s.UpdateService(context.Background(), svc)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user