Compare commits

..

11 Commits

Author SHA1 Message Date
riccardom
4988b6726e Aligns new tests to signature 2026-06-28 17:25:17 +02:00
riccardom
2552830184 Prevents skipping of intermediate map updates potentially not applied
by moving the persistence from applySync to the map state manager
2026-06-28 17:23:34 +02:00
riccardom
3b8fc688f4 Do the wholesale (firewall/routes/dns) once only 2026-06-28 17:23:34 +02:00
riccardom
d82d62e818 Adds explicit merge call for future map updates 2026-06-28 17:20:00 +02:00
riccardom
0bf964dad7 Do not process intermediate one if new ones are fresher just use the freshest 2026-06-28 17:20:00 +02:00
riccardom
297dcb3e24 Always run onConverged for every map that is processed 2026-06-28 17:20:00 +02:00
riccardom
bc22926fe0 Drop in case of error, will reconcile with next update 2026-06-28 17:20:00 +02:00
riccardom
d3f2ef9adb Comment why not serial 2026-06-28 17:20:00 +02:00
riccardom
5bec1e8f03 Adds map state manager 2026-06-28 17:20:00 +02:00
riccardom
74bb5c613e Allows to specify max batch for tests 2026-06-28 17:20:00 +02:00
riccardom
29dde908ae Modifies handleSync to support progressive peers conns convergence 2026-06-28 17:19:27 +02:00
6 changed files with 652 additions and 137 deletions

View File

