mirror of
https://github.com/netbirdio/netbird.git
synced 2026-07-05 06:09:56 +00:00
Compare commits
16 Commits
v0.74.2
...
netmap_pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d613189f6 | ||
|
|
7673067605 | ||
|
|
79567fe347 | ||
|
|
cf8d92fbb0 | ||
|
|
b70fc4015b | ||
|
|
4988b6726e | ||
|
|
2552830184 | ||
|
|
3b8fc688f4 | ||
|
|
d82d62e818 | ||
|
|
0bf964dad7 | ||
|
|
297dcb3e24 | ||
|
|
bc22926fe0 | ||
|
|
d3f2ef9adb | ||
|
|
5bec1e8f03 | ||
|
|
74bb5c613e | ||
|
|
29dde908ae |
@@ -220,6 +220,12 @@ type Engine struct {
|
|||||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||||
networkSerial uint64
|
networkSerial uint64
|
||||||
|
|
||||||
|
// forwardingRules holds the ingress forward rules applied for the current target.
|
||||||
|
// Wholesale sections (incl. forward rules) run only on the first pass of a target;
|
||||||
|
// it is stashed here so the final, peer-converged pass can build the lazy-connection
|
||||||
|
// exclude list without recomputing them on every bounded peer pass.
|
||||||
|
forwardingRules []firewallManager.ForwardRule
|
||||||
|
|
||||||
networkMonitor *networkmonitor.NetworkMonitor
|
networkMonitor *networkmonitor.NetworkMonitor
|
||||||
|
|
||||||
sshServer sshServer
|
sshServer sshServer
|
||||||
@@ -774,7 +780,15 @@ func (e *Engine) blockLanAccess() {
|
|||||||
|
|
||||||
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
|
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
|
||||||
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
|
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
|
||||||
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
// maxPeersPerSyncPass is the default per-pass cap on how many peers each of
|
||||||
|
// removePeers/modifyPeers/addNewPeers applies, so syncMsgMux is held only for a
|
||||||
|
// batch at a time and other subsystems can interleave between passes. It is
|
||||||
|
// passed in (not read globally) so tests can exercise the multi-pass path.
|
||||||
|
const maxPeersPerSyncPass = 300
|
||||||
|
|
||||||
|
// modifyPeers re-applies up to maxBatch changed peers per call. It returns true
|
||||||
|
// when more changed peers remained than the cap, so the caller re-runs.
|
||||||
|
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
|
||||||
|
|
||||||
// first, check if peers have been modified
|
// first, check if peers have been modified
|
||||||
var modified []*mgmProto.RemotePeerConfig
|
var modified []*mgmProto.RemotePeerConfig
|
||||||
@@ -804,26 +818,32 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
more := false
|
||||||
|
if len(modified) > maxBatch {
|
||||||
|
modified = modified[:maxBatch]
|
||||||
|
more = true
|
||||||
|
}
|
||||||
|
|
||||||
// second, close all modified connections and remove them from the state map
|
// second, close all modified connections and remove them from the state map
|
||||||
for _, p := range modified {
|
for _, p := range modified {
|
||||||
err := e.removePeer(p.GetWgPubKey())
|
if err := e.removePeer(p.GetWgPubKey()); err != nil {
|
||||||
if err != nil {
|
return false, err
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// third, add the peer connections again
|
// third, add the peer connections again
|
||||||
for _, p := range modified {
|
for _, p := range modified {
|
||||||
err := e.addNewPeer(p)
|
if err := e.addNewPeer(p); err != nil {
|
||||||
if err != nil {
|
return false, err
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return more, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
|
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
|
||||||
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
|
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
|
||||||
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
// removePeers removes up to maxBatch peers per call. It returns true when more
|
||||||
|
// peers remained to remove than the cap, so the caller re-runs.
|
||||||
|
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
|
||||||
newPeers := make([]string, 0, len(peersUpdate))
|
newPeers := make([]string, 0, len(peersUpdate))
|
||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
newPeers = append(newPeers, p.GetWgPubKey())
|
newPeers = append(newPeers, p.GetWgPubKey())
|
||||||
@@ -831,14 +851,19 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
|||||||
|
|
||||||
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
|
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
|
||||||
|
|
||||||
|
more := false
|
||||||
|
if len(toRemove) > maxBatch {
|
||||||
|
toRemove = toRemove[:maxBatch]
|
||||||
|
more = true
|
||||||
|
}
|
||||||
|
|
||||||
for _, p := range toRemove {
|
for _, p := range toRemove {
|
||||||
err := e.removePeer(p)
|
if err := e.removePeer(p); err != nil {
|
||||||
if err != nil {
|
return false, err
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
log.Infof("removed peer %s", p)
|
log.Infof("removed peer %s", p)
|
||||||
}
|
}
|
||||||
return nil
|
return more, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) removeAllPeers() error {
|
func (e *Engine) removeAllPeers() error {
|
||||||
@@ -917,19 +942,17 @@ func (e *Engine) phase(name string) func() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
// applySyncPass applies one bounded pass of the sync update under syncMsgMux and
|
||||||
started := time.Now()
|
// returns true if more peers remained than the per-pass cap. It is driven by the
|
||||||
defer func() {
|
// mapStateManager, which re-invokes it (releasing the lock between passes) until
|
||||||
duration := time.Since(started)
|
// the update is fully applied.
|
||||||
log.Infof("sync finished in %s", duration)
|
func (e *Engine) applySyncPass(update *mgmProto.SyncResponse, firstPass bool) (bool, error) {
|
||||||
e.clientMetrics.RecordSyncDuration(e.ctx, duration)
|
|
||||||
}()
|
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||||
if e.ctx.Err() != nil {
|
if e.ctx.Err() != nil {
|
||||||
return e.ctx.Err()
|
return false, e.ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||||
@@ -940,7 +963,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
err := e.updateNetbirdConfig(update.GetNetbirdConfig())
|
err := e.updateNetbirdConfig(update.GetNetbirdConfig())
|
||||||
done()
|
done()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Posture checks are bound to the network map presence:
|
// Posture checks are bound to the network map presence:
|
||||||
@@ -950,28 +973,25 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
// leave the previously applied checks untouched
|
// leave the previously applied checks untouched
|
||||||
nm := update.GetNetworkMap()
|
nm := update.GetNetworkMap()
|
||||||
if nm == nil {
|
if nm == nil {
|
||||||
return nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
done = e.phase("checks")
|
done = e.phase("checks")
|
||||||
err = e.updateChecksIfNew(update.Checks)
|
err = e.updateChecksIfNew(update.Checks)
|
||||||
done()
|
done()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
done = e.phase("persist")
|
|
||||||
e.persistSyncResponse(update)
|
|
||||||
done()
|
|
||||||
|
|
||||||
// only apply new changes and ignore old ones
|
// only apply new changes and ignore old ones
|
||||||
if err := e.updateNetworkMap(nm); err != nil {
|
more, err := e.updateNetworkMap(nm, maxPeersPerSyncPass, firstPass)
|
||||||
return err
|
if err != nil {
|
||||||
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil)
|
e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil)
|
||||||
|
|
||||||
return nil
|
return more, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateNetbirdConfig applies the management-provided NetBird configuration:
|
// updateNetbirdConfig applies the management-provided NetBird configuration:
|
||||||
@@ -1019,6 +1039,13 @@ func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
|
|||||||
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
|
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
|
||||||
// engine close) mid-call and have this write resurrect a file that was just removed.
|
// engine close) mid-call and have this write resurrect a file that was just removed.
|
||||||
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
|
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
|
||||||
|
// Only persist updates that carry a network map. Config-only updates (e.g. relay
|
||||||
|
// token rotation, STUN/TURN) have a nil NetworkMap; persisting them would overwrite
|
||||||
|
// the last full map on disk and break restore-on-restart.
|
||||||
|
if update.GetNetworkMap() == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
e.syncRespMux.RLock()
|
e.syncRespMux.RLock()
|
||||||
defer e.syncRespMux.RUnlock()
|
defer e.syncRespMux.RUnlock()
|
||||||
|
|
||||||
@@ -1306,7 +1333,24 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
}
|
}
|
||||||
e.applyInfoFlags(info)
|
e.applyInfoFlags(info)
|
||||||
|
|
||||||
err := e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
// The map-state manager converges the latest update in the background in
|
||||||
|
// bounded passes; the stream callback only hands it the newest target.
|
||||||
|
persist := func(u *mgmProto.SyncResponse) {
|
||||||
|
done := e.phase("persist")
|
||||||
|
e.persistSyncResponse(u)
|
||||||
|
done()
|
||||||
|
}
|
||||||
|
manager := newMapStateManager(e.applySyncPass, persist, func(d time.Duration) {
|
||||||
|
log.Infof("sync finished in %s", d)
|
||||||
|
e.clientMetrics.RecordSyncDuration(e.ctx, d)
|
||||||
|
})
|
||||||
|
e.shutdownWg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer e.shutdownWg.Done()
|
||||||
|
manager.run(e.ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
err := e.mgmClient.Sync(e.ctx, info, manager.SetTarget)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// happens if management is unavailable for a long time.
|
// happens if management is unavailable for a long time.
|
||||||
// We want to cancel the operation of the whole client
|
// We want to cancel the operation of the whole client
|
||||||
@@ -1357,21 +1401,107 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
// updateNetworkMap applies the wholesale parts (config, routes, ACL, DNS) in full
|
||||||
|
// and up to maxBatch peers per phase. It returns true when more peers remained
|
||||||
|
// than the cap, so the caller re-runs until convergence.
|
||||||
|
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap, maxBatch int, firstPass bool) (bool, error) {
|
||||||
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
|
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
|
||||||
if networkMap.GetPeerConfig() != nil {
|
if networkMap.GetPeerConfig() != nil {
|
||||||
err := e.updateConfig(networkMap.GetPeerConfig())
|
err := e.updateConfig(networkMap.GetPeerConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return false, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
serial := networkMap.GetSerial()
|
serial := networkMap.GetSerial()
|
||||||
if e.networkSerial > serial {
|
if e.networkSerial > serial {
|
||||||
log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial)
|
log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial)
|
||||||
return nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wholesale sections (firewall/ACL, DNS, routes, forward rules) are applied
|
||||||
|
// up-front and only once per target: they are cheap, local, idempotent and must
|
||||||
|
// be in place before peers come up (fail-closed). On the bounded re-runs that only
|
||||||
|
// drain the remaining peer batches they are skipped — the applied forward rules are
|
||||||
|
// reused from e.forwardingRules for the lazy-exclude finalize.
|
||||||
|
if firstPass {
|
||||||
|
e.applyWholesale(networkMap, serial)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||||
|
|
||||||
|
doneOffline := e.phase("offline_peers")
|
||||||
|
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||||
|
doneOffline()
|
||||||
|
|
||||||
|
// Filter out own peer from the remote peers list
|
||||||
|
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
||||||
|
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
||||||
|
for _, p := range networkMap.GetRemotePeers() {
|
||||||
|
if p.GetWgPubKey() != localPubKey {
|
||||||
|
remotePeers = append(remotePeers, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No special case for cleanup: when management signals RemotePeersIsEmpty (e.g. our
|
||||||
|
// peer was deleted), remotePeers is already empty, so the bounded diff below removes
|
||||||
|
// every peer in batches — same path as a normal update, no unbounded removeAllPeers
|
||||||
|
// held under syncMsgMux in one shot.
|
||||||
|
doneRemoved := e.phase("removed_peers")
|
||||||
|
removeMore, err := e.removePeers(remotePeers, maxBatch)
|
||||||
|
doneRemoved()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
doneModified := e.phase("modified_peers")
|
||||||
|
modifyMore, err := e.modifyPeers(remotePeers, maxBatch)
|
||||||
|
doneModified()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
doneAdded := e.phase("added_peers")
|
||||||
|
addMore, err := e.addNewPeers(remotePeers, maxBatch)
|
||||||
|
doneAdded()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// needMore signals the caller to re-run when a peer phase hit its per-pass cap.
|
||||||
|
needMore := removeMore || modifyMore || addMore
|
||||||
|
|
||||||
|
e.statusRecorder.FinishPeerListModifications()
|
||||||
|
|
||||||
|
e.updatePeerSSHHostKeys(remotePeers)
|
||||||
|
|
||||||
|
if err := e.updateSSHClientConfig(remotePeers); err != nil {
|
||||||
|
log.Warnf("failed to update SSH client config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||||
|
|
||||||
|
// Set the exclude list only once peers have fully converged (this pass added
|
||||||
|
// the last batch). It needs all target peers present in the store, and
|
||||||
|
// ExcludePeer has replace-semantics — a partial set mid-convergence would be wrong.
|
||||||
|
if !needMore {
|
||||||
|
doneLazy := e.phase("lazy_exclude")
|
||||||
|
excludedLazyPeers := e.toExcludedLazyPeers(e.forwardingRules, remotePeers)
|
||||||
|
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||||
|
doneLazy()
|
||||||
|
}
|
||||||
|
|
||||||
|
e.networkSerial = serial
|
||||||
|
|
||||||
|
return needMore, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyWholesale applies the cheap, local, idempotent map sections — lazy feature
|
||||||
|
// flag, firewall/legacy management, DNS, routes, ACL filtering, DNS forwarder and
|
||||||
|
// ingress forward rules — that must be in place before peers come up. It runs once
|
||||||
|
// per target (first pass only); the resulting forward rules are stashed in
|
||||||
|
// e.forwardingRules for the lazy-exclude finalize on the peer-converged pass.
|
||||||
|
func (e *Engine) applyWholesale(networkMap *mgmProto.NetworkMap, serial uint64) {
|
||||||
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
|
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
|
||||||
log.Errorf("failed to update lazy connection feature flag: %v", err)
|
log.Errorf("failed to update lazy connection feature flag: %v", err)
|
||||||
}
|
}
|
||||||
@@ -1444,84 +1574,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
log.Errorf("failed to update forward rules, err: %v", err)
|
log.Errorf("failed to update forward rules, err: %v", err)
|
||||||
}
|
}
|
||||||
done()
|
done()
|
||||||
|
e.forwardingRules = forwardingRules
|
||||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
|
||||||
|
|
||||||
done = e.phase("offline_peers")
|
|
||||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
|
||||||
done()
|
|
||||||
|
|
||||||
remotePeers, err := e.reconcilePeers(networkMap)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
|
||||||
done = e.phase("lazy_exclude")
|
|
||||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
|
||||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
|
||||||
done()
|
|
||||||
|
|
||||||
e.networkSerial = serial
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// reconcilePeers applies the remote peer list from the network map (removing,
|
|
||||||
// modifying and adding peers, then updating SSH config) and returns the remote
|
|
||||||
// peers with our own peer filtered out, for use by later sync steps.
|
|
||||||
func (e *Engine) reconcilePeers(networkMap *mgmProto.NetworkMap) ([]*mgmProto.RemotePeerConfig, error) {
|
|
||||||
// Filter out own peer from the remote peers list
|
|
||||||
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
|
||||||
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
|
||||||
for _, p := range networkMap.GetRemotePeers() {
|
|
||||||
if p.GetWgPubKey() != localPubKey {
|
|
||||||
remotePeers = append(remotePeers, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanup request, most likely our peer has been deleted
|
|
||||||
if networkMap.GetRemotePeersIsEmpty() {
|
|
||||||
err := e.removeAllPeers()
|
|
||||||
e.statusRecorder.FinishPeerListModifications()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return remotePeers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
done := e.phase("removed_peers")
|
|
||||||
err := e.removePeers(remotePeers)
|
|
||||||
done()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
done = e.phase("modified_peers")
|
|
||||||
err = e.modifyPeers(remotePeers)
|
|
||||||
done()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
done = e.phase("added_peers")
|
|
||||||
err = e.addNewPeers(remotePeers)
|
|
||||||
done()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
e.statusRecorder.FinishPeerListModifications()
|
|
||||||
|
|
||||||
e.updatePeerSSHHostKeys(remotePeers)
|
|
||||||
|
|
||||||
if err := e.updateSSHClientConfig(remotePeers); err != nil {
|
|
||||||
log.Warnf("failed to update SSH client config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
|
||||||
|
|
||||||
return remotePeers, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
|
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
|
||||||
@@ -1701,14 +1754,23 @@ func addrToString(addr netip.Addr) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
|
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
|
||||||
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
// addNewPeers adds up to maxBatch not-yet-present peers per call. It returns true
|
||||||
|
// when more new peers remained than the cap, so the caller re-runs.
|
||||||
|
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
|
||||||
|
added := 0
|
||||||
for _, p := range peersUpdate {
|
for _, p := range peersUpdate {
|
||||||
err := e.addNewPeer(p)
|
if _, ok := e.peerStore.PeerConn(p.GetWgPubKey()); ok {
|
||||||
if err != nil {
|
continue // already present (cheap skip), does not count toward the cap
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
if added >= maxBatch {
|
||||||
|
return true, nil // at least one more new peer remains
|
||||||
|
}
|
||||||
|
if err := e.addNewPeer(p); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
added++
|
||||||
}
|
}
|
||||||
return nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addNewPeer add peer if connection doesn't exist
|
// addNewPeer add peer if connection doesn't exist
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
RemotePeersIsEmpty: false,
|
RemotePeersIsEmpty: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Nil(t, engine.sshServer)
|
assert.Nil(t, engine.sshServer)
|
||||||
@@ -146,7 +146,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
RemotePeersIsEmpty: false,
|
RemotePeersIsEmpty: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
time.Sleep(250 * time.Millisecond)
|
||||||
@@ -159,7 +159,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
RemotePeersIsEmpty: false,
|
RemotePeersIsEmpty: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// time.Sleep(250 * time.Millisecond)
|
// time.Sleep(250 * time.Millisecond)
|
||||||
@@ -174,7 +174,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
RemotePeersIsEmpty: false,
|
RemotePeersIsEmpty: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Nil(t, engine.sshServer)
|
assert.Nil(t, engine.sshServer)
|
||||||
|
|||||||
@@ -437,7 +437,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
|
|
||||||
for _, c := range []testCase{case1, case2, case3, case4, case5, case6} {
|
for _, c := range []testCase{case1, case2, case3, case4, case5, case6} {
|
||||||
t.Run(c.name, func(t *testing.T) {
|
t.Run(c.name, func(t *testing.T) {
|
||||||
err = engine.updateNetworkMap(c.networkMap)
|
_, err = engine.updateNetworkMap(c.networkMap, maxPeersPerSyncPass, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
@@ -464,6 +464,47 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// chunked apply: with a per-pass cap smaller than the number of peers, a
|
||||||
|
// single updateNetworkMap applies one batch and reports more==true; the
|
||||||
|
// caller re-runs until convergence. (engine currently holds 0 peers.)
|
||||||
|
t.Run("chunked add converges over multiple passes", func(t *testing.T) {
|
||||||
|
nm := &mgmtProto.NetworkMap{
|
||||||
|
Serial: 6,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||||
|
}
|
||||||
|
|
||||||
|
more, err := engine.updateNetworkMap(nm, 1, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, more, "pass 1 should signal more")
|
||||||
|
require.Len(t, engine.peerStore.PeersPubKey(), 1)
|
||||||
|
|
||||||
|
more, err = engine.updateNetworkMap(nm, 1, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, more, "pass 2 should signal more")
|
||||||
|
require.Len(t, engine.peerStore.PeersPubKey(), 2)
|
||||||
|
|
||||||
|
more, err = engine.updateNetworkMap(nm, 1, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, more, "pass 3 should converge")
|
||||||
|
require.Len(t, engine.peerStore.PeersPubKey(), 3)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("chunked remove converges over multiple passes", func(t *testing.T) {
|
||||||
|
nm := &mgmtProto.NetworkMap{
|
||||||
|
Serial: 7,
|
||||||
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1}, // remove peer2, peer3
|
||||||
|
}
|
||||||
|
|
||||||
|
more, err := engine.updateNetworkMap(nm, 1, true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, more, "pass 1 should signal more (2 to remove, cap 1)")
|
||||||
|
|
||||||
|
more, err = engine.updateNetworkMap(nm, 1, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, more, "pass 2 should converge")
|
||||||
|
require.Len(t, engine.peerStore.PeersPubKey(), 1)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||||
@@ -634,7 +675,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = engine.updateNetworkMap(testCase.networkMap)
|
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
||||||
assert.Len(t, input.clientRoutes, testCase.expectedLen, "clientRoutes len should match")
|
assert.Len(t, input.clientRoutes, testCase.expectedLen, "clientRoutes len should match")
|
||||||
@@ -838,7 +879,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = engine.updateNetworkMap(testCase.networkMap)
|
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true)
|
||||||
assert.NoError(t, err, "shouldn't return error")
|
assert.NoError(t, err, "shouldn't return error")
|
||||||
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
||||||
assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match")
|
assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match")
|
||||||
|
|||||||
214
client/internal/mapsync.go
Normal file
214
client/internal/mapsync.go
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mapStateManager is the single read/write point between the management stream
|
||||||
|
// (writes) and the convergence loop (reads/applies).
|
||||||
|
//
|
||||||
|
// The stream calls SetTarget with the latest full SyncResponse — the complete
|
||||||
|
// desired state. A single background goroutine (run) applies it to the engine in
|
||||||
|
// bounded passes via apply() until converged, releasing syncMsgMux between passes
|
||||||
|
// so other subsystems interleave. If a newer update arrives mid-flight, the loop
|
||||||
|
// coalesces: it keeps converging toward the latest target and the intermediate one
|
||||||
|
// is SKIPPED — never applied on its own (logged, no onConverged).
|
||||||
|
//
|
||||||
|
// Convergence is a single comparison: appliedGen == targetGen. targetGen
|
||||||
|
// increments on every SetTarget (an internal generation counter, so it also covers
|
||||||
|
// config-only updates that carry no network-map serial).
|
||||||
|
//
|
||||||
|
// onConverged fires once for each — and only each — map that is actually processed
|
||||||
|
// (i.e. converged as the target). Skipped/superseded maps and dropped-on-error maps
|
||||||
|
// do NOT fire it. So "sync finished in X" / RecordSyncDuration always corresponds
|
||||||
|
// to a real, completed alignment.
|
||||||
|
type mapStateManager struct {
|
||||||
|
// apply performs one bounded apply pass and reports whether more passes are needed.
|
||||||
|
// firstPass is true on the first pass of a given target, so the caller can run
|
||||||
|
// wholesale (firewall/routes/DNS/forward-rules) once per target and skip it on the
|
||||||
|
// re-runs that only drain the bounded peer batches. The manager owns this signal
|
||||||
|
// because it owns the convergence boundary; the engine need not track serials for it.
|
||||||
|
apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error)
|
||||||
|
// onConverged is called once per processed map, with the elapsed time since that
|
||||||
|
// map was received (for the sync-duration metric / "sync finished" log).
|
||||||
|
onConverged func(time.Duration)
|
||||||
|
// persist snapshots an update to disk for restore-on-restart. Called once per
|
||||||
|
// update received from management (in SetTarget), including ones later coalesced
|
||||||
|
// or skipped from apply, so the on-disk state mirrors what management last sent.
|
||||||
|
// The impl skips config-only updates (nil NetworkMap). May be nil.
|
||||||
|
persist func(*mgmProto.SyncResponse)
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
target *mgmProto.SyncResponse
|
||||||
|
targetGen uint64
|
||||||
|
appliedGen uint64
|
||||||
|
targetSetAt time.Time
|
||||||
|
|
||||||
|
wake chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMapStateManager(apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error), persist func(*mgmProto.SyncResponse), onConverged func(time.Duration)) *mapStateManager {
|
||||||
|
return &mapStateManager{
|
||||||
|
apply: apply,
|
||||||
|
persist: persist,
|
||||||
|
onConverged: onConverged,
|
||||||
|
wake: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTarget records the latest update as the desired state and wakes the loop.
|
||||||
|
// It returns immediately; convergence happens in the background. Serial-based
|
||||||
|
// staleness of the network map is still enforced inside apply (updateNetworkMap).
|
||||||
|
func (m *mapStateManager) SetTarget(update *mgmProto.SyncResponse) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
// A target that has not settled yet (targetGen > appliedGen) is being superseded
|
||||||
|
// before it converged: we coalesce to the latest map and never apply this one on
|
||||||
|
// its own. It is SKIPPED — logged here, and it will not fire onConverged.
|
||||||
|
if m.target != nil && m.targetGen > m.appliedGen {
|
||||||
|
log.Debugf("sync map (gen %d) superseded before convergence, skipping", m.targetGen)
|
||||||
|
}
|
||||||
|
m.target = m.mergeTarget(m.target, update)
|
||||||
|
// Bump an internal generation counter, NOT the map serial: config-only updates
|
||||||
|
// (relay token rotation, STUN/TURN) arrive with NetworkMap == nil and carry no
|
||||||
|
// serial, yet must still be applied. Every SetTarget is therefore a distinct
|
||||||
|
// target regardless of payload. Map-serial staleness is enforced separately
|
||||||
|
// inside apply (updateNetworkMap).
|
||||||
|
m.targetGen++
|
||||||
|
m.targetSetAt = time.Now()
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case m.wake <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Persist every update received from management — once per update (not per apply
|
||||||
|
// pass), and including ones that get coalesced/skipped from apply, so the on-disk
|
||||||
|
// state always reflects the latest map management sent. Done after waking the loop
|
||||||
|
// so convergence can start in parallel with the disk write. The persist impl skips
|
||||||
|
// config-only updates (nil NetworkMap).
|
||||||
|
if m.persist != nil {
|
||||||
|
m.persist(update)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeTarget combines the currently pending target with a freshly received update
|
||||||
|
// and returns the new desired state. It is called under m.mu from SetTarget and is
|
||||||
|
// the single seam where the replace-vs-squash decision lives.
|
||||||
|
//
|
||||||
|
// Today management always sends a FULL map (the complete desired state), so the
|
||||||
|
// update simply replaces whatever was pending — prev is ignored. When management
|
||||||
|
// starts sending incremental/delta updates, squash `update` onto `prev` here; the
|
||||||
|
// rest of the manager (generation tracking, convergence, signaling) is unaffected
|
||||||
|
// because it already treats target as "the complete desired state, whatever it is".
|
||||||
|
func (m *mapStateManager) mergeTarget(prev, update *mgmProto.SyncResponse) *mgmProto.SyncResponse {
|
||||||
|
// Nothing pending to preserve (no prev, or prev already fully applied): plain replace.
|
||||||
|
if prev == nil || update == nil || m.targetGen == m.appliedGen {
|
||||||
|
return update
|
||||||
|
}
|
||||||
|
|
||||||
|
// prev still has unapplied state (targetGen > appliedGen). In the sync protocol a
|
||||||
|
// nil component means "no change", so if `update` omits a component that prev
|
||||||
|
// carried, carry prev's forward — otherwise coalescing an update that superseded a
|
||||||
|
// not-yet-applied one would silently drop the map or config it uniquely brought.
|
||||||
|
// A present component in `update` is newer and wins. Management may send map-only
|
||||||
|
// updates (nil config) and config-only updates (nil map); both are handled here.
|
||||||
|
// A nil component in `update` means "no change", so fill it in from prev — otherwise
|
||||||
|
// coalescing an update that superseded a not-yet-applied one would drop the map or
|
||||||
|
// config it uniquely carried. A present component in `update` is newer and wins.
|
||||||
|
// We mutate `update` in place: it is a fresh per-message allocation from the sync
|
||||||
|
// stream (see receiveUpdatesEvents — not reused), and persisting this squashed target
|
||||||
|
// is correct, since it is the current full (superset) desired state.
|
||||||
|
if update.GetNetworkMap() == nil && prev.GetNetworkMap() != nil {
|
||||||
|
update.NetworkMap = prev.GetNetworkMap()
|
||||||
|
update.Checks = prev.Checks // checks travel with the map
|
||||||
|
}
|
||||||
|
if update.GetNetbirdConfig() == nil && prev.GetNetbirdConfig() != nil {
|
||||||
|
update.NetbirdConfig = prev.GetNetbirdConfig()
|
||||||
|
}
|
||||||
|
return update
|
||||||
|
}
|
||||||
|
|
||||||
|
// run drives convergence until ctx is done. It is meant to run in its own goroutine.
|
||||||
|
func (m *mapStateManager) run(ctx context.Context) {
|
||||||
|
// passGen is the generation of the most recent apply() call (0 = none). A pass is
|
||||||
|
// the first for its target when its generation differs from the previous one —
|
||||||
|
// true on a fresh target and on a coalesced switch to a newer target mid-flight.
|
||||||
|
var passGen uint64
|
||||||
|
for {
|
||||||
|
m.mu.Lock()
|
||||||
|
target, tg, ag := m.target, m.targetGen, m.appliedGen
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
// Fully converged (or nothing yet): block until a new target arrives.
|
||||||
|
if target == nil || ag == tg {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-m.wake:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
firstPass := tg != passGen
|
||||||
|
passGen = tg
|
||||||
|
more, err := m.apply(target, firstPass)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Log and DROP this target — do not retry it. A deterministic failure
|
||||||
|
// (e.g. a malformed peer in the map) would otherwise spin every pass
|
||||||
|
// making no progress. Management is the source of truth and re-delivers
|
||||||
|
// the full map on the next sync, so dropping is safe; peers already
|
||||||
|
// applied this convergence stay (idempotent diffs) and the remainder is
|
||||||
|
// reconciled by the next target. Mirrors the legacy handleSync path,
|
||||||
|
// where the apply error was logged by the gRPC client and the update
|
||||||
|
// dropped. No onConverged: this target did not converge.
|
||||||
|
log.Errorf("apply sync pass, dropping update: %v", err)
|
||||||
|
m.settle(tg, false)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if more {
|
||||||
|
// keep converging the current target; syncMsgMux was released by apply
|
||||||
|
// between passes so other subsystems interleave.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// This pass converged. Mark applied and signal this one map.
|
||||||
|
m.settle(tg, true)
|
||||||
|
// if a newer target arrived mid-pass, settle is a no-op (targetGen != tg) and
|
||||||
|
// ag<tg next iteration -> apply it; this generation was skipped (logged in
|
||||||
|
// SetTarget) and is not signaled.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// settle marks generation tg as processed so the loop goes idle instead of
|
||||||
|
// re-applying the same target. It is a no-op when a newer target arrived during the
|
||||||
|
// pass (targetGen != tg), leaving appliedGen behind so that target re-applies — the
|
||||||
|
// just-finished generation was already counted as skipped.
|
||||||
|
//
|
||||||
|
// When signal is true (the pass converged) it fires onConverged once for this map;
|
||||||
|
// when false (the target was dropped on error) it does not — the map did not converge.
|
||||||
|
func (m *mapStateManager) settle(tg uint64, signal bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
if m.targetGen != tg {
|
||||||
|
m.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.appliedGen = tg
|
||||||
|
setAt := m.targetSetAt
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if signal && m.onConverged != nil {
|
||||||
|
m.onConverged(time.Since(setAt))
|
||||||
|
}
|
||||||
|
}
|
||||||
281
client/internal/mapsync_test.go
Normal file
281
client/internal/mapsync_test.go
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mergeTarget fills components missing from the incoming update with the pending
|
||||||
|
// (not-yet-applied) prev's, in place, so a coalesced/superseded update does not drop
|
||||||
|
// the map or config it uniquely carried.
|
||||||
|
func TestMapStateManager_MergeTargetPreservesPendingState(t *testing.T) {
|
||||||
|
m := newMapStateManager(nil, nil, nil)
|
||||||
|
|
||||||
|
// config-only update while a full map is still converging (targetGen > appliedGen):
|
||||||
|
// the pending map (+ checks) is filled into the update in place
|
||||||
|
m.targetGen, m.appliedGen = 5, 4
|
||||||
|
prev := &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 5}}
|
||||||
|
update := &mgmProto.SyncResponse{NetbirdConfig: &mgmProto.NetbirdConfig{}}
|
||||||
|
merged := m.mergeTarget(prev, update)
|
||||||
|
require.Same(t, update, merged, "merges in place, returns the update")
|
||||||
|
require.EqualValues(t, 5, merged.GetNetworkMap().GetSerial(), "pending map preserved")
|
||||||
|
require.NotNil(t, merged.GetNetbirdConfig(), "new config kept")
|
||||||
|
|
||||||
|
// symmetric: map-only update while a config-only update is pending -> keep the config
|
||||||
|
m.targetGen, m.appliedGen = 5, 4
|
||||||
|
prev = &mgmProto.SyncResponse{NetbirdConfig: &mgmProto.NetbirdConfig{}}
|
||||||
|
update = &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 7}}
|
||||||
|
merged = m.mergeTarget(prev, update)
|
||||||
|
require.EqualValues(t, 7, merged.GetNetworkMap().GetSerial(), "new map kept")
|
||||||
|
require.NotNil(t, merged.GetNetbirdConfig(), "pending config preserved")
|
||||||
|
|
||||||
|
// prev already applied (targetGen == appliedGen): plain replace, no fill-in
|
||||||
|
m.targetGen, m.appliedGen = 5, 5
|
||||||
|
prev = &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 5}}
|
||||||
|
update = &mgmProto.SyncResponse{NetbirdConfig: &mgmProto.NetbirdConfig{}}
|
||||||
|
merged = m.mergeTarget(prev, update)
|
||||||
|
require.Same(t, update, merged)
|
||||||
|
require.Nil(t, merged.GetNetworkMap(), "no map grafted when prev already applied")
|
||||||
|
|
||||||
|
// nothing to carry (update has a map, prev has no config): plain replace
|
||||||
|
m.targetGen, m.appliedGen = 5, 4
|
||||||
|
prev = &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 5}}
|
||||||
|
update = &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 6}}
|
||||||
|
require.Same(t, update, m.mergeTarget(prev, update))
|
||||||
|
}
|
||||||
|
|
||||||
|
// converges over the bounded passes (apply returns more until the 3rd pass),
|
||||||
|
// fires onConverged exactly once, then blocks (no further apply) until a new target.
|
||||||
|
func TestMapStateManager_ConvergesThenStops(t *testing.T) {
|
||||||
|
var passes int32
|
||||||
|
var firstPasses int32
|
||||||
|
converged := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
apply := func(_ *mgmProto.SyncResponse, firstPass bool) (bool, error) {
|
||||||
|
n := atomic.AddInt32(&passes, 1)
|
||||||
|
if firstPass {
|
||||||
|
atomic.AddInt32(&firstPasses, 1)
|
||||||
|
}
|
||||||
|
return n < 3, nil // more on pass 1 and 2, converge on pass 3
|
||||||
|
}
|
||||||
|
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go m.run(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-converged:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("manager did not converge")
|
||||||
|
}
|
||||||
|
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
|
||||||
|
require.EqualValues(t, 1, atomic.LoadInt32(&firstPasses), "firstPass true only on pass 1, false on re-runs of the same target")
|
||||||
|
|
||||||
|
// once converged the loop blocks: no further apply calls
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
require.EqualValues(t, 3, atomic.LoadInt32(&passes), "apply must not run after convergence")
|
||||||
|
}
|
||||||
|
|
||||||
|
// persist runs once per received update (not per apply pass), regardless of how many
|
||||||
|
// bounded passes that target takes to converge.
|
||||||
|
func TestMapStateManager_PersistsOncePerUpdate(t *testing.T) {
|
||||||
|
var passes, persists int32
|
||||||
|
converged := make(chan struct{}, 1)
|
||||||
|
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||||
|
n := atomic.AddInt32(&passes, 1)
|
||||||
|
return n < 3, nil // 3 passes for one target
|
||||||
|
}
|
||||||
|
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
|
||||||
|
m := newMapStateManager(apply, persist, func(time.Duration) { converged <- struct{}{} })
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go m.run(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
select {
|
||||||
|
case <-converged:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("did not converge")
|
||||||
|
}
|
||||||
|
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
|
||||||
|
require.EqualValues(t, 1, atomic.LoadInt32(&persists), "persist once per update, not per pass")
|
||||||
|
}
|
||||||
|
|
||||||
|
// every update received from management is persisted — even one that is coalesced /
|
||||||
|
// skipped from apply before it ever converges.
|
||||||
|
func TestMapStateManager_PersistsEveryUpdateIncludingSkipped(t *testing.T) {
|
||||||
|
release := make(chan struct{})
|
||||||
|
var persists int32
|
||||||
|
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||||
|
<-release // hold the first apply so the second update coalesces/skips
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
|
||||||
|
m := newMapStateManager(apply, persist, nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go m.run(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map1 -> apply blocks
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map2 supersedes map1 (skipped from apply)
|
||||||
|
close(release)
|
||||||
|
|
||||||
|
// both updates persisted even though map1 is skipped from apply
|
||||||
|
require.Eventually(t, func() bool { return atomic.LoadInt32(&persists) == 2 }, 2*time.Second, 10*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// each map that is actually processed (converged before the next arrives) fires
|
||||||
|
// onConverged exactly once — mirroring the legacy per-message handleSync timing.
|
||||||
|
func TestMapStateManager_SignalsEachProcessedMap(t *testing.T) {
|
||||||
|
converged := make(chan struct{}, 8)
|
||||||
|
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||||
|
return false, nil // converge in one pass
|
||||||
|
}
|
||||||
|
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go m.run(ctx)
|
||||||
|
|
||||||
|
const maps = 3
|
||||||
|
for i := 0; i < maps; i++ {
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
select { // wait for this map to converge before sending the next (no coalescing)
|
||||||
|
case <-converged:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatalf("map %d not signaled", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// no extra signals once the stream goes quiet
|
||||||
|
select {
|
||||||
|
case <-converged:
|
||||||
|
t.Fatal("unexpected extra onConverged")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// a map superseded before it converges is skipped: only the latest (processed) map
|
||||||
|
// fires onConverged, not the skipped one.
|
||||||
|
func TestMapStateManager_SkippedMapNotSignaled(t *testing.T) {
|
||||||
|
release := make(chan struct{})
|
||||||
|
var applies, converged atomic.Int32
|
||||||
|
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||||
|
applies.Add(1)
|
||||||
|
<-release // hold the first apply in-flight so we can queue a newer target
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go m.run(ctx)
|
||||||
|
|
||||||
|
// map1 is picked up; its apply blocks on release
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
require.Eventually(t, func() bool { return applies.Load() >= 1 }, 2*time.Second, 5*time.Millisecond)
|
||||||
|
|
||||||
|
// map2 supersedes map1 before it settled -> map1 is skipped
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
close(release) // let both applies proceed
|
||||||
|
|
||||||
|
// only the processed (latest) map signals; the skipped one does not
|
||||||
|
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
|
||||||
|
time.Sleep(150 * time.Millisecond)
|
||||||
|
require.EqualValues(t, 1, converged.Load(), "skipped map must not fire onConverged")
|
||||||
|
require.EqualValues(t, 2, applies.Load(), "both targets entered apply (map1 once, map2 once)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// an apply error drops the target: no retry of the same target, no onConverged,
|
||||||
|
// the loop goes idle — and a fresh target is still applied afterwards.
|
||||||
|
func TestMapStateManager_DropsTargetOnError(t *testing.T) {
|
||||||
|
applied := make(chan struct{}, 8)
|
||||||
|
var failNext atomic.Bool
|
||||||
|
failNext.Store(true)
|
||||||
|
|
||||||
|
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||||
|
applied <- struct{}{}
|
||||||
|
if failNext.Load() {
|
||||||
|
return false, errors.New("boom")
|
||||||
|
}
|
||||||
|
return false, nil // converge in one pass
|
||||||
|
}
|
||||||
|
var converged atomic.Int32
|
||||||
|
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go m.run(ctx)
|
||||||
|
|
||||||
|
// first target errors -> applied once, then dropped (no retry, no onConverged)
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
select {
|
||||||
|
case <-applied:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("errored target not applied")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-applied:
|
||||||
|
t.Fatal("errored target must not be retried")
|
||||||
|
case <-time.After(150 * time.Millisecond):
|
||||||
|
}
|
||||||
|
require.EqualValues(t, 0, converged.Load(), "onConverged must not fire on error")
|
||||||
|
|
||||||
|
// a new target is still processed normally and converges
|
||||||
|
failNext.Store(false)
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
select {
|
||||||
|
case <-applied:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("new target after error not applied")
|
||||||
|
}
|
||||||
|
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// a new target after convergence triggers a fresh apply; an idle (converged)
|
||||||
|
// manager does not apply on its own.
|
||||||
|
func TestMapStateManager_ReappliesOnNewTarget(t *testing.T) {
|
||||||
|
applied := make(chan struct{}, 8)
|
||||||
|
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||||
|
applied <- struct{}{}
|
||||||
|
return false, nil // converge in one pass
|
||||||
|
}
|
||||||
|
m := newMapStateManager(apply, nil, nil)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
go m.run(ctx)
|
||||||
|
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
select {
|
||||||
|
case <-applied:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("first target not applied")
|
||||||
|
}
|
||||||
|
|
||||||
|
// converged → must stay idle (no spurious apply)
|
||||||
|
select {
|
||||||
|
case <-applied:
|
||||||
|
t.Fatal("unexpected apply while idle/converged")
|
||||||
|
case <-time.After(150 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||||
|
select {
|
||||||
|
case <-applied:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("new target not applied")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -85,11 +85,7 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
|
|||||||
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
defer g.srWatcher.RemoveListener(srReconnectedChan)
|
||||||
|
|
||||||
ticker := g.initialTicker(ctx)
|
ticker := g.initialTicker(ctx)
|
||||||
defer func() {
|
defer ticker.Stop()
|
||||||
// If backoff.Ticker.send is blocked, context.Done will not close the Ticker goroutine.
|
|
||||||
// We have to explicitly call Stop, even if we use backoff.WithContext.
|
|
||||||
ticker.Stop()
|
|
||||||
}()
|
|
||||||
|
|
||||||
tickerChannel := ticker.C
|
tickerChannel := ticker.C
|
||||||
|
|
||||||
|
|||||||
@@ -1,92 +0,0 @@
|
|||||||
package guard
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newTestGuard(status connStatusFunc) *Guard {
|
|
||||||
srw := NewSRWatcher(nil, nil, nil, ice.Config{})
|
|
||||||
return NewGuard(log.WithField("test", "guard"), status, 50*time.Millisecond, srw)
|
|
||||||
}
|
|
||||||
|
|
||||||
// countBackoffTickerGoroutines returns how many goroutines are currently sitting
|
|
||||||
// in backoff/v4.(*Ticker).run (a ticker goroutine that has not exited).
|
|
||||||
func countBackoffTickerGoroutines() int {
|
|
||||||
buf := make([]byte, 1<<25) // 32MB
|
|
||||||
n := runtime.Stack(buf, true)
|
|
||||||
return strings.Count(string(buf[:n]), "backoff/v4.(*Ticker).run")
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestGuard_ReconnectTicker_NoGoroutineLeakOnShutdown reproduces a observed
|
|
||||||
// leak: after a shutdown burst, ticker run/send goroutines stay parked
|
|
||||||
// forever even though every reconnect loop has exited.
|
|
||||||
func TestGuard_ReconnectTicker_NoGoroutineLeakOnShutdown(t *testing.T) {
|
|
||||||
before := countBackoffTickerGoroutines()
|
|
||||||
|
|
||||||
const peers = 6000
|
|
||||||
cancels := make([]context.CancelFunc, 0, peers)
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
|
|
||||||
// A status check slower than the tick cadence. This models the real
|
|
||||||
// isConnectedOnAllWay/callback doing work: while the loop is busy in the
|
|
||||||
// handler, the ticker fires the next tick and parks in send(), because
|
|
||||||
// send() never selects on ctx.
|
|
||||||
slowStatus := func() ConnStatus {
|
|
||||||
time.Sleep(70 * time.Millisecond)
|
|
||||||
return ConnStatusConnected
|
|
||||||
}
|
|
||||||
|
|
||||||
for range peers {
|
|
||||||
g := newTestGuard(slowStatus)
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
cancels = append(cancels, cancel)
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
g.Start(ctx, func() {})
|
|
||||||
}()
|
|
||||||
// Force the live ticker to be a newReconnectTicker.
|
|
||||||
g.SetRelayedConnDisconnected()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Let the replacement tickers get past their 800ms initial interval, so
|
|
||||||
// many are parked in send() waiting on the (slow) consumer when we tear
|
|
||||||
// everything down.
|
|
||||||
time.Sleep(1500 * time.Millisecond)
|
|
||||||
|
|
||||||
// Shutdown burst: cancel every peer at once, like engine teardown.
|
|
||||||
for _, c := range cancels {
|
|
||||||
c()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Every reconnect loop must return
|
|
||||||
waitCh := make(chan struct{})
|
|
||||||
go func() { wg.Wait(); close(waitCh) }()
|
|
||||||
select {
|
|
||||||
case <-waitCh:
|
|
||||||
case <-time.After(30 * time.Second):
|
|
||||||
t.Fatal("not all reconnect loops returned after ctx cancel")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Give any correctly-stopped ticker goroutines time to unwind.
|
|
||||||
for range 50 {
|
|
||||||
runtime.Gosched()
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
}
|
|
||||||
|
|
||||||
leaked := countBackoffTickerGoroutines() - before
|
|
||||||
t.Logf("backoff Ticker.run goroutines still parked after teardown of %d peers: %d", peers, leaked)
|
|
||||||
if leaked > 0 {
|
|
||||||
t.Errorf("LEAK: %d backoff ticker goroutines parked after all reconnect loops exited "+
|
|
||||||
"(defer ticker.Stop() stops the initial ticker, not the live replacement)", leaked)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,191 +0,0 @@
|
|||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newExitNodeTestManager() *DefaultManager {
|
|
||||||
return &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
|
|
||||||
}
|
|
||||||
|
|
||||||
func exitRoute(netID, peer string, skipAutoApply bool) *route.Route {
|
|
||||||
return &route.Route{
|
|
||||||
NetID: route.NetID(netID),
|
|
||||||
Network: netip.MustParsePrefix("0.0.0.0/0"),
|
|
||||||
Peer: peer,
|
|
||||||
SkipAutoApply: skipAutoApply,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPickPreferredExitNode(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
info exitNodeInfo
|
|
||||||
want route.NetID
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "persisted user selection wins over management",
|
|
||||||
info: exitNodeInfo{
|
|
||||||
allIDs: []route.NetID{"a", "b", "c"},
|
|
||||||
userSelected: []route.NetID{"b"},
|
|
||||||
selectedByManagement: []route.NetID{"a"},
|
|
||||||
},
|
|
||||||
want: "b",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple user-selected self-heal to deterministic min",
|
|
||||||
info: exitNodeInfo{
|
|
||||||
allIDs: []route.NetID{"a", "b", "c"},
|
|
||||||
userSelected: []route.NetID{"c", "a"},
|
|
||||||
},
|
|
||||||
want: "a",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "explicit opt-out keeps none",
|
|
||||||
info: exitNodeInfo{
|
|
||||||
allIDs: []route.NetID{"a", "b"},
|
|
||||||
userDeselected: []route.NetID{"a", "b"},
|
|
||||||
},
|
|
||||||
want: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "fresh defaults to management auto-apply pick",
|
|
||||||
info: exitNodeInfo{
|
|
||||||
allIDs: []route.NetID{"a", "b", "c"},
|
|
||||||
selectedByManagement: []route.NetID{"b"},
|
|
||||||
},
|
|
||||||
want: "b",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no user pick and no management auto-apply selects none",
|
|
||||||
info: exitNodeInfo{
|
|
||||||
allIDs: []route.NetID{"c", "a", "b"},
|
|
||||||
},
|
|
||||||
want: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user-deselect does not block a management auto-apply sibling",
|
|
||||||
info: exitNodeInfo{
|
|
||||||
allIDs: []route.NetID{"a", "b"},
|
|
||||||
userDeselected: []route.NetID{"a"},
|
|
||||||
selectedByManagement: []route.NetID{"b"},
|
|
||||||
},
|
|
||||||
want: "b",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
assert.Equal(t, tt.want, pickPreferredExitNode(tt.info), "preferred exit node")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEnforceSingleExitNode(t *testing.T) {
|
|
||||||
m := newExitNodeTestManager()
|
|
||||||
all := []route.NetID{"a", "b", "c"}
|
|
||||||
|
|
||||||
m.enforceSingleExitNode("b", all)
|
|
||||||
assert.False(t, m.routeSelector.IsSelected("a"), "a should be deselected")
|
|
||||||
assert.True(t, m.routeSelector.IsSelected("b"), "b should be the only selected exit node")
|
|
||||||
assert.False(t, m.routeSelector.IsSelected("c"), "c should be deselected")
|
|
||||||
|
|
||||||
// Switching the preferred node moves the single selection.
|
|
||||||
m.enforceSingleExitNode("c", all)
|
|
||||||
assert.False(t, m.routeSelector.IsSelected("a"), "a stays deselected")
|
|
||||||
assert.False(t, m.routeSelector.IsSelected("b"), "b should now be deselected")
|
|
||||||
assert.True(t, m.routeSelector.IsSelected("c"), "c should now be selected")
|
|
||||||
|
|
||||||
// Empty preferred turns every exit node off.
|
|
||||||
m.enforceSingleExitNode("", all)
|
|
||||||
for _, id := range all {
|
|
||||||
assert.False(t, m.routeSelector.IsSelected(id), "no exit node should be selected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEnforceSingleExitNode_RespectsDeselectAll(t *testing.T) {
|
|
||||||
m := newExitNodeTestManager()
|
|
||||||
m.routeSelector.DeselectAllRoutes()
|
|
||||||
|
|
||||||
m.enforceSingleExitNode("b", []route.NetID{"a", "b"})
|
|
||||||
|
|
||||||
assert.True(t, m.routeSelector.IsDeselectAll(), "global deselect-all must stay in effect")
|
|
||||||
assert.False(t, m.routeSelector.IsSelected("b"), "no exit node should be forced on while deselect-all is set")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateRouteSelectorFromManagement_FreshSelectsOne(t *testing.T) {
|
|
||||||
m := newExitNodeTestManager()
|
|
||||||
routes := route.HAMap{
|
|
||||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
|
||||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
|
||||||
"lan|192.168.1.0/24": {{NetID: "lan", Network: netip.MustParsePrefix("192.168.1.0/24"), Peer: "p3"}},
|
|
||||||
"exitC|0.0.0.0/0": {exitRoute("exitC", "p4", false)},
|
|
||||||
}
|
|
||||||
|
|
||||||
m.updateRouteSelectorFromManagement(routes)
|
|
||||||
|
|
||||||
// Exactly one exit node (the deterministic first) is selected.
|
|
||||||
assert.True(t, m.routeSelector.IsSelected("exitA"), "exitA is the deterministic default")
|
|
||||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "exitB must not also be selected")
|
|
||||||
assert.False(t, m.routeSelector.IsSelected("exitC"), "exitC must not also be selected")
|
|
||||||
// Non-exit routes are left at their default-on state.
|
|
||||||
assert.True(t, m.routeSelector.IsSelected("lan"), "non-exit route selection is untouched")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateRouteSelectorFromManagement_HonorsPersistedPick(t *testing.T) {
|
|
||||||
m := newExitNodeTestManager()
|
|
||||||
routes := route.HAMap{
|
|
||||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
|
||||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
|
||||||
}
|
|
||||||
all := []route.NetID{"exitA", "exitB"}
|
|
||||||
|
|
||||||
// Simulate the state the runtime select path leaves behind: exactly one
|
|
||||||
// exit node explicitly selected, its sibling deselected.
|
|
||||||
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exitB"}, true, all))
|
|
||||||
require.NoError(t, m.routeSelector.DeselectRoutes([]route.NetID{"exitA"}, all))
|
|
||||||
|
|
||||||
m.updateRouteSelectorFromManagement(routes)
|
|
||||||
|
|
||||||
assert.True(t, m.routeSelector.IsSelected("exitB"), "persisted pick must stay selected")
|
|
||||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "the other exit node stays deselected")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateRouteSelectorFromManagement_OptOutKeepsNone(t *testing.T) {
|
|
||||||
m := newExitNodeTestManager()
|
|
||||||
routes := route.HAMap{
|
|
||||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
|
|
||||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
|
|
||||||
}
|
|
||||||
all := []route.NetID{"exitA", "exitB"}
|
|
||||||
|
|
||||||
// User deselected exit nodes and selected none.
|
|
||||||
require.NoError(t, m.routeSelector.DeselectRoutes(all, all))
|
|
||||||
|
|
||||||
m.updateRouteSelectorFromManagement(routes)
|
|
||||||
|
|
||||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "opt-out keeps exitA off")
|
|
||||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "opt-out keeps exitB off")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpdateRouteSelectorFromManagement_NoAutoApplySelectsNone(t *testing.T) {
|
|
||||||
m := newExitNodeTestManager()
|
|
||||||
// SkipAutoApply=true: management offers the exit nodes but doesn't request
|
|
||||||
// auto-activation, so none should be selected until the user picks one.
|
|
||||||
routes := route.HAMap{
|
|
||||||
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", true)},
|
|
||||||
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", true)},
|
|
||||||
}
|
|
||||||
|
|
||||||
m.updateRouteSelectorFromManagement(routes)
|
|
||||||
|
|
||||||
assert.False(t, m.routeSelector.IsSelected("exitA"), "no auto-apply keeps exitA off")
|
|
||||||
assert.False(t, m.routeSelector.IsSelected("exitB"), "no auto-apply keeps exitB off")
|
|
||||||
}
|
|
||||||
@@ -701,13 +701,7 @@ func resolveURLsToIPs(urls []string) []net.IP {
|
|||||||
return ips
|
return ips
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateRouteSelectorFromManagement reconciles exit-node selection on every
|
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
|
||||||
// network map: it keeps at most one exit node selected — the user's persisted
|
|
||||||
// pick, else whatever management marks for auto-apply (SkipAutoApply=false),
|
|
||||||
// else none. We never auto-activate an exit node the map doesn't request; it
|
|
||||||
// stays off until the user picks it. Exit nodes are mutually exclusive, but the
|
|
||||||
// RouteSelector stores routes with default-on semantics, so without this every
|
|
||||||
// available exit node would report selected at once.
|
|
||||||
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
|
||||||
m.mirrorV6ExitPairSelections(clientRoutes)
|
m.mirrorV6ExitPairSelections(clientRoutes)
|
||||||
|
|
||||||
@@ -718,14 +712,13 @@ func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HA
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
info := m.collectExitNodeInfo(clientRoutes)
|
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
|
||||||
if len(info.allIDs) == 0 {
|
if len(exitNodeInfo.allIDs) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
preferred := pickPreferredExitNode(info)
|
m.updateExitNodeSelections(exitNodeInfo)
|
||||||
m.enforceSingleExitNode(preferred, info.allIDs)
|
m.logExitNodeUpdate(exitNodeInfo)
|
||||||
m.logExitNodeUpdate(info, preferred)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
|
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
|
||||||
@@ -753,10 +746,6 @@ type exitNodeInfo struct {
|
|||||||
userDeselected []route.NetID
|
userDeselected []route.NetID
|
||||||
}
|
}
|
||||||
|
|
||||||
// collectExitNodeInfo categorises the available exit nodes by their persisted
|
|
||||||
// selection state. It keys on the base (v4) NetID and skips the synthesized
|
|
||||||
// "-v6" partner, which inherits its base's selection through the RouteSelector
|
|
||||||
// — counting it separately would double-count the pair.
|
|
||||||
func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeInfo {
|
func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeInfo {
|
||||||
var info exitNodeInfo
|
var info exitNodeInfo
|
||||||
|
|
||||||
@@ -766,9 +755,6 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
|
|||||||
}
|
}
|
||||||
|
|
||||||
netID := haID.NetID()
|
netID := haID.NetID()
|
||||||
if strings.HasSuffix(string(netID), route.V6ExitSuffix) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
info.allIDs = append(info.allIDs, netID)
|
info.allIDs = append(info.allIDs, netID)
|
||||||
|
|
||||||
if m.routeSelector.HasUserSelectionForRoute(netID) {
|
if m.routeSelector.HasUserSelectionForRoute(netID) {
|
||||||
@@ -805,52 +791,45 @@ func (m *DefaultManager) checkManagementSelection(routes []*route.Route, netID r
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// pickPreferredExitNode chooses the single exit node to keep selected. In order:
|
func (m *DefaultManager) updateExitNodeSelections(info exitNodeInfo) {
|
||||||
// - a persisted user selection wins (deterministic if several survive from
|
routesToDeselect := m.getRoutesToDeselect(info.allIDs)
|
||||||
// legacy state, so the set self-heals down to one);
|
m.deselectExitNodes(routesToDeselect)
|
||||||
// - otherwise activate only what management marks for auto-apply
|
m.selectExitNodesByManagement(info.selectedByManagement, info.allIDs)
|
||||||
// (SkipAutoApply=false); the lexicographically first if it marks several.
|
|
||||||
//
|
|
||||||
// Returns "" when neither holds — we never force an arbitrary exit node on. A
|
|
||||||
// route the map doesn't auto-apply stays off until the user selects it.
|
|
||||||
// info.userDeselected is informational only: an explicit deselect simply keeps
|
|
||||||
// that route out of both lists above, so it can't be picked.
|
|
||||||
func pickPreferredExitNode(info exitNodeInfo) route.NetID {
|
|
||||||
if len(info.userSelected) > 0 {
|
|
||||||
return minNetID(info.userSelected)
|
|
||||||
}
|
|
||||||
if len(info.selectedByManagement) > 0 {
|
|
||||||
return minNetID(info.selectedByManagement)
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// enforceSingleExitNode makes preferred the only selected exit node: every other
|
func (m *DefaultManager) getRoutesToDeselect(allIDs []route.NetID) []route.NetID {
|
||||||
// available exit node is deselected and preferred (if any) is selected, without
|
var routesToDeselect []route.NetID
|
||||||
// disturbing non-exit route selections. The whole reconciliation runs under a
|
for _, netID := range allIDs {
|
||||||
// single RouteSelector lock (SetExclusiveExitNode) so a concurrent deselect-all
|
if !m.routeSelector.HasUserSelectionForRoute(netID) {
|
||||||
// cannot interleave and get undone; a global deselect-all is left untouched so
|
routesToDeselect = append(routesToDeselect, netID)
|
||||||
// the user's "all off" stays in effect.
|
|
||||||
func (m *DefaultManager) enforceSingleExitNode(preferred route.NetID, allIDs []route.NetID) {
|
|
||||||
m.routeSelector.SetExclusiveExitNode(preferred, allIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo, preferred route.NetID) {
|
|
||||||
log.Debugf("Exit node selection: %d available, preferred=%q (%d user-selected, %d user-deselected, %d management-selected)",
|
|
||||||
len(info.allIDs), preferred, len(info.userSelected), len(info.userDeselected), len(info.selectedByManagement))
|
|
||||||
}
|
|
||||||
|
|
||||||
// minNetID returns the lexicographically smallest NetID, for a deterministic
|
|
||||||
// default pick that stays stable across restarts.
|
|
||||||
func minNetID(ids []route.NetID) route.NetID {
|
|
||||||
if len(ids) == 0 {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
best := ids[0]
|
|
||||||
for _, id := range ids[1:] {
|
|
||||||
if id < best {
|
|
||||||
best = id
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return best
|
return routesToDeselect
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *DefaultManager) deselectExitNodes(routesToDeselect []route.NetID) {
|
||||||
|
if len(routesToDeselect) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.routeSelector.DeselectRoutes(routesToDeselect, routesToDeselect)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to deselect exit nodes: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *DefaultManager) selectExitNodesByManagement(selectedByManagement []route.NetID, allIDs []route.NetID) {
|
||||||
|
if len(selectedByManagement) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.routeSelector.SelectRoutes(selectedByManagement, true, allIDs)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to select exit nodes: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo) {
|
||||||
|
log.Debugf("Updated route selector: %d exit nodes available, %d selected by management, %d user-selected, %d user-deselected",
|
||||||
|
len(info.allIDs), len(info.selectedByManagement), len(info.userSelected), len(info.userDeselected))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -115,38 +115,7 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
|||||||
clear(rs.selectedRoutes)
|
clear(rs.selectedRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetExclusiveExitNode atomically makes preferred the only selected exit node
|
// IsDeselectAll reports whether the user has explicitly deselected all routes.
|
||||||
// among exitIDs: every other ID in exitIDs is deselected and preferred (when
|
|
||||||
// non-empty) is selected, all under a single lock. Holding the lock across the
|
|
||||||
// whole reconciliation prevents a concurrent DeselectAllRoutes from interleaving
|
|
||||||
// between the deselect and select steps and being silently undone. A global
|
|
||||||
// deselect-all is left untouched so the user's "all off" stays in effect;
|
|
||||||
// non-exit routes are never referenced, so their selection is preserved.
|
|
||||||
func (rs *RouteSelector) SetExclusiveExitNode(preferred route.NetID, exitIDs []route.NetID) {
|
|
||||||
rs.mu.Lock()
|
|
||||||
defer rs.mu.Unlock()
|
|
||||||
|
|
||||||
if rs.deselectAll {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, id := range exitIDs {
|
|
||||||
if id == preferred {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
rs.deselectedRoutes[id] = struct{}{}
|
|
||||||
delete(rs.selectedRoutes, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
if preferred != "" {
|
|
||||||
delete(rs.deselectedRoutes, preferred)
|
|
||||||
rs.selectedRoutes[preferred] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsDeselectAll reports whether the global "deselect all" flag is set, i.e. the
|
|
||||||
// user explicitly disabled every route. Callers enforcing per-route invariants
|
|
||||||
// (e.g. single exit node) should leave the selection untouched when it is.
|
|
||||||
func (rs *RouteSelector) IsDeselectAll() bool {
|
func (rs *RouteSelector) IsDeselectAll() bool {
|
||||||
rs.mu.RLock()
|
rs.mu.RLock()
|
||||||
defer rs.mu.RUnlock()
|
defer rs.mu.RUnlock()
|
||||||
|
|||||||
Reference in New Issue
Block a user