[proxy] consolidate mapping update (#6072)

This commit is contained in:
Pascal Fischer
2026-05-05 18:40:42 +02:00
committed by GitHub
parent 3c28d29725
commit cfb1b3fe31
6 changed files with 559 additions and 58 deletions

View File

@@ -364,14 +364,16 @@ func TestIntegration_ProxyConnection_HappyPath(t *testing.T) {
})
require.NoError(t, err)
// Receive all mappings from the snapshot - server sends each mapping individually
mappingsByID := make(map[string]*proto.ProxyMapping)
for i := 0; i < 2; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
for _, m := range msg.GetMapping() {
mappingsByID[m.GetId()] = m
}
if msg.GetInitialSyncComplete() {
break
}
}
// Should receive 2 mappings total
@@ -411,12 +413,14 @@ func TestIntegration_ProxyConnection_SendsClusterAddress(t *testing.T) {
})
require.NoError(t, err)
// Receive all mappings - server sends each mapping individually
mappings := make([]*proto.ProxyMapping, 0)
for i := 0; i < 2; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...)
if msg.GetInitialSyncComplete() {
break
}
}
// Should receive the 2 mappings matching the cluster
@@ -440,13 +444,15 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
clusterAddress := "test.proxy.io"
proxyID := "test-proxy-reconnect"
// Helper to receive all mappings from a stream
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient, count int) []*proto.ProxyMapping {
receiveMappings := func(stream proto.ProxyService_GetMappingUpdateClient) []*proto.ProxyMapping {
var mappings []*proto.ProxyMapping
for i := 0; i < count; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
mappings = append(mappings, msg.GetMapping()...)
if msg.GetInitialSyncComplete() {
break
}
}
return mappings
}
@@ -460,7 +466,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
})
require.NoError(t, err)
firstMappings := receiveMappings(stream1, 2)
firstMappings := receiveMappings(stream1)
cancel1()
time.Sleep(100 * time.Millisecond)
@@ -476,7 +482,7 @@ func TestIntegration_ProxyConnection_Reconnect_ReceivesSameConfig(t *testing.T)
})
require.NoError(t, err)
secondMappings := receiveMappings(stream2, 2)
secondMappings := receiveMappings(stream2)
// Should receive the same mappings
assert.Equal(t, len(firstMappings), len(secondMappings),
@@ -542,12 +548,14 @@ func TestIntegration_ProxyConnection_ReconnectDoesNotDuplicateState(t *testing.T
}
}
// Helper to receive and apply all mappings
receiveAndApply := func(stream proto.ProxyService_GetMappingUpdateClient) {
for i := 0; i < 2; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
applyMappings(msg.GetMapping())
if msg.GetInitialSyncComplete() {
break
}
}
}
@@ -636,12 +644,14 @@ func TestIntegration_ProxyConnection_MultipleProxiesReceiveUpdates(t *testing.T)
})
require.NoError(t, err)
// Receive all mappings - server sends each mapping individually
count := 0
for i := 0; i < 2; i++ {
for {
msg, err := stream.Recv()
require.NoError(t, err)
count += len(msg.GetMapping())
if msg.GetInitialSyncComplete() {
break
}
}
mu.Lock()
@@ -681,9 +691,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
})
require.NoError(t, err)
for i := 0; i < 2; i++ {
_, err := stream1.Recv()
for {
msg, err := stream1.Recv()
require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
}
require.Contains(t, setup.proxyService.GetConnectedProxies(), proxyID,
@@ -699,9 +712,12 @@ func TestIntegration_ProxyConnection_FastReconnectDoesNotLoseState(t *testing.T)
})
require.NoError(t, err)
for i := 0; i < 2; i++ {
_, err := stream2.Recv()
for {
msg, err := stream2.Recv()
require.NoError(t, err)
if msg.GetInitialSyncComplete() {
break
}
}
cancel1()

View File

@@ -943,6 +943,8 @@ func (s *Server) newManagementMappingWorker(ctx context.Context, client proto.Pr
operation := func() error {
s.Logger.Debug("connecting to management mapping stream")
initialSyncDone = false
if s.healthChecker != nil {
s.healthChecker.SetManagementConnected(false)
}
@@ -1000,6 +1002,11 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
return ctx.Err()
}
var snapshotIDs map[types.ServiceID]struct{}
if !*initialSyncDone {
snapshotIDs = make(map[types.ServiceID]struct{})
}
for {
// Check for context completion to gracefully shutdown.
select {
@@ -1020,17 +1027,45 @@ func (s *Server) handleMappingStream(ctx context.Context, mappingClient proto.Pr
s.processMappings(ctx, msg.GetMapping())
s.Logger.Debug("Processing mapping update completed")
if !*initialSyncDone && msg.GetInitialSyncComplete() {
if s.healthChecker != nil {
s.healthChecker.SetInitialSyncComplete()
if !*initialSyncDone {
for _, m := range msg.GetMapping() {
snapshotIDs[types.ServiceID(m.GetId())] = struct{}{}
}
if msg.GetInitialSyncComplete() {
s.reconcileSnapshot(ctx, snapshotIDs)
snapshotIDs = nil
if s.healthChecker != nil {
s.healthChecker.SetInitialSyncComplete()
}
*initialSyncDone = true
s.Logger.Info("Initial mapping sync complete")
}
*initialSyncDone = true
s.Logger.Info("Initial mapping sync complete")
}
}
}
}
// reconcileSnapshot removes local mappings that are absent from the snapshot.
// This ensures services deleted while the proxy was disconnected get cleaned up.
func (s *Server) reconcileSnapshot(ctx context.Context, snapshotIDs map[types.ServiceID]struct{}) {
s.portMu.RLock()
var stale []*proto.ProxyMapping
for svcID, mapping := range s.lastMappings {
if _, ok := snapshotIDs[svcID]; !ok {
stale = append(stale, mapping)
}
}
s.portMu.RUnlock()
for _, mapping := range stale {
s.Logger.WithFields(log.Fields{
"service_id": mapping.GetId(),
"domain": mapping.GetDomain(),
}).Info("Removing stale mapping absent from snapshot")
s.removeMapping(ctx, mapping)
}
}
func (s *Server) processMappings(ctx context.Context, mappings []*proto.ProxyMapping) {
for _, mapping := range mappings {
s.Logger.WithFields(log.Fields{

View File

@@ -0,0 +1,227 @@
package proxy
import (
"context"
"io"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/internal/health"
"github.com/netbirdio/netbird/proxy/internal/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
// collectStaleIDs mirrors the stale-detection logic in reconcileSnapshot
// so we can verify it without triggering removeMapping (which requires full
// server wiring). This keeps the test focused on the detection algorithm.
func collectStaleIDs(lastMappings map[types.ServiceID]*proto.ProxyMapping, snapshotIDs map[types.ServiceID]struct{}) []types.ServiceID {
var stale []types.ServiceID
for svcID := range lastMappings {
if _, ok := snapshotIDs[svcID]; !ok {
stale = append(stale, svcID)
}
}
return stale
}
// TestStaleDetection_PartialOverlap verifies that only services absent from
// the snapshot are flagged as stale.
func TestStaleDetection_PartialOverlap(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
"svc-stale-a": {Id: "svc-stale-a"},
"svc-stale-b": {Id: "svc-stale-b"},
}
snapshot := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
"svc-3": {}, // new service, not in local
}
stale := collectStaleIDs(local, snapshot)
assert.Len(t, stale, 2)
staleSet := make(map[types.ServiceID]struct{})
for _, id := range stale {
staleSet[id] = struct{}{}
}
assert.Contains(t, staleSet, types.ServiceID("svc-stale-a"))
assert.Contains(t, staleSet, types.ServiceID("svc-stale-b"))
}
// TestStaleDetection_AllStale verifies an empty snapshot flags everything.
func TestStaleDetection_AllStale(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
}
stale := collectStaleIDs(local, map[types.ServiceID]struct{}{})
assert.Len(t, stale, 2)
}
// TestStaleDetection_NoneStale verifies full overlap produces no stale entries.
func TestStaleDetection_NoneStale(t *testing.T) {
local := map[types.ServiceID]*proto.ProxyMapping{
"svc-1": {Id: "svc-1"},
"svc-2": {Id: "svc-2"},
}
snapshot := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
}
stale := collectStaleIDs(local, snapshot)
assert.Empty(t, stale)
}
// TestStaleDetection_EmptyLocal verifies no stale entries when local is empty.
func TestStaleDetection_EmptyLocal(t *testing.T) {
stale := collectStaleIDs(
map[types.ServiceID]*proto.ProxyMapping{},
map[types.ServiceID]struct{}{"svc-1": {}},
)
assert.Empty(t, stale)
}
// TestReconcileSnapshot_NoStale verifies reconciliation is a no-op when all
// local mappings are present in the snapshot (removeMapping is never called).
func TestReconcileSnapshot_NoStale(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1"}
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2"}
snapshotIDs := map[types.ServiceID]struct{}{
"svc-1": {},
"svc-2": {},
}
// This should not panic — no stale entries means removeMapping is never called.
s.reconcileSnapshot(context.Background(), snapshotIDs)
assert.Len(t, s.lastMappings, 2, "no mappings should be removed when all are in snapshot")
}
// TestReconcileSnapshot_EmptyLocal verifies reconciliation is a no-op with
// no local mappings.
func TestReconcileSnapshot_EmptyLocal(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.reconcileSnapshot(context.Background(), map[types.ServiceID]struct{}{"svc-1": {}})
assert.Empty(t, s.lastMappings)
}
// --- handleMappingStream tests for batched snapshot ID accumulation ---
// TestHandleMappingStream_BatchedSnapshotSyncComplete verifies that sync is
// marked done only after the final InitialSyncComplete message, even when
// the snapshot arrives in multiple batches.
func TestHandleMappingStream_BatchedSnapshotSyncComplete(t *testing.T) {
checker := health.NewChecker(nil, nil)
s := &Server{
Logger: log.StandardLogger(),
healthChecker: checker,
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{}, // batch 1: no sync-complete
{}, // batch 2: no sync-complete
{InitialSyncComplete: true}, // batch 3: sync done
},
}
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.True(t, syncDone, "sync should be marked done after final batch")
}
// TestHandleMappingStream_PostSyncDoesNotReconcile verifies that messages
// arriving after InitialSyncComplete do not trigger a second reconciliation.
func TestHandleMappingStream_PostSyncDoesNotReconcile(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
// Simulate state left over from a previous sync.
s.lastMappings["svc-1"] = &proto.ProxyMapping{Id: "svc-1", AccountId: "acct-1"}
s.lastMappings["svc-2"] = &proto.ProxyMapping{Id: "svc-2", AccountId: "acct-1"}
stream := &mockMappingStream{
messages: []*proto.GetMappingUpdateResponse{
{}, // post-sync empty message — must not reconcile
},
}
syncDone := true // sync already completed in a previous stream
err := s.handleMappingStream(context.Background(), stream, &syncDone)
require.NoError(t, err)
assert.Len(t, s.lastMappings, 2,
"post-sync messages must not trigger reconciliation — all entries should survive")
}
// TestHandleMappingStream_ImmediateEOF_NoReconciliation verifies that if the
// stream closes before sync completes, no reconciliation occurs.
func TestHandleMappingStream_ImmediateEOF_NoReconciliation(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
stream := &mockMappingStream{} // no messages → immediate EOF
syncDone := false
err := s.handleMappingStream(context.Background(), stream, &syncDone)
assert.NoError(t, err)
assert.False(t, syncDone, "sync should not be marked done on immediate EOF")
_, hasStale := s.lastMappings["svc-stale"]
assert.True(t, hasStale, "stale mapping should remain when sync never completed")
}
// mockErrRecvStream returns an error on the second Recv to verify
// handleMappingStream returns without completing sync.
type mockErrRecvStream struct {
mockMappingStream
calls int
}
func (m *mockErrRecvStream) Recv() (*proto.GetMappingUpdateResponse, error) {
m.calls++
if m.calls == 1 {
return &proto.GetMappingUpdateResponse{}, nil
}
return nil, io.ErrUnexpectedEOF
}
func TestHandleMappingStream_ErrorMidSync_NoReconciliation(t *testing.T) {
s := &Server{
Logger: log.StandardLogger(),
routerReady: closedChan(),
lastMappings: make(map[types.ServiceID]*proto.ProxyMapping),
}
s.lastMappings["svc-stale"] = &proto.ProxyMapping{Id: "svc-stale", AccountId: "acct-1"}
syncDone := false
err := s.handleMappingStream(context.Background(), &mockErrRecvStream{}, &syncDone)
assert.Error(t, err)
assert.False(t, syncDone)
_, hasStale := s.lastMappings["svc-stale"]
assert.True(t, hasStale, "stale mapping should remain when sync was interrupted by error")
}