@@ -210,6 +210,12 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
// forwardingRules holds the ingress forward rules applied for the current target.
// Wholesale sections (incl. forward rules) run only on the first pass of a target;
// it is stashed here so the final, peer-converged pass can build the lazy-connection
// exclude list without recomputing them on every bounded peer pass.
forwardingRules []firewallManager.ForwardRule
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
@@ -762,7 +768,15 @@ func (e *Engine) blockLanAccess() {
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// maxPeersPerSyncPass is the default per-pass cap on how many peers each of
// removePeers/modifyPeers/addNewPeers applies, so syncMsgMux is held only for a
// batch at a time and other subsystems can interleave between passes. It is
// passed in (not read globally) so tests can exercise the multi-pass path.
const maxPeersPerSyncPass = 300
// modifyPeers re-applies up to maxBatch changed peers per call. It returns true
// when more changed peers remained than the cap, so the caller re-runs.
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
// first, check if peers have been modified
var modified []*mgmProto.RemotePeerConfig
@@ -792,26 +806,32 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
}
}
more := false
if len(modified) > maxBatch {
modified = modified[:maxBatch]
more = true
}
// second, close all modified connections and remove them from the state map
for _, p := range modified {
err := e.removePeer(p.GetWgPubKey())
if err != nil {
return err
if err := e.removePeer(p.GetWgPubKey()); err != nil {
return false, err
}
}
// third, add the peer connections again
for _, p := range modified {
err := e.addNewPeer(p)
if err != nil {
return err
if err := e.addNewPeer(p); err != nil {
return false, err
}
}
return nil
return more, nil
}
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// removePeers removes up to maxBatch peers per call. It returns true when more
// peers remained to remove than the cap, so the caller re-runs.
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
newPeers := make([]string, 0, len(peersUpdate))
for _, p := range peersUpdate {
newPeers = append(newPeers, p.GetWgPubKey())
@@ -819,14 +839,19 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
more := false
if len(toRemove) > maxBatch {
toRemove = toRemove[:maxBatch]
more = true
}
for _, p := range toRemove {
err := e.removePeer(p)
if err != nil {
return err
if err := e.removePeer(p); err != nil {
return false, err
}
log.Infof("removed peer %s", p)
}
return nil
return more, nil
}
func (e *Engine) removeAllPeers() error {
@@ -895,19 +920,17 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
duration := time.Since(started)
log.Infof("sync finished in %s", duration)
e.clientMetrics.RecordSyncDuration(e.ctx, duration)
}()
// applySyncPass applies one bounded pass of the sync update under syncMsgMux and
// returns true if more peers remained than the per-pass cap. It is driven by the
// mapStateManager, which re-invokes it (releasing the lock between passes) until
// the update is fully applied.
func (e *Engine) applySyncPass(update *mgmProto.SyncResponse, firstPass bool) (bool, error) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return e.ctx.Err()
return false, e.ctx.Err()
}
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
@@ -915,7 +938,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
}
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
return err
return false, err
}
// Posture checks are bound to the network map presence:
@@ -925,23 +948,22 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
// leave the previously applied checks untouched
nm := update.GetNetworkMap()
if nm == nil {
return nil
return false, nil
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
return err
return false, err
}
e.persistSyncResponse(update)
// only apply new changes and ignore old ones
if err := e.updateNetworkMap(nm); err != nil {
return err
more, err := e.updateNetworkMap(nm, maxPeersPerSyncPass, firstPass)
if err != nil {
return false, err
}
e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil)
return nil
return more, nil
}
// updateNetbirdConfig applies the management-provided NetBird configuration:
@@ -987,6 +1009,13 @@ func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
// engine close) mid-call and have this write resurrect a file that was just removed.
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
// Only persist updates that carry a network map. Config-only updates (e.g. relay
// token rotation, STUN/TURN) have a nil NetworkMap; persisting them would overwrite
// the last full map on disk and break restore-on-restart.
if update.GetNetworkMap() == nil {
return
}
e.syncRespMux.RLock()
defer e.syncRespMux.RUnlock()
@@ -1278,7 +1307,19 @@ func (e *Engine) receiveManagementEvents() {
e.config.DisableSSHAuth,
)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
// The map-state manager converges the latest update in the background in
// bounded passes; the stream callback only hands it the newest target.
manager := newMapStateManager(e.applySyncPass, e.persistSyncResponse, func(d time.Duration) {
log.Infof("sync finished in %s", d)
e.clientMetrics.RecordSyncDuration(e.ctx, d)
})
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
manager.run(e.ctx)
}()
err = e.mgmClient.Sync(e.ctx, info, manager.SetTarget)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
@@ -1329,21 +1370,104 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
return nil
}
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// updateNetworkMap applies the wholesale parts (config, routes, ACL, DNS) in full
// and up to maxBatch peers per phase. It returns true when more peers remained
// than the cap, so the caller re-runs until convergence.
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap, maxBatch int, firstPass bool) (bool, error) {
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
if networkMap.GetPeerConfig() != nil {
err := e.updateConfig(networkMap.GetPeerConfig())
if err != nil {
return err
return false, err
}
}
serial := networkMap.GetSerial()
if e.networkSerial > serial {
log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial)
return nil
return false, nil
}
// Wholesale sections (firewall/ACL, DNS, routes, forward rules) are applied
// up-front and only once per target: they are cheap, local, idempotent and must
// be in place before peers come up (fail-closed). On the bounded re-runs that only
// drain the remaining peer batches they are skipped — the applied forward rules are
// reused from e.forwardingRules for the lazy-exclude finalize.
if firstPass {
e.applyWholesale(networkMap, serial)
}
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers())
// Filter out own peer from the remote peers list
localPubKey := e.config.WgPrivateKey.PublicKey().String()
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
for _, p := range networkMap.GetRemotePeers() {
if p.GetWgPubKey() != localPubKey {
remotePeers = append(remotePeers, p)
}
}
// needMore signals the caller to re-run when a peer phase hit its per-pass cap.
needMore := false
// cleanup request, most likely our peer has been deleted
if networkMap.GetRemotePeersIsEmpty() {
err := e.removeAllPeers()
e.statusRecorder.FinishPeerListModifications()
if err != nil {
return false, err
}
} else {
removeMore, err := e.removePeers(remotePeers, maxBatch)
if err != nil {
return false, err
}
modifyMore, err := e.modifyPeers(remotePeers, maxBatch)
if err != nil {
return false, err
}
addMore, err := e.addNewPeers(remotePeers, maxBatch)
if err != nil {
return false, err
}
needMore = removeMore || modifyMore || addMore
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
}
// Set the exclude list only once peers have fully converged (this pass added
// the last batch). It needs all target peers present in the store, and
// ExcludePeer has replace-semantics — a partial set mid-convergence would be wrong.
if !needMore {
excludedLazyPeers := e.toExcludedLazyPeers(e.forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
}
e.networkSerial = serial
return needMore, nil
}
// applyWholesale applies the cheap, local, idempotent map sections — lazy feature
// flag, firewall/legacy management, DNS, routes, ACL filtering, DNS forwarder and
// ingress forward rules — that must be in place before peers come up. It runs once
// per target (first pass only); the resulting forward rules are stashed in
// e.forwardingRules for the lazy-exclude finalize on the peer-converged pass.
func (e *Engine) applyWholesale(networkMap *mgmProto.NetworkMap, serial uint64) {
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
log.Errorf("failed to update lazy connection feature flag: %v", err)
}
@@ -1404,61 +1528,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
if err != nil {
log.Errorf("failed to update forward rules, err: %v", err)
}
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers())
// Filter out own peer from the remote peers list
localPubKey := e.config.WgPrivateKey.PublicKey().String()
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
for _, p := range networkMap.GetRemotePeers() {
if p.GetWgPubKey() != localPubKey {
remotePeers = append(remotePeers, p)
}
}
// cleanup request, most likely our peer has been deleted
if networkMap.GetRemotePeersIsEmpty() {
err := e.removeAllPeers()
e.statusRecorder.FinishPeerListModifications()
if err != nil {
return err
}
} else {
err := e.removePeers(remotePeers)
if err != nil {
return err
}
err = e.modifyPeers(remotePeers)
if err != nil {
return err
}
err = e.addNewPeers(remotePeers)
if err != nil {
return err
}
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
e.networkSerial = serial
return nil
e.forwardingRules = forwardingRules
}
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
@@ -1638,14 +1708,23 @@ func addrToString(addr netip.Addr) string {
}
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// addNewPeers adds up to maxBatch not-yet-present peers per call. It returns true
// when more new peers remained than the cap, so the caller re-runs.
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
added := 0
for _, p := range peersUpdate {
err := e.addNewPeer(p)
if err != nil {
return err
if _, ok := e.peerStore.PeerConn(p.GetWgPubKey()); ok {
continue // already present (cheap skip), does not count toward the cap
}
if added >= maxBatch {
return true, nil // at least one more new peer remains
}
if err := e.addNewPeer(p); err != nil {
return false, err
}
added++
}
return nil
return false, nil
}
// addNewPeer add peer if connection doesn't exist

