diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 5ae64e9f1..3e28e1380 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -247,7 +247,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) c.metrics.CountToSyncResponseDuration(time.Since(start)) - c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update}) + c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeNetworkMap, + }) }(peer) } @@ -370,7 +373,10 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) - c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update}) + c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeNetworkMap, + }) return nil } @@ -778,6 +784,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI }, }, }, + MessageType: network_map.MessageTypeNetworkMap, }) c.peersUpdateManager.CloseChannel(ctx, peerID) diff --git a/management/internals/controllers/network_map/update_channel/updatechannel_test.go b/management/internals/controllers/network_map/update_channel/updatechannel_test.go index afc1e2c32..c73baf81f 100644 --- a/management/internals/controllers/network_map/update_channel/updatechannel_test.go +++ b/management/internals/controllers/network_map/update_channel/updatechannel_test.go @@ -25,11 +25,14 @@ func TestCreateChannel(t *testing.T) { func TestSendUpdate(t *testing.T) { peer := "test-sendupdate" peersUpdater := NewPeersUpdateManager(nil) - update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{ - Serial: 0, + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: 0, + }, }, - }} + MessageType: network_map.MessageTypeNetworkMap, + } _ = peersUpdater.CreateChannel(context.Background(), peer) if _, ok := peersUpdater.peerChannels[peer]; !ok { t.Error("Error creating the channel") @@ -45,11 +48,14 @@ func TestSendUpdate(t *testing.T) { peersUpdater.SendUpdate(context.Background(), peer, update1) } - update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{ - Serial: 10, + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: 10, + }, }, - }} + MessageType: network_map.MessageTypeNetworkMap, + } peersUpdater.SendUpdate(context.Background(), peer, update2) timeout := time.After(5 * time.Second) diff --git a/management/internals/controllers/network_map/update_message.go b/management/internals/controllers/network_map/update_message.go index 33643bcbd..0ffddf8b2 100644 --- a/management/internals/controllers/network_map/update_message.go +++ b/management/internals/controllers/network_map/update_message.go @@ -4,6 +4,19 @@ import ( "github.com/netbirdio/netbird/shared/management/proto" ) +// MessageType indicates the type of update message for debouncing strategy +type MessageType int + +const ( + // MessageTypeNetworkMap represents network map updates (peers, routes, DNS, firewall) + // These updates can be safely debounced - only the latest state matters + MessageTypeNetworkMap MessageType = iota + // MessageTypeControlConfig represents control/config updates (tokens, peer expiration) + // These updates should not be dropped as they contain time-sensitive information + MessageTypeControlConfig +) + type UpdateMessage struct { - Update *proto.SyncResponse + Update *proto.SyncResponse + MessageType MessageType } diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 98c68ebda..ff9d7ea05 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -404,11 +404,20 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt } // handleUpdates sends updates to the connected peer until the updates channel is closed. +// It implements a backpressure mechanism that sends the first update immediately, +// then debounces subsequent rapid updates, ensuring only the latest update is sent +// after a quiet period. func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) + + // Create a debouncer for this peer connection + debouncer := NewUpdateDebouncer(1000 * time.Millisecond) + defer debouncer.Stop() + for { select { // condition when there are some updates + // todo set the updates channel size to 1 case update, open := <-updates: if s.appMetrics != nil { s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1) @@ -419,10 +428,28 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return nil } + log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) - if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil { - log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) - return err + if debouncer.ProcessUpdate(update) { + // Send immediately (first update or after quiet period) + if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil { + log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) + return err + } + } + + // Timer expired - quiet period reached, send pending updates if any + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) == 0 { + continue + } + log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String()) + for _, pendingUpdate := range pendingUpdates { + if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil { + log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) + return err + } } // condition when client <-> server connection has been terminated diff --git a/management/internals/shared/grpc/token_mgr.go b/management/internals/shared/grpc/token_mgr.go index ccb32202f..65e58ad41 100644 --- a/management/internals/shared/grpc/token_mgr.go +++ b/management/internals/shared/grpc/token_mgr.go @@ -242,7 +242,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeControlConfig, + }) } func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) { @@ -266,7 +269,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeControlConfig, + }) } func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) { diff --git a/management/internals/shared/grpc/update_debouncer.go b/management/internals/shared/grpc/update_debouncer.go new file mode 100644 index 000000000..8af9c2656 --- /dev/null +++ b/management/internals/shared/grpc/update_debouncer.go @@ -0,0 +1,103 @@ +package grpc + +import ( + "time" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" +) + +// UpdateDebouncer implements a backpressure mechanism that: +// - Sends the first update immediately +// - Coalesces rapid subsequent network map updates (only latest matters) +// - Queues control/config updates (all must be delivered) +// - Preserves the order of messages (important for control configs between network maps) +// - Ensures pending updates are sent after a quiet period +type UpdateDebouncer struct { + debounceInterval time.Duration + timer *time.Timer + pendingUpdates []*network_map.UpdateMessage // Queue that preserves order + timerC <-chan time.Time +} + +// NewUpdateDebouncer creates a new debouncer with the specified interval +func NewUpdateDebouncer(interval time.Duration) *UpdateDebouncer { + return &UpdateDebouncer{ + debounceInterval: interval, + } +} + +// ProcessUpdate handles an incoming update and returns whether it should be sent immediately +func (d *UpdateDebouncer) ProcessUpdate(update *network_map.UpdateMessage) bool { + if d.timer == nil { + // No active debounce timer, signal to send immediately + // and start the debounce period + d.startTimer() + return true + } + + // Already in debounce period, accumulate this update preserving order + // Check if we should coalesce with the last pending update + if len(d.pendingUpdates) > 0 && + update.MessageType == network_map.MessageTypeNetworkMap && + d.pendingUpdates[len(d.pendingUpdates)-1].MessageType == network_map.MessageTypeNetworkMap { + // Replace the last network map with this one (coalesce consecutive network maps) + d.pendingUpdates[len(d.pendingUpdates)-1] = update + } else { + // Append to the queue (preserves order for control configs and non-consecutive network maps) + d.pendingUpdates = append(d.pendingUpdates, update) + } + d.resetTimer() + return false +} + +// TimerChannel returns the timer channel for select statements +func (d *UpdateDebouncer) TimerChannel() <-chan time.Time { + if d.timer == nil { + return nil + } + return d.timerC +} + +// GetPendingUpdates returns and clears all pending updates after timer expiration. +// Updates are returned in the order they were received, with consecutive network maps +// already coalesced to only the latest one. +// If there were pending updates, it restarts the timer to continue debouncing. +// If there were no pending updates, it clears the timer (true quiet period). +func (d *UpdateDebouncer) GetPendingUpdates() []*network_map.UpdateMessage { + updates := d.pendingUpdates + d.pendingUpdates = nil + + if len(updates) > 0 { + // There were pending updates, so updates are still coming rapidly + // Restart the timer to continue debouncing mode + if d.timer != nil { + d.timer.Reset(d.debounceInterval) + } + } else { + // No pending updates means true quiet period - return to immediate mode + d.timer = nil + d.timerC = nil + } + + return updates +} + +// Stop stops the debouncer and cleans up resources +func (d *UpdateDebouncer) Stop() { + if d.timer != nil { + d.timer.Stop() + d.timer = nil + d.timerC = nil + } + d.pendingUpdates = nil +} + +func (d *UpdateDebouncer) startTimer() { + d.timer = time.NewTimer(d.debounceInterval) + d.timerC = d.timer.C +} + +func (d *UpdateDebouncer) resetTimer() { + d.timer.Stop() + d.timer.Reset(d.debounceInterval) +} diff --git a/management/internals/shared/grpc/update_debouncer_test.go b/management/internals/shared/grpc/update_debouncer_test.go new file mode 100644 index 000000000..075994a2d --- /dev/null +++ b/management/internals/shared/grpc/update_debouncer_test.go @@ -0,0 +1,587 @@ +package grpc + +import ( + "testing" + "time" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/shared/management/proto" +) + +func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + shouldSend := debouncer.ProcessUpdate(update) + + if !shouldSend { + t.Error("First update should be sent immediately") + } + + if debouncer.TimerChannel() == nil { + t.Error("Timer should be started after first update") + } +} + +func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update3 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // First update should be sent immediately + if !debouncer.ProcessUpdate(update1) { + t.Error("First update should be sent immediately") + } + + // Rapid subsequent updates should be coalesced + if debouncer.ProcessUpdate(update2) { + t.Error("Second rapid update should not be sent immediately") + } + + if debouncer.ProcessUpdate(update3) { + t.Error("Third rapid update should not be sent immediately") + } + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != update3 { + t.Error("Should get the last update (update3)") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send first update + debouncer.ProcessUpdate(update1) + + // Send second update within debounce period + debouncer.ProcessUpdate(update2) + + // Wait for timer + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != update2 { + t.Error("Should get the last update") + } + if pendingUpdates[0] == update1 { + t.Error("Should not get the first update") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update3 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send first update + debouncer.ProcessUpdate(update1) + + // Wait a bit, but not the full debounce period + time.Sleep(30 * time.Millisecond) + + // Send second update - should reset timer + debouncer.ProcessUpdate(update2) + + // Wait a bit more + time.Sleep(30 * time.Millisecond) + + // Send third update - should reset timer again + debouncer.ProcessUpdate(update3) + + // Now wait for the timer (should fire after last update's reset) + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != update3 { + t.Error("Should get the last update (update3)") + } + // Timer should be restarted since there was a pending update + if debouncer.TimerChannel() == nil { + t.Error("Timer should be restarted after sending pending update") + } + case <-time.After(150 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update3 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // First update sent immediately + debouncer.ProcessUpdate(update1) + + // Second update coalesced + debouncer.ProcessUpdate(update2) + + // Wait for timer to expire + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + + if len(pendingUpdates) == 0 { + t.Fatal("Should have pending update") + } + + // After sending pending update, timer is restarted, so next update is NOT immediate + if debouncer.ProcessUpdate(update3) { + t.Error("Update after debounced send should not be sent immediately (timer restarted)") + } + + // Wait for the restarted timer and verify update3 is pending + select { + case <-debouncer.TimerChannel(): + finalUpdates := debouncer.GetPendingUpdates() + if len(finalUpdates) != 1 || finalUpdates[0] != update3 { + t.Error("Should get update3 as pending") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired for restarted timer") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_StopCleansUp(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send update to start timer + debouncer.ProcessUpdate(update) + + // Stop should clean up + debouncer.Stop() + + // Multiple stops should be safe + debouncer.Stop() +} + +func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + // Simulate high-frequency updates + var lastUpdate *network_map.UpdateMessage + sentImmediately := 0 + for i := 0; i < 100; i++ { + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: uint64(i), + }, + }, + MessageType: network_map.MessageTypeNetworkMap, + } + lastUpdate = update + if debouncer.ProcessUpdate(update) { + sentImmediately++ + } + time.Sleep(1 * time.Millisecond) // Very rapid updates + } + + // Only first update should be sent immediately + if sentImmediately != 1 { + t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately) + } + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != lastUpdate { + t.Error("Should get the very last update") + } + if pendingUpdates[0].Update.NetworkMap.Serial != 99 { + t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial) + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send first update + if !debouncer.ProcessUpdate(update) { + t.Error("First update should be sent immediately") + } + + // Wait for timer to expire with no additional updates (true quiet period) + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 0 { + t.Error("Should have no pending updates") + } + // After true quiet period, timer should be cleared + if debouncer.TimerChannel() != nil { + t.Error("Timer should be cleared after quiet period") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + updates := make([]*network_map.UpdateMessage, 5) + for i := range updates { + updates[i] = &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: uint64(i), + }, + }, + MessageType: network_map.MessageTypeNetworkMap, + } + } + + // First update sent immediately + debouncer.ProcessUpdate(updates[0]) + + // Send updates 1, 2, 3, 4 rapidly - only last one should remain pending + debouncer.ProcessUpdate(updates[1]) + debouncer.ProcessUpdate(updates[2]) + debouncer.ProcessUpdate(updates[3]) + debouncer.ProcessUpdate(updates[4]) + + // Wait for debounce + <-debouncer.TimerChannel() + pendingUpdates := debouncer.GetPendingUpdates() + + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0].Update.NetworkMap.Serial != 4 { + t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial) + } +} + +func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // First update sent immediately + if !debouncer.ProcessUpdate(update1) { + t.Error("First update should be sent immediately") + } + + // Wait for timer without sending any more updates (true quiet period) + <-debouncer.TimerChannel() + pendingUpdates := debouncer.GetPendingUpdates() + + if len(pendingUpdates) != 0 { + t.Error("Should have no pending updates during quiet period") + } + + // After true quiet period, next update should be sent immediately + if !debouncer.ProcessUpdate(update2) { + t.Error("Update after true quiet period should be sent immediately") + } +} + +func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + // Simulate continuous high-frequency updates + for i := 0; i < 10; i++ { + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: uint64(i), + }, + }, + MessageType: network_map.MessageTypeNetworkMap, + } + + if i == 0 { + // First one sent immediately + if !debouncer.ProcessUpdate(update) { + t.Error("First update should be sent immediately") + } + } else { + // All others should be coalesced (not sent immediately) + if debouncer.ProcessUpdate(update) { + t.Errorf("Update %d should not be sent immediately", i) + } + } + + // Wait a bit but send next update before debounce expires + time.Sleep(20 * time.Millisecond) + } + + // Now wait for final debounce + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) == 0 { + t.Fatal("Should have the last update pending") + } + if pendingUpdates[0].Update.NetworkMap.Serial != 9 { + t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial) + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + netmapUpdate := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}}, + MessageType: network_map.MessageTypeNetworkMap, + } + tokenUpdate1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + tokenUpdate2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + + // First update sent immediately + debouncer.ProcessUpdate(netmapUpdate) + + // Send multiple control config updates - they should all be queued + debouncer.ProcessUpdate(tokenUpdate1) + debouncer.ProcessUpdate(tokenUpdate2) + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + // Should get both control config updates + if len(pendingUpdates) != 2 { + t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates)) + } + // Control configs should come first + if pendingUpdates[0] != tokenUpdate1 { + t.Error("First pending update should be tokenUpdate1") + } + if pendingUpdates[1] != tokenUpdate2 { + t.Error("Second pending update should be tokenUpdate2") + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + netmapUpdate1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}}, + MessageType: network_map.MessageTypeNetworkMap, + } + netmapUpdate2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}}, + MessageType: network_map.MessageTypeNetworkMap, + } + tokenUpdate := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + + // First update sent immediately + debouncer.ProcessUpdate(netmapUpdate1) + + // Send token update and network map update + debouncer.ProcessUpdate(tokenUpdate) + debouncer.ProcessUpdate(netmapUpdate2) + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + // Should get 2 updates in order: token, then network map + if len(pendingUpdates) != 2 { + t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates)) + } + // Token update should come first (preserves order) + if pendingUpdates[0] != tokenUpdate { + t.Error("First pending update should be tokenUpdate") + } + // Network map update should come second + if pendingUpdates[1] != netmapUpdate2 { + t.Error("Second pending update should be netmapUpdate2") + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_OrderPreservation(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + // Simulate: 50 network maps -> 1 control config -> 50 network maps + // Expected result: 3 messages (netmap, controlConfig, netmap) + + // Send first network map immediately + firstNetmap := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}}, + MessageType: network_map.MessageTypeNetworkMap, + } + if !debouncer.ProcessUpdate(firstNetmap) { + t.Error("First update should be sent immediately") + } + + // Send 49 more network maps (will be coalesced to last one) + var lastNetmapBatch1 *network_map.UpdateMessage + for i := 1; i < 50; i++ { + lastNetmapBatch1 = &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}}, + MessageType: network_map.MessageTypeNetworkMap, + } + debouncer.ProcessUpdate(lastNetmapBatch1) + } + + // Send 1 control config + controlConfig := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + debouncer.ProcessUpdate(controlConfig) + + // Send 50 more network maps (will be coalesced to last one) + var lastNetmapBatch2 *network_map.UpdateMessage + for i := 50; i < 100; i++ { + lastNetmapBatch2 = &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}}, + MessageType: network_map.MessageTypeNetworkMap, + } + debouncer.ProcessUpdate(lastNetmapBatch2) + } + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + // Should get exactly 3 updates: netmap, controlConfig, netmap + if len(pendingUpdates) != 3 { + t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates)) + } + // First should be the last netmap from batch 1 + if pendingUpdates[0] != lastNetmapBatch1 { + t.Error("First pending update should be last netmap from batch 1") + } + if pendingUpdates[0].Update.NetworkMap.Serial != 49 { + t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial) + } + // Second should be the control config + if pendingUpdates[1] != controlConfig { + t.Error("Second pending update should be control config") + } + // Third should be the last netmap from batch 2 + if pendingUpdates[2] != lastNetmapBatch2 { + t.Error("Third pending update should be last netmap from batch 2") + } + if pendingUpdates[2].Update.NetworkMap.Serial != 99 { + t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial) + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} diff --git a/management/server/management_test.go b/management/server/management_test.go index 0864baadf..de02855bf 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -610,6 +610,7 @@ func TestSync10PeersGetUpdates(t *testing.T) { initialPeers := 10 additionalPeers := 10 + expectedPeerCount := initialPeers + additionalPeers - 1 // -1 because peer doesn't see itself var peers []wgtypes.Key for i := 0; i < initialPeers; i++ { @@ -618,8 +619,19 @@ func TestSync10PeersGetUpdates(t *testing.T) { peers = append(peers, key) } + // Track the maximum peer count each peer has seen + type peerState struct { + mu sync.Mutex + maxPeerCount int + done bool + } + peerStates := make(map[string]*peerState) + for _, pk := range peers { + peerStates[pk.PublicKey().String()] = &peerState{} + } + var wg sync.WaitGroup - wg.Add(initialPeers + initialPeers*additionalPeers) + wg.Add(initialPeers) // One completion per initial peer var syncClients []mgmtProto.ManagementService_SyncClient for _, pk := range peers { @@ -643,6 +655,9 @@ func TestSync10PeersGetUpdates(t *testing.T) { syncClients = append(syncClients, s) go func(pk wgtypes.Key, syncStream mgmtProto.ManagementService_SyncClient) { + pubKey := pk.PublicKey().String() + state := peerStates[pubKey] + for { encMsg := &mgmtProto.EncryptedMessage{} err := syncStream.RecvMsg(encMsg) @@ -651,19 +666,28 @@ func TestSync10PeersGetUpdates(t *testing.T) { } decryptedBytes, decErr := encryption.Decrypt(encMsg.Body, ts.serverPubKey, pk) if decErr != nil { - t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pk.PublicKey().String(), decErr) + t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pubKey, decErr) return } resp := &mgmtProto.SyncResponse{} umErr := pb.Unmarshal(decryptedBytes, resp) if umErr != nil { - t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pk.PublicKey().String(), umErr) + t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pubKey, umErr) return } - // We only count if there's a new peer update - if len(resp.GetRemotePeers()) > 0 { + + // Track the maximum peer count seen (due to debouncing, updates are coalesced) + peerCount := len(resp.GetRemotePeers()) + state.mu.Lock() + if peerCount > state.maxPeerCount { + state.maxPeerCount = peerCount + } + // Signal completion when this peer has seen all expected peers + if !state.done && state.maxPeerCount >= expectedPeerCount { + state.done = true wg.Done() } + state.mu.Unlock() } }(pk, s) } @@ -677,7 +701,30 @@ func TestSync10PeersGetUpdates(t *testing.T) { time.Sleep(time.Duration(n) * time.Millisecond) } - wg.Wait() + // Wait for debouncer to flush final updates (debounce interval is 1000ms) + time.Sleep(1500 * time.Millisecond) + + // Wait with timeout + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Success - all peers received expected peer count + case <-time.After(5 * time.Second): + // Timeout - report which peers didn't receive all updates + t.Error("Timeout waiting for all peers to receive updates") + for pubKey, state := range peerStates { + state.mu.Lock() + if state.maxPeerCount < expectedPeerCount { + t.Errorf("Peer %s only saw %d peers, expected %d", pubKey, state.maxPeerCount, expectedPeerCount) + } + state.mu.Unlock() + } + } for _, sc := range syncClients { err := sc.CloseSend()