View File

@@ -124,7 +124,7 @@ func TestEngine_SSH(t *testing.T) {
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
@@ -146,7 +146,7 @@ func TestEngine_SSH(t *testing.T) {
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
require.NoError(t, err)
time.Sleep(250 * time.Millisecond)
@@ -159,7 +159,7 @@ func TestEngine_SSH(t *testing.T) {
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
require.NoError(t, err)
// time.Sleep(250 * time.Millisecond)
@@ -174,7 +174,7 @@ func TestEngine_SSH(t *testing.T) {
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)

View File

@@ -433,7 +433,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
for _, c := range []testCase{case1, case2, case3, case4, case5, case6} {
t.Run(c.name, func(t *testing.T) {
err = engine.updateNetworkMap(c.networkMap)
_, err = engine.updateNetworkMap(c.networkMap, maxPeersPerSyncPass, true)
if err != nil {
t.Fatal(err)
return
@@ -460,6 +460,47 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}
})
}
// chunked apply: with a per-pass cap smaller than the number of peers, a
// single updateNetworkMap applies one batch and reports more==true; the
// caller re-runs until convergence. (engine currently holds 0 peers.)
t.Run("chunked add converges over multiple passes", func(t *testing.T) {
nm := &mgmtProto.NetworkMap{
Serial: 6,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
}
more, err := engine.updateNetworkMap(nm, 1, true)
require.NoError(t, err)
require.True(t, more, "pass 1 should signal more")
require.Len(t, engine.peerStore.PeersPubKey(), 1)
more, err = engine.updateNetworkMap(nm, 1, false)
require.NoError(t, err)
require.True(t, more, "pass 2 should signal more")
require.Len(t, engine.peerStore.PeersPubKey(), 2)
more, err = engine.updateNetworkMap(nm, 1, false)
require.NoError(t, err)
require.False(t, more, "pass 3 should converge")
require.Len(t, engine.peerStore.PeersPubKey(), 3)
})
t.Run("chunked remove converges over multiple passes", func(t *testing.T) {
nm := &mgmtProto.NetworkMap{
Serial: 7,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1}, // remove peer2, peer3
}
more, err := engine.updateNetworkMap(nm, 1, true)
require.NoError(t, err)
require.True(t, more, "pass 1 should signal more (2 to remove, cap 1)")
more, err = engine.updateNetworkMap(nm, 1, false)
require.NoError(t, err)
require.False(t, more, "pass 2 should converge")
require.Len(t, engine.peerStore.PeersPubKey(), 1)
})
}
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
@@ -630,7 +671,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}
}()
err = engine.updateNetworkMap(testCase.networkMap)
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.clientRoutes, testCase.expectedLen, "clientRoutes len should match")
@@ -834,7 +875,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}
}()
err = engine.updateNetworkMap(testCase.networkMap)
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match")

190
client/internal/mapsync.go Normal file
View File

@@ -0,0 +1,190 @@
package internal
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// mapStateManager is the single read/write point between the management stream
// (writes) and the convergence loop (reads/applies).
//
// The stream calls SetTarget with the latest full SyncResponse — the complete
// desired state. A single background goroutine (run) applies it to the engine in
// bounded passes via apply() until converged, releasing syncMsgMux between passes
// so other subsystems interleave. If a newer update arrives mid-flight, the loop
// coalesces: it keeps converging toward the latest target and the intermediate one
// is SKIPPED — never applied on its own (logged, no onConverged).
//
// Convergence is a single comparison: appliedGen == targetGen. targetGen
// increments on every SetTarget (an internal generation counter, so it also covers
// config-only updates that carry no network-map serial).
//
// onConverged fires once for each — and only each — map that is actually processed
// (i.e. converged as the target). Skipped/superseded maps and dropped-on-error maps
// do NOT fire it. So "sync finished in X" / RecordSyncDuration always corresponds
// to a real, completed alignment.
type mapStateManager struct {
// apply performs one bounded apply pass and reports whether more passes are needed.
// firstPass is true on the first pass of a given target, so the caller can run
// wholesale (firewall/routes/DNS/forward-rules) once per target and skip it on the
// re-runs that only drain the bounded peer batches. The manager owns this signal
// because it owns the convergence boundary; the engine need not track serials for it.
apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error)
// onConverged is called once per processed map, with the elapsed time since that
// map was received (for the sync-duration metric / "sync finished" log).
onConverged func(time.Duration)
// persist snapshots an update to disk for restore-on-restart. Called once per
// update received from management (in SetTarget), including ones later coalesced
// or skipped from apply, so the on-disk state mirrors what management last sent.
// The impl skips config-only updates (nil NetworkMap). May be nil.
persist func(*mgmProto.SyncResponse)
mu sync.Mutex
target *mgmProto.SyncResponse
targetGen uint64
appliedGen uint64
targetSetAt time.Time
wake chan struct{}
}
func newMapStateManager(apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error), persist func(*mgmProto.SyncResponse), onConverged func(time.Duration)) *mapStateManager {
return &mapStateManager{
apply: apply,
persist: persist,
onConverged: onConverged,
wake: make(chan struct{}, 1),
}
}
// SetTarget records the latest update as the desired state and wakes the loop.
// It returns immediately; convergence happens in the background. Serial-based
// staleness of the network map is still enforced inside apply (updateNetworkMap).
func (m *mapStateManager) SetTarget(update *mgmProto.SyncResponse) error {
m.mu.Lock()
// A target that has not settled yet (targetGen > appliedGen) is being superseded
// before it converged: we coalesce to the latest map and never apply this one on
// its own. It is SKIPPED — logged here, and it will not fire onConverged.
if m.target != nil && m.targetGen > m.appliedGen {
log.Debugf("sync map (gen %d) superseded before convergence, skipping", m.targetGen)
}
m.target = m.mergeTarget(m.target, update)
// Bump an internal generation counter, NOT the map serial: config-only updates
// (relay token rotation, STUN/TURN) arrive with NetworkMap == nil and carry no
// serial, yet must still be applied. Every SetTarget is therefore a distinct
// target regardless of payload. Map-serial staleness is enforced separately
// inside apply (updateNetworkMap).
m.targetGen++
m.targetSetAt = time.Now()
m.mu.Unlock()
select {
case m.wake <- struct{}{}:
default:
}
// Persist every update received from management — once per update (not per apply
// pass), and including ones that get coalesced/skipped from apply, so the on-disk
// state always reflects the latest map management sent. Done after waking the loop
// so convergence can start in parallel with the disk write. The persist impl skips
// config-only updates (nil NetworkMap).
if m.persist != nil {
m.persist(update)
}
return nil
}
// mergeTarget combines the currently pending target with a freshly received update
// and returns the new desired state. It is called under m.mu from SetTarget and is
// the single seam where the replace-vs-squash decision lives.
//
// Today management always sends a FULL map (the complete desired state), so the
// update simply replaces whatever was pending — prev is ignored. When management
// starts sending incremental/delta updates, squash `update` onto `prev` here; the
// rest of the manager (generation tracking, convergence, signaling) is unaffected
// because it already treats target as "the complete desired state, whatever it is".
func (m *mapStateManager) mergeTarget(prev, update *mgmProto.SyncResponse) *mgmProto.SyncResponse {
return update
}
// run drives convergence until ctx is done. It is meant to run in its own goroutine.
func (m *mapStateManager) run(ctx context.Context) {
// passGen is the generation of the most recent apply() call (0 = none). A pass is
// the first for its target when its generation differs from the previous one —
// true on a fresh target and on a coalesced switch to a newer target mid-flight.
var passGen uint64
for {
m.mu.Lock()
target, tg, ag := m.target, m.targetGen, m.appliedGen
m.mu.Unlock()
// Fully converged (or nothing yet): block until a new target arrives.
if target == nil || ag == tg {
select {
case <-ctx.Done():
return
case <-m.wake:
continue
}
}
firstPass := tg != passGen
passGen = tg
more, err := m.apply(target, firstPass)
if err != nil {
if ctx.Err() != nil {
return
}
// Log and DROP this target — do not retry it. A deterministic failure
// (e.g. a malformed peer in the map) would otherwise spin every pass
// making no progress. Management is the source of truth and re-delivers
// the full map on the next sync, so dropping is safe; peers already
// applied this convergence stay (idempotent diffs) and the remainder is
// reconciled by the next target. Mirrors the legacy handleSync path,
// where the apply error was logged by the gRPC client and the update
// dropped. No onConverged: this target did not converge.
log.Errorf("apply sync pass, dropping update: %v", err)
m.settle(tg, false)
continue
}
if more {
// keep converging the current target; syncMsgMux was released by apply
// between passes so other subsystems interleave.
continue
}
// This pass converged. Mark applied and signal this one map.
m.settle(tg, true)
// if a newer target arrived mid-pass, settle is a no-op (targetGen != tg) and
// ag<tg next iteration -> apply it; this generation was skipped (logged in
// SetTarget) and is not signaled.
}
}
// settle marks generation tg as processed so the loop goes idle instead of
// re-applying the same target. It is a no-op when a newer target arrived during the
// pass (targetGen != tg), leaving appliedGen behind so that target re-applies — the
// just-finished generation was already counted as skipped.
//
// When signal is true (the pass converged) it fires onConverged once for this map;
// when false (the target was dropped on error) it does not — the map did not converge.
func (m *mapStateManager) settle(tg uint64, signal bool) {
m.mu.Lock()
if m.targetGen != tg {
m.mu.Unlock()
return
}
m.appliedGen = tg
setAt := m.targetSetAt
m.mu.Unlock()
if signal && m.onConverged != nil {
m.onConverged(time.Since(setAt))
}
}

View File

@@ -0,0 +1,242 @@
package internal
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// converges over the bounded passes (apply returns more until the 3rd pass),
// fires onConverged exactly once, then blocks (no further apply) until a new target.
func TestMapStateManager_ConvergesThenStops(t *testing.T) {
var passes int32
var firstPasses int32
converged := make(chan struct{}, 1)
apply := func(_ *mgmProto.SyncResponse, firstPass bool) (bool, error) {
n := atomic.AddInt32(&passes, 1)
if firstPass {
atomic.AddInt32(&firstPasses, 1)
}
return n < 3, nil // more on pass 1 and 2, converge on pass 3
}
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-converged:
case <-time.After(2 * time.Second):
t.Fatal("manager did not converge")
}
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
require.EqualValues(t, 1, atomic.LoadInt32(&firstPasses), "firstPass true only on pass 1, false on re-runs of the same target")
// once converged the loop blocks: no further apply calls
time.Sleep(100 * time.Millisecond)
require.EqualValues(t, 3, atomic.LoadInt32(&passes), "apply must not run after convergence")
}
// persist runs once per received update (not per apply pass), regardless of how many
// bounded passes that target takes to converge.
func TestMapStateManager_PersistsOncePerUpdate(t *testing.T) {
var passes, persists int32
converged := make(chan struct{}, 1)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
n := atomic.AddInt32(&passes, 1)
return n < 3, nil // 3 passes for one target
}
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
m := newMapStateManager(apply, persist, func(time.Duration) { converged <- struct{}{} })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-converged:
case <-time.After(2 * time.Second):
t.Fatal("did not converge")
}
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
require.EqualValues(t, 1, atomic.LoadInt32(&persists), "persist once per update, not per pass")
}
// every update received from management is persisted — even one that is coalesced /
// skipped from apply before it ever converges.
func TestMapStateManager_PersistsEveryUpdateIncludingSkipped(t *testing.T) {
release := make(chan struct{})
var persists int32
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
<-release // hold the first apply so the second update coalesces/skips
return false, nil
}
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
m := newMapStateManager(apply, persist, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map1 -> apply blocks
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map2 supersedes map1 (skipped from apply)
close(release)
// both updates persisted even though map1 is skipped from apply
require.Eventually(t, func() bool { return atomic.LoadInt32(&persists) == 2 }, 2*time.Second, 10*time.Millisecond)
}
// each map that is actually processed (converged before the next arrives) fires
// onConverged exactly once — mirroring the legacy per-message handleSync timing.
func TestMapStateManager_SignalsEachProcessedMap(t *testing.T) {
converged := make(chan struct{}, 8)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
return false, nil // converge in one pass
}
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
const maps = 3
for i := 0; i < maps; i++ {
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select { // wait for this map to converge before sending the next (no coalescing)
case <-converged:
case <-time.After(2 * time.Second):
t.Fatalf("map %d not signaled", i)
}
}
// no extra signals once the stream goes quiet
select {
case <-converged:
t.Fatal("unexpected extra onConverged")
case <-time.After(100 * time.Millisecond):
}
}
// a map superseded before it converges is skipped: only the latest (processed) map
// fires onConverged, not the skipped one.
func TestMapStateManager_SkippedMapNotSignaled(t *testing.T) {
release := make(chan struct{})
var applies, converged atomic.Int32
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
applies.Add(1)
<-release // hold the first apply in-flight so we can queue a newer target
return false, nil
}
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
// map1 is picked up; its apply blocks on release
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
require.Eventually(t, func() bool { return applies.Load() >= 1 }, 2*time.Second, 5*time.Millisecond)
// map2 supersedes map1 before it settled -> map1 is skipped
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
close(release) // let both applies proceed
// only the processed (latest) map signals; the skipped one does not
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
time.Sleep(150 * time.Millisecond)
require.EqualValues(t, 1, converged.Load(), "skipped map must not fire onConverged")
require.EqualValues(t, 2, applies.Load(), "both targets entered apply (map1 once, map2 once)")
}
// an apply error drops the target: no retry of the same target, no onConverged,
// the loop goes idle — and a fresh target is still applied afterwards.
func TestMapStateManager_DropsTargetOnError(t *testing.T) {
applied := make(chan struct{}, 8)
var failNext atomic.Bool
failNext.Store(true)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
applied <- struct{}{}
if failNext.Load() {
return false, errors.New("boom")
}
return false, nil // converge in one pass
}
var converged atomic.Int32
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
// first target errors -> applied once, then dropped (no retry, no onConverged)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("errored target not applied")
}
select {
case <-applied:
t.Fatal("errored target must not be retried")
case <-time.After(150 * time.Millisecond):
}
require.EqualValues(t, 0, converged.Load(), "onConverged must not fire on error")
// a new target is still processed normally and converges
failNext.Store(false)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("new target after error not applied")
}
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
}
// a new target after convergence triggers a fresh apply; an idle (converged)
// manager does not apply on its own.
func TestMapStateManager_ReappliesOnNewTarget(t *testing.T) {
applied := make(chan struct{}, 8)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
applied <- struct{}{}
return false, nil // converge in one pass
}
m := newMapStateManager(apply, nil, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("first target not applied")
}
// converged → must stay idle (no spurious apply)
select {
case <-applied:
t.Fatal("unexpected apply while idle/converged")
case <-time.After(150 * time.Millisecond):
}
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("new target not applied")
}
}

View File

@@ -55,14 +55,6 @@ type GrpcClient struct {
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
serverURL string
// syncStreamErr holds the last Sync stream error, or nil while the stream
// is established and healthy. GetServerKey succeeds even when the peer
// cannot sync (e.g. the server returns "settings not found"), so the
// health probe must consult this to avoid reporting a healthy management
// connection while the Sync stream keeps failing.
syncStreamMu sync.RWMutex
syncStreamErr error
}
type ExposeRequest struct {
@@ -372,8 +364,6 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
stream, err := c.connectToSyncStream(ctx, serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
c.notifyDisconnected(err)
c.setSyncStreamDisconnected(err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
}
@@ -382,13 +372,11 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
log.Infof("connected to the Management Service stream")
c.notifyConnected()
c.setSyncStreamConnected()
// blocking until error
err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler)
if err != nil {
c.notifyDisconnected(err)
c.setSyncStreamDisconnected(err)
if ctx.Err() != nil {
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
@@ -542,13 +530,6 @@ func (c *GrpcClient) IsHealthy() bool {
log.Warnf("health check returned: %s", err)
return false
}
if syncErr := c.syncStreamError(); syncErr != nil {
c.notifyDisconnected(syncErr)
log.Warnf("management transport is up but the Sync stream is unhealthy: %s", syncErr)
return false
}
c.notifyConnected()
return true
}
@@ -790,24 +771,6 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
return err
}
func (c *GrpcClient) setSyncStreamConnected() {
c.syncStreamMu.Lock()
defer c.syncStreamMu.Unlock()
c.syncStreamErr = nil
}
func (c *GrpcClient) setSyncStreamDisconnected(err error) {
c.syncStreamMu.Lock()
defer c.syncStreamMu.Unlock()
c.syncStreamErr = err
}
func (c *GrpcClient) syncStreamError() error {
c.syncStreamMu.RLock()
defer c.syncStreamMu.RUnlock()
return c.syncStreamErr
}
func (c *GrpcClient) notifyDisconnected(err error) {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()