mirror of
https://github.com/netbirdio/netbird.git
synced 2026-07-02 12:49:54 +00:00
Compare commits
18 Commits
dependabot
...
netmap_pro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7673067605 | ||
|
|
79567fe347 | ||
|
|
cf8d92fbb0 | ||
|
|
b70fc4015b | ||
|
|
4ef65294e9 | ||
|
|
5b5f11740a | ||
|
|
3de889d529 | ||
|
|
4988b6726e | ||
|
|
2552830184 | ||
|
|
3b8fc688f4 | ||
|
|
d82d62e818 | ||
|
|
0bf964dad7 | ||
|
|
297dcb3e24 | ||
|
|
bc22926fe0 | ||
|
|
d3f2ef9adb | ||
|
|
5bec1e8f03 | ||
|
|
74bb5c613e | ||
|
|
29dde908ae |
@@ -136,6 +136,11 @@ func (p *ProxyBind) CloseConn() error {
|
||||
return p.close()
|
||||
}
|
||||
|
||||
// InjectPacket is a no-op for the userspace proxy: first-packet reinjection is kernel-only.
|
||||
func (p *ProxyBind) InjectPacket(_ []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *ProxyBind) close() error {
|
||||
if p.remoteConn == nil {
|
||||
return nil
|
||||
|
||||
@@ -219,6 +219,17 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.pausedCond.L.Unlock()
|
||||
}
|
||||
|
||||
// InjectPacket writes b to the remote peer over the underlying transport.
|
||||
func (p *ProxyWrapper) InjectPacket(b []byte) error {
|
||||
if p.remoteConn == nil {
|
||||
return errors.New("proxy not started")
|
||||
}
|
||||
if _, err := p.remoteConn.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseConn close the remoteConn and automatically remove the conn instance from the map
|
||||
func (p *ProxyWrapper) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
|
||||
@@ -18,4 +18,9 @@ type Proxy interface {
|
||||
RedirectAs(endpoint *net.UDPAddr)
|
||||
CloseConn() error
|
||||
SetDisconnectListener(disconnected func())
|
||||
|
||||
// InjectPacket writes a raw packet directly to the remote peer over the underlying transport,
|
||||
// bypassing WireGuard. Used to replay the captured lazyconn handshake initiation. Only the
|
||||
// kernel-mode proxies act on it; the userspace proxy is a no-op since reinjection is kernel-only.
|
||||
InjectPacket(b []byte) error
|
||||
}
|
||||
|
||||
@@ -147,6 +147,17 @@ func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
|
||||
p.sendPkg = p.srcFakerConn.SendPkg
|
||||
}
|
||||
|
||||
// InjectPacket writes b to the remote peer over the underlying transport.
|
||||
func (p *WGUDPProxy) InjectPacket(b []byte) error {
|
||||
if p.remoteConn == nil {
|
||||
return errors.New("proxy not started")
|
||||
}
|
||||
if _, err := p.remoteConn.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CloseConn close the localConn
|
||||
func (p *WGUDPProxy) CloseConn() error {
|
||||
if p.cancel == nil {
|
||||
|
||||
@@ -82,6 +82,12 @@ const (
|
||||
PeerConnectionTimeoutMax = 45000 // ms
|
||||
PeerConnectionTimeoutMin = 30000 // ms
|
||||
disableAutoUpdate = "disabled"
|
||||
|
||||
// systemInfoTimeout bounds how long the sync loop waits for system info / posture
|
||||
// check gathering. The gathering runs uncancellable system calls (process scan,
|
||||
// exec, os.Stat); without this bound a single stuck call freezes handleSync, and
|
||||
// thus syncMsgMux, for as long as the call hangs (observed multi-minute freezes).
|
||||
systemInfoTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
@@ -210,6 +216,12 @@ type Engine struct {
|
||||
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
|
||||
networkSerial uint64
|
||||
|
||||
// forwardingRules holds the ingress forward rules applied for the current target.
|
||||
// Wholesale sections (incl. forward rules) run only on the first pass of a target;
|
||||
// it is stashed here so the final, peer-converged pass can build the lazy-connection
|
||||
// exclude list without recomputing them on every bounded peer pass.
|
||||
forwardingRules []firewallManager.ForwardRule
|
||||
|
||||
networkMonitor *networkmonitor.NetworkMonitor
|
||||
|
||||
sshServer sshServer
|
||||
@@ -762,7 +774,15 @@ func (e *Engine) blockLanAccess() {
|
||||
|
||||
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
|
||||
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
|
||||
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
// maxPeersPerSyncPass is the default per-pass cap on how many peers each of
|
||||
// removePeers/modifyPeers/addNewPeers applies, so syncMsgMux is held only for a
|
||||
// batch at a time and other subsystems can interleave between passes. It is
|
||||
// passed in (not read globally) so tests can exercise the multi-pass path.
|
||||
const maxPeersPerSyncPass = 300
|
||||
|
||||
// modifyPeers re-applies up to maxBatch changed peers per call. It returns true
|
||||
// when more changed peers remained than the cap, so the caller re-runs.
|
||||
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
|
||||
|
||||
// first, check if peers have been modified
|
||||
var modified []*mgmProto.RemotePeerConfig
|
||||
@@ -792,26 +812,32 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
}
|
||||
}
|
||||
|
||||
more := false
|
||||
if len(modified) > maxBatch {
|
||||
modified = modified[:maxBatch]
|
||||
more = true
|
||||
}
|
||||
|
||||
// second, close all modified connections and remove them from the state map
|
||||
for _, p := range modified {
|
||||
err := e.removePeer(p.GetWgPubKey())
|
||||
if err != nil {
|
||||
return err
|
||||
if err := e.removePeer(p.GetWgPubKey()); err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
// third, add the peer connections again
|
||||
for _, p := range modified {
|
||||
err := e.addNewPeer(p)
|
||||
if err != nil {
|
||||
return err
|
||||
if err := e.addNewPeer(p); err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return more, nil
|
||||
}
|
||||
|
||||
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
|
||||
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
|
||||
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
// removePeers removes up to maxBatch peers per call. It returns true when more
|
||||
// peers remained to remove than the cap, so the caller re-runs.
|
||||
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
|
||||
newPeers := make([]string, 0, len(peersUpdate))
|
||||
for _, p := range peersUpdate {
|
||||
newPeers = append(newPeers, p.GetWgPubKey())
|
||||
@@ -819,14 +845,19 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
|
||||
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
|
||||
|
||||
more := false
|
||||
if len(toRemove) > maxBatch {
|
||||
toRemove = toRemove[:maxBatch]
|
||||
more = true
|
||||
}
|
||||
|
||||
for _, p := range toRemove {
|
||||
err := e.removePeer(p)
|
||||
if err != nil {
|
||||
return err
|
||||
if err := e.removePeer(p); err != nil {
|
||||
return false, err
|
||||
}
|
||||
log.Infof("removed peer %s", p)
|
||||
}
|
||||
return nil
|
||||
return more, nil
|
||||
}
|
||||
|
||||
func (e *Engine) removeAllPeers() error {
|
||||
@@ -905,19 +936,17 @@ func (e *Engine) phase(name string) func() {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
started := time.Now()
|
||||
defer func() {
|
||||
duration := time.Since(started)
|
||||
log.Infof("sync finished in %s", duration)
|
||||
e.clientMetrics.RecordSyncDuration(e.ctx, duration)
|
||||
}()
|
||||
// applySyncPass applies one bounded pass of the sync update under syncMsgMux and
|
||||
// returns true if more peers remained than the per-pass cap. It is driven by the
|
||||
// mapStateManager, which re-invokes it (releasing the lock between passes) until
|
||||
// the update is fully applied.
|
||||
func (e *Engine) applySyncPass(update *mgmProto.SyncResponse, firstPass bool) (bool, error) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||
if e.ctx.Err() != nil {
|
||||
return e.ctx.Err()
|
||||
return false, e.ctx.Err()
|
||||
}
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
@@ -928,7 +957,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
err := e.updateNetbirdConfig(update.GetNetbirdConfig())
|
||||
done()
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Posture checks are bound to the network map presence:
|
||||
@@ -938,28 +967,25 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
// leave the previously applied checks untouched
|
||||
nm := update.GetNetworkMap()
|
||||
if nm == nil {
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
done = e.phase("checks")
|
||||
err = e.updateChecksIfNew(update.Checks)
|
||||
done()
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
|
||||
done = e.phase("persist")
|
||||
e.persistSyncResponse(update)
|
||||
done()
|
||||
|
||||
// only apply new changes and ignore old ones
|
||||
if err := e.updateNetworkMap(nm); err != nil {
|
||||
return err
|
||||
more, err := e.updateNetworkMap(nm, maxPeersPerSyncPass, firstPass)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil)
|
||||
|
||||
return nil
|
||||
return more, nil
|
||||
}
|
||||
|
||||
// updateNetbirdConfig applies the management-provided NetBird configuration:
|
||||
@@ -1005,6 +1031,13 @@ func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
|
||||
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
|
||||
// engine close) mid-call and have this write resurrect a file that was just removed.
|
||||
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
|
||||
// Only persist updates that carry a network map. Config-only updates (e.g. relay
|
||||
// token rotation, STUN/TURN) have a nil NetworkMap; persisting them would overwrite
|
||||
// the last full map on disk and break restore-on-restart.
|
||||
if update.GetNetworkMap() == nil {
|
||||
return
|
||||
}
|
||||
|
||||
e.syncRespMux.RLock()
|
||||
defer e.syncRespMux.RUnlock()
|
||||
|
||||
@@ -1084,11 +1117,22 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
}
|
||||
e.checks = checks
|
||||
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks, e.overlayAddresses()...)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, checks, e.overlayAddresses()...)
|
||||
if !ok {
|
||||
// Gathering timed out; skip the meta sync this cycle rather than blocking the
|
||||
// sync loop (and syncMsgMux) on a stuck system call. A later sync will retry.
|
||||
return nil
|
||||
}
|
||||
e.applyInfoFlags(info)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
return fmt.Errorf("could not sync meta: error %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyInfoFlags sets the engine's config-derived feature flags on the gathered system info.
|
||||
func (e *Engine) applyInfoFlags(info *system.Info) {
|
||||
info.SetFlags(
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
@@ -1107,12 +1151,6 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
log.Errorf("could not sync meta: error %s", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// overlayAddresses returns our own WireGuard overlay address (v4 and v6) so it
|
||||
@@ -1272,31 +1310,32 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.overlayAddresses()...)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, e.checks, e.overlayAddresses()...)
|
||||
if !ok {
|
||||
// Gathering timed out; connect the stream with base info so management
|
||||
// connectivity still comes up rather than blocking here.
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
info.SetFlags(
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
e.config.DisableFirewall,
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.DisableIPv6,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
e.applyInfoFlags(info)
|
||||
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
// The map-state manager converges the latest update in the background in
|
||||
// bounded passes; the stream callback only hands it the newest target.
|
||||
persist := func(u *mgmProto.SyncResponse) {
|
||||
done := e.phase("persist")
|
||||
e.persistSyncResponse(u)
|
||||
done()
|
||||
}
|
||||
manager := newMapStateManager(e.applySyncPass, persist, func(d time.Duration) {
|
||||
log.Infof("sync finished in %s", d)
|
||||
e.clientMetrics.RecordSyncDuration(e.ctx, d)
|
||||
})
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
manager.run(e.ctx)
|
||||
}()
|
||||
|
||||
err := e.mgmClient.Sync(e.ctx, info, manager.SetTarget)
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
@@ -1347,21 +1386,107 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
// updateNetworkMap applies the wholesale parts (config, routes, ACL, DNS) in full
|
||||
// and up to maxBatch peers per phase. It returns true when more peers remained
|
||||
// than the cap, so the caller re-runs until convergence.
|
||||
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap, maxBatch int, firstPass bool) (bool, error) {
|
||||
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
|
||||
if networkMap.GetPeerConfig() != nil {
|
||||
err := e.updateConfig(networkMap.GetPeerConfig())
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
serial := networkMap.GetSerial()
|
||||
if e.networkSerial > serial {
|
||||
log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial)
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Wholesale sections (firewall/ACL, DNS, routes, forward rules) are applied
|
||||
// up-front and only once per target: they are cheap, local, idempotent and must
|
||||
// be in place before peers come up (fail-closed). On the bounded re-runs that only
|
||||
// drain the remaining peer batches they are skipped — the applied forward rules are
|
||||
// reused from e.forwardingRules for the lazy-exclude finalize.
|
||||
if firstPass {
|
||||
e.applyWholesale(networkMap, serial)
|
||||
}
|
||||
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
|
||||
doneOffline := e.phase("offline_peers")
|
||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||
doneOffline()
|
||||
|
||||
// Filter out own peer from the remote peers list
|
||||
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
||||
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
||||
for _, p := range networkMap.GetRemotePeers() {
|
||||
if p.GetWgPubKey() != localPubKey {
|
||||
remotePeers = append(remotePeers, p)
|
||||
}
|
||||
}
|
||||
|
||||
// No special case for cleanup: when management signals RemotePeersIsEmpty (e.g. our
|
||||
// peer was deleted), remotePeers is already empty, so the bounded diff below removes
|
||||
// every peer in batches — same path as a normal update, no unbounded removeAllPeers
|
||||
// held under syncMsgMux in one shot.
|
||||
doneRemoved := e.phase("removed_peers")
|
||||
removeMore, err := e.removePeers(remotePeers, maxBatch)
|
||||
doneRemoved()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
doneModified := e.phase("modified_peers")
|
||||
modifyMore, err := e.modifyPeers(remotePeers, maxBatch)
|
||||
doneModified()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
doneAdded := e.phase("added_peers")
|
||||
addMore, err := e.addNewPeers(remotePeers, maxBatch)
|
||||
doneAdded()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// needMore signals the caller to re-run when a peer phase hit its per-pass cap.
|
||||
needMore := removeMore || modifyMore || addMore
|
||||
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
e.updatePeerSSHHostKeys(remotePeers)
|
||||
|
||||
if err := e.updateSSHClientConfig(remotePeers); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
}
|
||||
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
|
||||
// Set the exclude list only once peers have fully converged (this pass added
|
||||
// the last batch). It needs all target peers present in the store, and
|
||||
// ExcludePeer has replace-semantics — a partial set mid-convergence would be wrong.
|
||||
if !needMore {
|
||||
doneLazy := e.phase("lazy_exclude")
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(e.forwardingRules, remotePeers)
|
||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||
doneLazy()
|
||||
}
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
return needMore, nil
|
||||
}
|
||||
|
||||
// applyWholesale applies the cheap, local, idempotent map sections — lazy feature
|
||||
// flag, firewall/legacy management, DNS, routes, ACL filtering, DNS forwarder and
|
||||
// ingress forward rules — that must be in place before peers come up. It runs once
|
||||
// per target (first pass only); the resulting forward rules are stashed in
|
||||
// e.forwardingRules for the lazy-exclude finalize on the peer-converged pass.
|
||||
func (e *Engine) applyWholesale(networkMap *mgmProto.NetworkMap, serial uint64) {
|
||||
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
|
||||
log.Errorf("failed to update lazy connection feature flag: %v", err)
|
||||
}
|
||||
@@ -1434,84 +1559,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
log.Errorf("failed to update forward rules, err: %v", err)
|
||||
}
|
||||
done()
|
||||
|
||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||
|
||||
done = e.phase("offline_peers")
|
||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||
done()
|
||||
|
||||
remotePeers, err := e.reconcilePeers(networkMap)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
done = e.phase("lazy_exclude")
|
||||
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
|
||||
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
|
||||
done()
|
||||
|
||||
e.networkSerial = serial
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reconcilePeers applies the remote peer list from the network map (removing,
|
||||
// modifying and adding peers, then updating SSH config) and returns the remote
|
||||
// peers with our own peer filtered out, for use by later sync steps.
|
||||
func (e *Engine) reconcilePeers(networkMap *mgmProto.NetworkMap) ([]*mgmProto.RemotePeerConfig, error) {
|
||||
// Filter out own peer from the remote peers list
|
||||
localPubKey := e.config.WgPrivateKey.PublicKey().String()
|
||||
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
|
||||
for _, p := range networkMap.GetRemotePeers() {
|
||||
if p.GetWgPubKey() != localPubKey {
|
||||
remotePeers = append(remotePeers, p)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup request, most likely our peer has been deleted
|
||||
if networkMap.GetRemotePeersIsEmpty() {
|
||||
err := e.removeAllPeers()
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return remotePeers, nil
|
||||
}
|
||||
|
||||
done := e.phase("removed_peers")
|
||||
err := e.removePeers(remotePeers)
|
||||
done()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
done = e.phase("modified_peers")
|
||||
err = e.modifyPeers(remotePeers)
|
||||
done()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
done = e.phase("added_peers")
|
||||
err = e.addNewPeers(remotePeers)
|
||||
done()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
e.updatePeerSSHHostKeys(remotePeers)
|
||||
|
||||
if err := e.updateSSHClientConfig(remotePeers); err != nil {
|
||||
log.Warnf("failed to update SSH client config: %v", err)
|
||||
}
|
||||
|
||||
e.updateSSHServerAuth(networkMap.GetSshAuth())
|
||||
|
||||
return remotePeers, nil
|
||||
e.forwardingRules = forwardingRules
|
||||
}
|
||||
|
||||
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
|
||||
@@ -1691,14 +1739,23 @@ func addrToString(addr netip.Addr) string {
|
||||
}
|
||||
|
||||
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
|
||||
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
// addNewPeers adds up to maxBatch not-yet-present peers per call. It returns true
|
||||
// when more new peers remained than the cap, so the caller re-runs.
|
||||
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
|
||||
added := 0
|
||||
for _, p := range peersUpdate {
|
||||
err := e.addNewPeer(p)
|
||||
if err != nil {
|
||||
return err
|
||||
if _, ok := e.peerStore.PeerConn(p.GetWgPubKey()); ok {
|
||||
continue // already present (cheap skip), does not count toward the cap
|
||||
}
|
||||
if added >= maxBatch {
|
||||
return true, nil // at least one more new peer remains
|
||||
}
|
||||
if err := e.addNewPeer(p); err != nil {
|
||||
return false, err
|
||||
}
|
||||
added++
|
||||
}
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// addNewPeer add peer if connection doesn't exist
|
||||
|
||||
@@ -124,7 +124,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
@@ -146,7 +146,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
@@ -159,7 +159,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
// time.Sleep(250 * time.Millisecond)
|
||||
@@ -174,7 +174,7 @@ func TestEngine_SSH(t *testing.T) {
|
||||
RemotePeersIsEmpty: false,
|
||||
}
|
||||
|
||||
err = engine.updateNetworkMap(networkMap)
|
||||
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Nil(t, engine.sshServer)
|
||||
|
||||
@@ -178,6 +178,10 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockWGIface) MTU() uint16 {
|
||||
return 1280
|
||||
}
|
||||
|
||||
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
|
||||
return nil
|
||||
}
|
||||
@@ -433,7 +437,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
|
||||
for _, c := range []testCase{case1, case2, case3, case4, case5, case6} {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
err = engine.updateNetworkMap(c.networkMap)
|
||||
_, err = engine.updateNetworkMap(c.networkMap, maxPeersPerSyncPass, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
@@ -460,6 +464,47 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// chunked apply: with a per-pass cap smaller than the number of peers, a
|
||||
// single updateNetworkMap applies one batch and reports more==true; the
|
||||
// caller re-runs until convergence. (engine currently holds 0 peers.)
|
||||
t.Run("chunked add converges over multiple passes", func(t *testing.T) {
|
||||
nm := &mgmtProto.NetworkMap{
|
||||
Serial: 6,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
|
||||
}
|
||||
|
||||
more, err := engine.updateNetworkMap(nm, 1, true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, more, "pass 1 should signal more")
|
||||
require.Len(t, engine.peerStore.PeersPubKey(), 1)
|
||||
|
||||
more, err = engine.updateNetworkMap(nm, 1, false)
|
||||
require.NoError(t, err)
|
||||
require.True(t, more, "pass 2 should signal more")
|
||||
require.Len(t, engine.peerStore.PeersPubKey(), 2)
|
||||
|
||||
more, err = engine.updateNetworkMap(nm, 1, false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, more, "pass 3 should converge")
|
||||
require.Len(t, engine.peerStore.PeersPubKey(), 3)
|
||||
})
|
||||
|
||||
t.Run("chunked remove converges over multiple passes", func(t *testing.T) {
|
||||
nm := &mgmtProto.NetworkMap{
|
||||
Serial: 7,
|
||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1}, // remove peer2, peer3
|
||||
}
|
||||
|
||||
more, err := engine.updateNetworkMap(nm, 1, true)
|
||||
require.NoError(t, err)
|
||||
require.True(t, more, "pass 1 should signal more (2 to remove, cap 1)")
|
||||
|
||||
more, err = engine.updateNetworkMap(nm, 1, false)
|
||||
require.NoError(t, err)
|
||||
require.False(t, more, "pass 2 should converge")
|
||||
require.Len(t, engine.peerStore.PeersPubKey(), 1)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
@@ -630,7 +675,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
err = engine.updateNetworkMap(testCase.networkMap)
|
||||
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
||||
assert.Len(t, input.clientRoutes, testCase.expectedLen, "clientRoutes len should match")
|
||||
@@ -834,7 +879,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
}
|
||||
}()
|
||||
|
||||
err = engine.updateNetworkMap(testCase.networkMap)
|
||||
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true)
|
||||
assert.NoError(t, err, "shouldn't return error")
|
||||
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
|
||||
assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match")
|
||||
|
||||
@@ -44,4 +44,5 @@ type wgIfaceBase interface {
|
||||
FullStats() (*configurer.Stats, error)
|
||||
LastActivities() map[string]monotime.Time
|
||||
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
|
||||
MTU() uint16
|
||||
}
|
||||
|
||||
@@ -124,6 +124,11 @@ func (d *BindListener) ReadPackets() {
|
||||
d.done.Done()
|
||||
}
|
||||
|
||||
// CapturedPacket is unused in userspace bind mode: first-packet reinjection is kernel-only.
|
||||
func (d *BindListener) CapturedPacket() []byte {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close stops the listener and cleans up resources.
|
||||
func (d *BindListener) Close() {
|
||||
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")
|
||||
|
||||
@@ -45,10 +45,6 @@ type MockWGIfaceBind struct {
|
||||
endpointMgr *mockEndpointManager
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) RemovePeer(string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||
return nil
|
||||
}
|
||||
@@ -68,6 +64,10 @@ func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
|
||||
return m.endpointMgr
|
||||
}
|
||||
|
||||
func (m *MockWGIfaceBind) MTU() uint16 {
|
||||
return 1280
|
||||
}
|
||||
|
||||
func TestBindListener_Creation(t *testing.T) {
|
||||
mockEndpointMgr := newMockEndpointManager()
|
||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||
@@ -207,8 +207,9 @@ func TestManager_BindMode(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case peerConnID := <-mgr.OnActivityChan:
|
||||
assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
|
||||
case ev := <-mgr.OnActivityChan:
|
||||
assert.Equal(t, cfg.PeerConnID, ev.PeerConnID, "Received peer connection ID should match")
|
||||
assert.Nil(t, ev.FirstPacket, "Bind mode does not capture packets: reinjection is kernel-only")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for activity notification")
|
||||
}
|
||||
@@ -266,8 +267,8 @@ func TestManager_BindMode_MultiplePeers(t *testing.T) {
|
||||
receivedPeers := make(map[peerid.ConnID]bool)
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case peerConnID := <-mgr.OnActivityChan:
|
||||
receivedPeers[peerConnID] = true
|
||||
case ev := <-mgr.OnActivityChan:
|
||||
receivedPeers[ev.PeerConnID] = true
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for activity notifications")
|
||||
}
|
||||
|
||||
@@ -3,11 +3,13 @@ package activity
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/bufsize"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
@@ -20,6 +22,8 @@ type UDPListener struct {
|
||||
done sync.Mutex
|
||||
|
||||
isClosed atomic.Bool
|
||||
|
||||
capturedPacket []byte
|
||||
}
|
||||
|
||||
// NewUDPListener creates a listener that detects activity via UDP socket reads.
|
||||
@@ -46,9 +50,13 @@ func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener,
|
||||
}
|
||||
|
||||
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
|
||||
// The first packet that triggers activity is captured so it can be reinjected through the real
|
||||
// transport once it is established. Without this, kernel WireGuard's handshake initiation would be
|
||||
// dropped and WG would only retry after REKEY_TIMEOUT.
|
||||
func (d *UDPListener) ReadPackets() {
|
||||
for {
|
||||
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
|
||||
buf := make([]byte, int(d.wgIface.MTU())+bufsize.WGBufferOverhead)
|
||||
n, remoteAddr, err := d.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
if d.isClosed.Load() {
|
||||
d.peerCfg.Log.Infof("exit from activity listener")
|
||||
@@ -62,20 +70,24 @@ func (d *UDPListener) ReadPackets() {
|
||||
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
|
||||
continue
|
||||
}
|
||||
d.peerCfg.Log.Infof("activity detected")
|
||||
d.capturedPacket = slices.Clone(buf[:n])
|
||||
d.peerCfg.Log.Infof("activity detected, captured %d bytes for reinjection", n)
|
||||
break
|
||||
}
|
||||
|
||||
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
|
||||
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
|
||||
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
|
||||
}
|
||||
|
||||
// Ignore close error as it may return "use of closed network connection" if already closed.
|
||||
// Leave the peer in place. ConfigureWGEndpoint will UpdatePeer with the real endpoint;
|
||||
// removing the peer here wipes kernel WG's staged queue and drops the user packet that
|
||||
// triggered activation.
|
||||
_ = d.conn.Close()
|
||||
d.done.Unlock()
|
||||
}
|
||||
|
||||
// CapturedPacket returns the first packet that triggered activity, or nil if none was captured.
|
||||
// Safe to call after ReadPackets returns.
|
||||
func (d *UDPListener) CapturedPacket() []byte {
|
||||
return d.capturedPacket
|
||||
}
|
||||
|
||||
// Close stops the listener and cleans up resources.
|
||||
func (d *UDPListener) Close() {
|
||||
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())
|
||||
|
||||
@@ -19,17 +19,25 @@ import (
|
||||
type listener interface {
|
||||
ReadPackets()
|
||||
Close()
|
||||
CapturedPacket() []byte
|
||||
}
|
||||
|
||||
// Event reports activity on a managed peer. FirstPacket is the bytes that triggered activation,
|
||||
// captured for reinjection through the real transport.
|
||||
type Event struct {
|
||||
PeerConnID peerid.ConnID
|
||||
FirstPacket []byte
|
||||
}
|
||||
|
||||
type WgInterface interface {
|
||||
RemovePeer(peerKey string) error
|
||||
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||
IsUserspaceBind() bool
|
||||
Address() wgaddr.Address
|
||||
MTU() uint16
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
OnActivityChan chan peerid.ConnID
|
||||
OnActivityChan chan Event
|
||||
|
||||
wgIface WgInterface
|
||||
|
||||
@@ -41,7 +49,7 @@ type Manager struct {
|
||||
|
||||
func NewManager(wgIface WgInterface) *Manager {
|
||||
m := &Manager{
|
||||
OnActivityChan: make(chan peerid.ConnID, 1),
|
||||
OnActivityChan: make(chan Event, 1),
|
||||
wgIface: wgIface,
|
||||
peers: make(map[peerid.ConnID]listener),
|
||||
done: make(chan struct{}),
|
||||
@@ -116,12 +124,12 @@ func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
|
||||
delete(m.peers, peerConnID)
|
||||
m.mu.Unlock()
|
||||
|
||||
m.notify(peerConnID)
|
||||
m.notify(Event{PeerConnID: peerConnID, FirstPacket: l.CapturedPacket()})
|
||||
}
|
||||
|
||||
func (m *Manager) notify(peerConnID peerid.ConnID) {
|
||||
func (m *Manager) notify(ev Event) {
|
||||
select {
|
||||
case <-m.done:
|
||||
case m.OnActivityChan <- peerConnID:
|
||||
case m.OnActivityChan <- ev:
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package activity
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
@@ -25,10 +26,6 @@ func (m *MocPeer) ConnID() peerid.ConnID {
|
||||
type MocWGIface struct {
|
||||
}
|
||||
|
||||
func (m MocWGIface) RemovePeer(string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
|
||||
return nil
|
||||
}
|
||||
@@ -44,6 +41,10 @@ func (m MocWGIface) Address() wgaddr.Address {
|
||||
}
|
||||
}
|
||||
|
||||
func (m MocWGIface) MTU() uint16 {
|
||||
return 1280
|
||||
}
|
||||
|
||||
// GetPeerListener is a test helper to access listeners
|
||||
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
|
||||
m.mu.Lock()
|
||||
@@ -86,11 +87,15 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
|
||||
}
|
||||
|
||||
select {
|
||||
case peerConnID := <-mgr.OnActivityChan:
|
||||
if peerConnID != peerCfg1.PeerConnID {
|
||||
t.Fatalf("unexpected peerConnID: %v", peerConnID)
|
||||
case ev := <-mgr.OnActivityChan:
|
||||
if ev.PeerConnID != peerCfg1.PeerConnID {
|
||||
t.Fatalf("unexpected peerConnID: %v", ev.PeerConnID)
|
||||
}
|
||||
if !bytes.Equal(ev.FirstPacket, []byte{0x01, 0x02, 0x03, 0x04, 0x05}) {
|
||||
t.Fatalf("unexpected first packet: %v", ev.FirstPacket)
|
||||
}
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("timed out waiting for activity")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -130,8 +130,8 @@ func (m *Manager) Start(ctx context.Context) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case peerConnID := <-m.activityManager.OnActivityChan:
|
||||
m.onPeerActivity(peerConnID)
|
||||
case ev := <-m.activityManager.OnActivityChan:
|
||||
m.onPeerActivity(ev)
|
||||
case peerIDs := <-m.inactivityManager.InactivePeersChan():
|
||||
m.onPeerInactivityTimedOut(peerIDs)
|
||||
}
|
||||
@@ -513,13 +513,13 @@ func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string,
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
|
||||
func (m *Manager) onPeerActivity(ev activity.Event) {
|
||||
m.managedPeersMu.Lock()
|
||||
defer m.managedPeersMu.Unlock()
|
||||
|
||||
mp, ok := m.managedPeersByConnID[peerConnID]
|
||||
mp, ok := m.managedPeersByConnID[ev.PeerConnID]
|
||||
if !ok {
|
||||
log.Errorf("peer not found by conn id: %v", peerConnID)
|
||||
log.Errorf("peer not found by conn id: %v", ev.PeerConnID)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -536,7 +536,7 @@ func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
|
||||
|
||||
m.activateHAGroupPeers(mp.peerCfg)
|
||||
|
||||
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
|
||||
m.peerStore.PeerConnOpenWithFirstPacket(m.engineCtx, mp.peerCfg.PublicKey, ev.FirstPacket)
|
||||
}
|
||||
|
||||
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {
|
||||
|
||||
@@ -17,4 +17,5 @@ type WGIface interface {
|
||||
IsUserspaceBind() bool
|
||||
Address() wgaddr.Address
|
||||
LastActivities() map[string]monotime.Time
|
||||
MTU() uint16
|
||||
}
|
||||
|
||||
214
client/internal/mapsync.go
Normal file
214
client/internal/mapsync.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// mapStateManager is the single read/write point between the management stream
|
||||
// (writes) and the convergence loop (reads/applies).
|
||||
//
|
||||
// The stream calls SetTarget with the latest full SyncResponse — the complete
|
||||
// desired state. A single background goroutine (run) applies it to the engine in
|
||||
// bounded passes via apply() until converged, releasing syncMsgMux between passes
|
||||
// so other subsystems interleave. If a newer update arrives mid-flight, the loop
|
||||
// coalesces: it keeps converging toward the latest target and the intermediate one
|
||||
// is SKIPPED — never applied on its own (logged, no onConverged).
|
||||
//
|
||||
// Convergence is a single comparison: appliedGen == targetGen. targetGen
|
||||
// increments on every SetTarget (an internal generation counter, so it also covers
|
||||
// config-only updates that carry no network-map serial).
|
||||
//
|
||||
// onConverged fires once for each — and only each — map that is actually processed
|
||||
// (i.e. converged as the target). Skipped/superseded maps and dropped-on-error maps
|
||||
// do NOT fire it. So "sync finished in X" / RecordSyncDuration always corresponds
|
||||
// to a real, completed alignment.
|
||||
type mapStateManager struct {
|
||||
// apply performs one bounded apply pass and reports whether more passes are needed.
|
||||
// firstPass is true on the first pass of a given target, so the caller can run
|
||||
// wholesale (firewall/routes/DNS/forward-rules) once per target and skip it on the
|
||||
// re-runs that only drain the bounded peer batches. The manager owns this signal
|
||||
// because it owns the convergence boundary; the engine need not track serials for it.
|
||||
apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error)
|
||||
// onConverged is called once per processed map, with the elapsed time since that
|
||||
// map was received (for the sync-duration metric / "sync finished" log).
|
||||
onConverged func(time.Duration)
|
||||
// persist snapshots an update to disk for restore-on-restart. Called once per
|
||||
// update received from management (in SetTarget), including ones later coalesced
|
||||
// or skipped from apply, so the on-disk state mirrors what management last sent.
|
||||
// The impl skips config-only updates (nil NetworkMap). May be nil.
|
||||
persist func(*mgmProto.SyncResponse)
|
||||
|
||||
mu sync.Mutex
|
||||
target *mgmProto.SyncResponse
|
||||
targetGen uint64
|
||||
appliedGen uint64
|
||||
targetSetAt time.Time
|
||||
|
||||
wake chan struct{}
|
||||
}
|
||||
|
||||
func newMapStateManager(apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error), persist func(*mgmProto.SyncResponse), onConverged func(time.Duration)) *mapStateManager {
|
||||
return &mapStateManager{
|
||||
apply: apply,
|
||||
persist: persist,
|
||||
onConverged: onConverged,
|
||||
wake: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// SetTarget records the latest update as the desired state and wakes the loop.
|
||||
// It returns immediately; convergence happens in the background. Serial-based
|
||||
// staleness of the network map is still enforced inside apply (updateNetworkMap).
|
||||
func (m *mapStateManager) SetTarget(update *mgmProto.SyncResponse) error {
|
||||
m.mu.Lock()
|
||||
// A target that has not settled yet (targetGen > appliedGen) is being superseded
|
||||
// before it converged: we coalesce to the latest map and never apply this one on
|
||||
// its own. It is SKIPPED — logged here, and it will not fire onConverged.
|
||||
if m.target != nil && m.targetGen > m.appliedGen {
|
||||
log.Debugf("sync map (gen %d) superseded before convergence, skipping", m.targetGen)
|
||||
}
|
||||
m.target = m.mergeTarget(m.target, update)
|
||||
// Bump an internal generation counter, NOT the map serial: config-only updates
|
||||
// (relay token rotation, STUN/TURN) arrive with NetworkMap == nil and carry no
|
||||
// serial, yet must still be applied. Every SetTarget is therefore a distinct
|
||||
// target regardless of payload. Map-serial staleness is enforced separately
|
||||
// inside apply (updateNetworkMap).
|
||||
m.targetGen++
|
||||
m.targetSetAt = time.Now()
|
||||
m.mu.Unlock()
|
||||
|
||||
select {
|
||||
case m.wake <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
|
||||
// Persist every update received from management — once per update (not per apply
|
||||
// pass), and including ones that get coalesced/skipped from apply, so the on-disk
|
||||
// state always reflects the latest map management sent. Done after waking the loop
|
||||
// so convergence can start in parallel with the disk write. The persist impl skips
|
||||
// config-only updates (nil NetworkMap).
|
||||
if m.persist != nil {
|
||||
m.persist(update)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mergeTarget combines the currently pending target with a freshly received update
|
||||
// and returns the new desired state. It is called under m.mu from SetTarget and is
|
||||
// the single seam where the replace-vs-squash decision lives.
|
||||
//
|
||||
// Today management always sends a FULL map (the complete desired state), so the
|
||||
// update simply replaces whatever was pending — prev is ignored. When management
|
||||
// starts sending incremental/delta updates, squash `update` onto `prev` here; the
|
||||
// rest of the manager (generation tracking, convergence, signaling) is unaffected
|
||||
// because it already treats target as "the complete desired state, whatever it is".
|
||||
func (m *mapStateManager) mergeTarget(prev, update *mgmProto.SyncResponse) *mgmProto.SyncResponse {
|
||||
// Nothing pending to preserve (no prev, or prev already fully applied): plain replace.
|
||||
if prev == nil || update == nil || m.targetGen == m.appliedGen {
|
||||
return update
|
||||
}
|
||||
|
||||
// prev still has unapplied state (targetGen > appliedGen). In the sync protocol a
|
||||
// nil component means "no change", so if `update` omits a component that prev
|
||||
// carried, carry prev's forward — otherwise coalescing an update that superseded a
|
||||
// not-yet-applied one would silently drop the map or config it uniquely brought.
|
||||
// A present component in `update` is newer and wins. Management may send map-only
|
||||
// updates (nil config) and config-only updates (nil map); both are handled here.
|
||||
// A nil component in `update` means "no change", so fill it in from prev — otherwise
|
||||
// coalescing an update that superseded a not-yet-applied one would drop the map or
|
||||
// config it uniquely carried. A present component in `update` is newer and wins.
|
||||
// We mutate `update` in place: it is a fresh per-message allocation from the sync
|
||||
// stream (see receiveUpdatesEvents — not reused), and persisting this squashed target
|
||||
// is correct, since it is the current full (superset) desired state.
|
||||
if update.GetNetworkMap() == nil && prev.GetNetworkMap() != nil {
|
||||
update.NetworkMap = prev.GetNetworkMap()
|
||||
update.Checks = prev.Checks // checks travel with the map
|
||||
}
|
||||
if update.GetNetbirdConfig() == nil && prev.GetNetbirdConfig() != nil {
|
||||
update.NetbirdConfig = prev.GetNetbirdConfig()
|
||||
}
|
||||
return update
|
||||
}
|
||||
|
||||
// run drives convergence until ctx is done. It is meant to run in its own goroutine.
|
||||
func (m *mapStateManager) run(ctx context.Context) {
|
||||
// passGen is the generation of the most recent apply() call (0 = none). A pass is
|
||||
// the first for its target when its generation differs from the previous one —
|
||||
// true on a fresh target and on a coalesced switch to a newer target mid-flight.
|
||||
var passGen uint64
|
||||
for {
|
||||
m.mu.Lock()
|
||||
target, tg, ag := m.target, m.targetGen, m.appliedGen
|
||||
m.mu.Unlock()
|
||||
|
||||
// Fully converged (or nothing yet): block until a new target arrives.
|
||||
if target == nil || ag == tg {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-m.wake:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
firstPass := tg != passGen
|
||||
passGen = tg
|
||||
more, err := m.apply(target, firstPass)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
// Log and DROP this target — do not retry it. A deterministic failure
|
||||
// (e.g. a malformed peer in the map) would otherwise spin every pass
|
||||
// making no progress. Management is the source of truth and re-delivers
|
||||
// the full map on the next sync, so dropping is safe; peers already
|
||||
// applied this convergence stay (idempotent diffs) and the remainder is
|
||||
// reconciled by the next target. Mirrors the legacy handleSync path,
|
||||
// where the apply error was logged by the gRPC client and the update
|
||||
// dropped. No onConverged: this target did not converge.
|
||||
log.Errorf("apply sync pass, dropping update: %v", err)
|
||||
m.settle(tg, false)
|
||||
continue
|
||||
}
|
||||
|
||||
if more {
|
||||
// keep converging the current target; syncMsgMux was released by apply
|
||||
// between passes so other subsystems interleave.
|
||||
continue
|
||||
}
|
||||
|
||||
// This pass converged. Mark applied and signal this one map.
|
||||
m.settle(tg, true)
|
||||
// if a newer target arrived mid-pass, settle is a no-op (targetGen != tg) and
|
||||
// ag<tg next iteration -> apply it; this generation was skipped (logged in
|
||||
// SetTarget) and is not signaled.
|
||||
}
|
||||
}
|
||||
|
||||
// settle marks generation tg as processed so the loop goes idle instead of
|
||||
// re-applying the same target. It is a no-op when a newer target arrived during the
|
||||
// pass (targetGen != tg), leaving appliedGen behind so that target re-applies — the
|
||||
// just-finished generation was already counted as skipped.
|
||||
//
|
||||
// When signal is true (the pass converged) it fires onConverged once for this map;
|
||||
// when false (the target was dropped on error) it does not — the map did not converge.
|
||||
func (m *mapStateManager) settle(tg uint64, signal bool) {
|
||||
m.mu.Lock()
|
||||
if m.targetGen != tg {
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
m.appliedGen = tg
|
||||
setAt := m.targetSetAt
|
||||
m.mu.Unlock()
|
||||
|
||||
if signal && m.onConverged != nil {
|
||||
m.onConverged(time.Since(setAt))
|
||||
}
|
||||
}
|
||||
281
client/internal/mapsync_test.go
Normal file
281
client/internal/mapsync_test.go
Normal file
@@ -0,0 +1,281 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// mergeTarget fills components missing from the incoming update with the pending
|
||||
// (not-yet-applied) prev's, in place, so a coalesced/superseded update does not drop
|
||||
// the map or config it uniquely carried.
|
||||
func TestMapStateManager_MergeTargetPreservesPendingState(t *testing.T) {
|
||||
m := newMapStateManager(nil, nil, nil)
|
||||
|
||||
// config-only update while a full map is still converging (targetGen > appliedGen):
|
||||
// the pending map (+ checks) is filled into the update in place
|
||||
m.targetGen, m.appliedGen = 5, 4
|
||||
prev := &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 5}}
|
||||
update := &mgmProto.SyncResponse{NetbirdConfig: &mgmProto.NetbirdConfig{}}
|
||||
merged := m.mergeTarget(prev, update)
|
||||
require.Same(t, update, merged, "merges in place, returns the update")
|
||||
require.EqualValues(t, 5, merged.GetNetworkMap().GetSerial(), "pending map preserved")
|
||||
require.NotNil(t, merged.GetNetbirdConfig(), "new config kept")
|
||||
|
||||
// symmetric: map-only update while a config-only update is pending -> keep the config
|
||||
m.targetGen, m.appliedGen = 5, 4
|
||||
prev = &mgmProto.SyncResponse{NetbirdConfig: &mgmProto.NetbirdConfig{}}
|
||||
update = &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 7}}
|
||||
merged = m.mergeTarget(prev, update)
|
||||
require.EqualValues(t, 7, merged.GetNetworkMap().GetSerial(), "new map kept")
|
||||
require.NotNil(t, merged.GetNetbirdConfig(), "pending config preserved")
|
||||
|
||||
// prev already applied (targetGen == appliedGen): plain replace, no fill-in
|
||||
m.targetGen, m.appliedGen = 5, 5
|
||||
prev = &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 5}}
|
||||
update = &mgmProto.SyncResponse{NetbirdConfig: &mgmProto.NetbirdConfig{}}
|
||||
merged = m.mergeTarget(prev, update)
|
||||
require.Same(t, update, merged)
|
||||
require.Nil(t, merged.GetNetworkMap(), "no map grafted when prev already applied")
|
||||
|
||||
// nothing to carry (update has a map, prev has no config): plain replace
|
||||
m.targetGen, m.appliedGen = 5, 4
|
||||
prev = &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 5}}
|
||||
update = &mgmProto.SyncResponse{NetworkMap: &mgmProto.NetworkMap{Serial: 6}}
|
||||
require.Same(t, update, m.mergeTarget(prev, update))
|
||||
}
|
||||
|
||||
// converges over the bounded passes (apply returns more until the 3rd pass),
|
||||
// fires onConverged exactly once, then blocks (no further apply) until a new target.
|
||||
func TestMapStateManager_ConvergesThenStops(t *testing.T) {
|
||||
var passes int32
|
||||
var firstPasses int32
|
||||
converged := make(chan struct{}, 1)
|
||||
|
||||
apply := func(_ *mgmProto.SyncResponse, firstPass bool) (bool, error) {
|
||||
n := atomic.AddInt32(&passes, 1)
|
||||
if firstPass {
|
||||
atomic.AddInt32(&firstPasses, 1)
|
||||
}
|
||||
return n < 3, nil // more on pass 1 and 2, converge on pass 3
|
||||
}
|
||||
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.run(ctx)
|
||||
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
|
||||
select {
|
||||
case <-converged:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("manager did not converge")
|
||||
}
|
||||
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
|
||||
require.EqualValues(t, 1, atomic.LoadInt32(&firstPasses), "firstPass true only on pass 1, false on re-runs of the same target")
|
||||
|
||||
// once converged the loop blocks: no further apply calls
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
require.EqualValues(t, 3, atomic.LoadInt32(&passes), "apply must not run after convergence")
|
||||
}
|
||||
|
||||
// persist runs once per received update (not per apply pass), regardless of how many
|
||||
// bounded passes that target takes to converge.
|
||||
func TestMapStateManager_PersistsOncePerUpdate(t *testing.T) {
|
||||
var passes, persists int32
|
||||
converged := make(chan struct{}, 1)
|
||||
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||
n := atomic.AddInt32(&passes, 1)
|
||||
return n < 3, nil // 3 passes for one target
|
||||
}
|
||||
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
|
||||
m := newMapStateManager(apply, persist, func(time.Duration) { converged <- struct{}{} })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.run(ctx)
|
||||
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
select {
|
||||
case <-converged:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("did not converge")
|
||||
}
|
||||
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
|
||||
require.EqualValues(t, 1, atomic.LoadInt32(&persists), "persist once per update, not per pass")
|
||||
}
|
||||
|
||||
// every update received from management is persisted — even one that is coalesced /
|
||||
// skipped from apply before it ever converges.
|
||||
func TestMapStateManager_PersistsEveryUpdateIncludingSkipped(t *testing.T) {
|
||||
release := make(chan struct{})
|
||||
var persists int32
|
||||
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||
<-release // hold the first apply so the second update coalesces/skips
|
||||
return false, nil
|
||||
}
|
||||
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
|
||||
m := newMapStateManager(apply, persist, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.run(ctx)
|
||||
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map1 -> apply blocks
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map2 supersedes map1 (skipped from apply)
|
||||
close(release)
|
||||
|
||||
// both updates persisted even though map1 is skipped from apply
|
||||
require.Eventually(t, func() bool { return atomic.LoadInt32(&persists) == 2 }, 2*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
// each map that is actually processed (converged before the next arrives) fires
|
||||
// onConverged exactly once — mirroring the legacy per-message handleSync timing.
|
||||
func TestMapStateManager_SignalsEachProcessedMap(t *testing.T) {
|
||||
converged := make(chan struct{}, 8)
|
||||
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||
return false, nil // converge in one pass
|
||||
}
|
||||
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.run(ctx)
|
||||
|
||||
const maps = 3
|
||||
for i := 0; i < maps; i++ {
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
select { // wait for this map to converge before sending the next (no coalescing)
|
||||
case <-converged:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("map %d not signaled", i)
|
||||
}
|
||||
}
|
||||
|
||||
// no extra signals once the stream goes quiet
|
||||
select {
|
||||
case <-converged:
|
||||
t.Fatal("unexpected extra onConverged")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
|
||||
// a map superseded before it converges is skipped: only the latest (processed) map
|
||||
// fires onConverged, not the skipped one.
|
||||
func TestMapStateManager_SkippedMapNotSignaled(t *testing.T) {
|
||||
release := make(chan struct{})
|
||||
var applies, converged atomic.Int32
|
||||
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||
applies.Add(1)
|
||||
<-release // hold the first apply in-flight so we can queue a newer target
|
||||
return false, nil
|
||||
}
|
||||
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.run(ctx)
|
||||
|
||||
// map1 is picked up; its apply blocks on release
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
require.Eventually(t, func() bool { return applies.Load() >= 1 }, 2*time.Second, 5*time.Millisecond)
|
||||
|
||||
// map2 supersedes map1 before it settled -> map1 is skipped
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
close(release) // let both applies proceed
|
||||
|
||||
// only the processed (latest) map signals; the skipped one does not
|
||||
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
require.EqualValues(t, 1, converged.Load(), "skipped map must not fire onConverged")
|
||||
require.EqualValues(t, 2, applies.Load(), "both targets entered apply (map1 once, map2 once)")
|
||||
}
|
||||
|
||||
// an apply error drops the target: no retry of the same target, no onConverged,
|
||||
// the loop goes idle — and a fresh target is still applied afterwards.
|
||||
func TestMapStateManager_DropsTargetOnError(t *testing.T) {
|
||||
applied := make(chan struct{}, 8)
|
||||
var failNext atomic.Bool
|
||||
failNext.Store(true)
|
||||
|
||||
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||
applied <- struct{}{}
|
||||
if failNext.Load() {
|
||||
return false, errors.New("boom")
|
||||
}
|
||||
return false, nil // converge in one pass
|
||||
}
|
||||
var converged atomic.Int32
|
||||
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.run(ctx)
|
||||
|
||||
// first target errors -> applied once, then dropped (no retry, no onConverged)
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
select {
|
||||
case <-applied:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("errored target not applied")
|
||||
}
|
||||
select {
|
||||
case <-applied:
|
||||
t.Fatal("errored target must not be retried")
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
}
|
||||
require.EqualValues(t, 0, converged.Load(), "onConverged must not fire on error")
|
||||
|
||||
// a new target is still processed normally and converges
|
||||
failNext.Store(false)
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
select {
|
||||
case <-applied:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("new target after error not applied")
|
||||
}
|
||||
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
// a new target after convergence triggers a fresh apply; an idle (converged)
|
||||
// manager does not apply on its own.
|
||||
func TestMapStateManager_ReappliesOnNewTarget(t *testing.T) {
|
||||
applied := make(chan struct{}, 8)
|
||||
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
|
||||
applied <- struct{}{}
|
||||
return false, nil // converge in one pass
|
||||
}
|
||||
m := newMapStateManager(apply, nil, nil)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go m.run(ctx)
|
||||
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
select {
|
||||
case <-applied:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("first target not applied")
|
||||
}
|
||||
|
||||
// converged → must stay idle (no spurious apply)
|
||||
select {
|
||||
case <-applied:
|
||||
t.Fatal("unexpected apply while idle/converged")
|
||||
case <-time.After(150 * time.Millisecond):
|
||||
}
|
||||
|
||||
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
|
||||
select {
|
||||
case <-applied:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("new target not applied")
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -136,6 +137,39 @@ type Conn struct {
|
||||
// Connection stage timestamps for metrics
|
||||
metricsRecorder MetricsRecorder
|
||||
metricsStages *MetricsStages
|
||||
|
||||
// pendingFirstPacket is the lazyconn-captured handshake init, replayed once the real
|
||||
// transport is up.
|
||||
pendingFirstPacket []byte
|
||||
}
|
||||
|
||||
// injectPendingFirstPacket replays the captured handshake through the proxy if present, else
|
||||
// directly through the ICE conn. The packet is cleared only after a successful write, so a failed
|
||||
// or transport-less attempt leaves it available for a later reinjection. Caller must hold conn.mu.
|
||||
func (conn *Conn) injectPendingFirstPacket(proxy wgproxy.Proxy, directConn net.Conn) {
|
||||
pkt := conn.pendingFirstPacket
|
||||
if len(pkt) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case proxy != nil:
|
||||
if err := proxy.InjectPacket(pkt); err != nil {
|
||||
conn.Log.Debugf("failed to reinject captured first packet via proxy: %v", err)
|
||||
return
|
||||
}
|
||||
case directConn != nil:
|
||||
if _, err := directConn.Write(pkt); err != nil {
|
||||
conn.Log.Debugf("failed to reinject captured first packet via direct conn: %v", err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
conn.Log.Debugf("no transport available to reinject captured first packet")
|
||||
return
|
||||
}
|
||||
|
||||
conn.pendingFirstPacket = nil
|
||||
conn.Log.Debugf("reinjected captured first packet (%d bytes)", len(pkt))
|
||||
}
|
||||
|
||||
// NewConn creates a new not opened Conn to the remote peer.
|
||||
@@ -172,6 +206,16 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||
// be used.
|
||||
func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
return conn.open(engineCtx, nil)
|
||||
}
|
||||
|
||||
// OpenWithFirstPacket opens the connection like Open and stashes firstPacket to be replayed once
|
||||
// the real transport is established. The packet is retained only on a successful open.
|
||||
func (conn *Conn) OpenWithFirstPacket(engineCtx context.Context, firstPacket []byte) error {
|
||||
return conn.open(engineCtx, firstPacket)
|
||||
}
|
||||
|
||||
func (conn *Conn) open(engineCtx context.Context, firstPacket []byte) error {
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
@@ -227,6 +271,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
defer conn.wg.Done()
|
||||
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
||||
}()
|
||||
if len(firstPacket) > 0 {
|
||||
conn.pendingFirstPacket = slices.Clone(firstPacket)
|
||||
}
|
||||
conn.opened = true
|
||||
return nil
|
||||
}
|
||||
@@ -423,6 +470,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
||||
conn.wgProxyRelay.RedirectAs(ep)
|
||||
}
|
||||
|
||||
conn.injectPendingFirstPacket(wgProxy, iceConnInfo.RemoteConn)
|
||||
|
||||
conn.currentConnPriority = priority
|
||||
conn.statusICE.SetConnected()
|
||||
conn.updateIceState(iceConnInfo, updateTime)
|
||||
@@ -546,6 +595,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
|
||||
wgConfigWorkaround()
|
||||
|
||||
conn.injectPendingFirstPacket(wgProxy, nil)
|
||||
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
conn.statusRelay.SetConnected()
|
||||
|
||||
@@ -88,11 +88,24 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
// this can be blocked because of the connect open limiter semaphore
|
||||
if err := p.Open(ctx); err != nil {
|
||||
p.Log.Errorf("failed to open peer connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// PeerConnOpenWithFirstPacket opens the peer connection and stashes a first packet to be
|
||||
// reinjected once the real transport is established.
|
||||
func (s *Store) PeerConnOpenWithFirstPacket(ctx context.Context, pubKey string, firstPacket []byte) {
|
||||
s.peerConnsMu.RLock()
|
||||
defer s.peerConnsMu.RUnlock()
|
||||
|
||||
p, ok := s.peerConns[pubKey]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := p.OpenWithFirstPacket(ctx, firstPacket); err != nil {
|
||||
p.Log.Errorf("failed to open peer connection: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) PeerConnIdle(pubKey string) {
|
||||
|
||||
@@ -2,9 +2,11 @@ package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -174,7 +176,7 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs .
|
||||
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
|
||||
}
|
||||
|
||||
files, err := checkFileAndProcess(processCheckPaths)
|
||||
files, err := checkFileAndProcess(ctx, processCheckPaths)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -187,3 +189,43 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs .
|
||||
log.Debugf("all system information gathered successfully")
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// GetInfoWithChecksTimeout is GetInfoWithChecks bounded by timeout. Posture-check gathering
|
||||
// runs uncancellable system calls (process enumeration, os.Stat), so calling it inline can
|
||||
// block the caller for as long as such a call hangs. It runs in a goroutine instead: if it
|
||||
// does not return within timeout the caller gets (nil, false) and should proceed with
|
||||
// degraded behavior rather than block. On a gathering error it falls back to base GetInfo.
|
||||
//
|
||||
// The buffered channel lets the abandoned goroutine finish and exit once its blocking call
|
||||
// returns, so it does not leak beyond the duration of that call.
|
||||
func GetInfoWithChecksTimeout(ctx context.Context, timeout time.Duration, checks []*proto.Checks, excludeIPs ...netip.Addr) (*Info, bool) {
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
infoCh := make(chan *Info, 1)
|
||||
go func() {
|
||||
info, err := GetInfoWithChecks(ctx, checks, excludeIPs...)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = GetInfo(ctx)
|
||||
info.removeAddresses(excludeIPs...)
|
||||
}
|
||||
infoCh <- info
|
||||
}()
|
||||
|
||||
select {
|
||||
case info := <-infoCh:
|
||||
return info, true
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
log.Warnf("gathering system info with checks timed out after %s", timeout)
|
||||
} else {
|
||||
// Parent context canceled (e.g. shutdown), not a timeout.
|
||||
log.Warnf("gathering system info with checks canceled: %v", ctx.Err())
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
sysName := string(bytes.Split(utsname.Sysname[:], []byte{0})[0])
|
||||
machine := string(bytes.Split(utsname.Machine[:], []byte{0})[0])
|
||||
release := string(bytes.Split(utsname.Release[:], []byte{0})[0])
|
||||
swVersion, err := exec.Command("sw_vers", "-productVersion").Output()
|
||||
swVersion, err := exec.CommandContext(ctx, "sw_vers", "-productVersion").Output()
|
||||
if err != nil {
|
||||
log.Warnf("got an error while retrieving macOS version with sw_vers, error: %s. Using darwin version instead.\n", err)
|
||||
swVersion = []byte(release)
|
||||
|
||||
@@ -105,7 +105,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ func collectLocationInfo(info *Info) {
|
||||
}
|
||||
}
|
||||
|
||||
func checkFileAndProcess(_ []string) ([]File, error) {
|
||||
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -35,6 +36,20 @@ func Test_CustomHostname(t *testing.T) {
|
||||
assert.Equal(t, want, got.Hostname)
|
||||
}
|
||||
|
||||
func TestGetInfoWithChecksTimeout_Success(t *testing.T) {
|
||||
info, ok := GetInfoWithChecksTimeout(context.Background(), 30*time.Second, nil)
|
||||
assert.True(t, ok, "expected gathering to complete within the timeout")
|
||||
assert.NotNil(t, info)
|
||||
}
|
||||
|
||||
func TestGetInfoWithChecksTimeout_Timeout(t *testing.T) {
|
||||
// A 1ns budget expires before the (real) system-info gathering can finish, so the
|
||||
// caller must get (nil, false) instead of blocking on the in-flight goroutine.
|
||||
info, ok := GetInfoWithChecksTimeout(context.Background(), time.Nanosecond, nil)
|
||||
assert.False(t, ok, "expected timeout to be reported")
|
||||
assert.Nil(t, info)
|
||||
}
|
||||
|
||||
func Test_NetAddresses(t *testing.T) {
|
||||
addr, err := networkAddresses()
|
||||
if err != nil {
|
||||
|
||||
@@ -3,24 +3,30 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"slices"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
// getRunningProcesses returns a list of running process paths.
|
||||
func getRunningProcesses() ([]string, error) {
|
||||
processIDs, err := process.Pids()
|
||||
// getRunningProcesses returns a list of running process paths. The context bounds the work:
|
||||
// the per-PID loop bails as soon as ctx is done, and the gopsutil calls honor it where they
|
||||
// can, so a stuck enumeration cannot run unbounded.
|
||||
func getRunningProcesses(ctx context.Context) ([]string, error) {
|
||||
processIDs, err := process.PidsWithContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
processMap := make(map[string]bool)
|
||||
for _, pID := range processIDs {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := &process.Process{Pid: pID}
|
||||
|
||||
path, _ := p.Exe()
|
||||
path, _ := p.ExeWithContext(ctx)
|
||||
if path != "" {
|
||||
processMap[path] = false
|
||||
}
|
||||
@@ -35,18 +41,21 @@ func getRunningProcesses() ([]string, error) {
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
func checkFileAndProcess(ctx context.Context, paths []string) ([]File, error) {
|
||||
files := make([]File, len(paths))
|
||||
if len(paths) == 0 {
|
||||
return files, nil
|
||||
}
|
||||
|
||||
runningProcesses, err := getRunningProcesses()
|
||||
runningProcesses, err := getRunningProcesses(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, path := range paths {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
file := File{Path: path}
|
||||
|
||||
_, err := os.Stat(path)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
@@ -9,7 +10,7 @@ import (
|
||||
func Benchmark_getRunningProcesses(b *testing.B) {
|
||||
b.Run("getRunningProcesses new", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ps, err := getRunningProcesses()
|
||||
ps, err := getRunningProcesses(context.Background())
|
||||
if err != nil {
|
||||
b.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -29,12 +30,38 @@ func Benchmark_getRunningProcesses(b *testing.B) {
|
||||
}
|
||||
}
|
||||
})
|
||||
s, _ := getRunningProcesses()
|
||||
s, _ := getRunningProcesses(context.Background())
|
||||
b.Logf("getRunningProcesses returned %d processes", len(s))
|
||||
s, _ = getRunningProcessesOld()
|
||||
b.Logf("getRunningProcessesOld returned %d processes", len(s))
|
||||
}
|
||||
|
||||
func TestCheckFileAndProcess_ContextCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
// With a canceled context and non-empty paths the gathering must bail with an error
|
||||
// instead of running the (potentially blocking) process scan / stat loop.
|
||||
if _, err := checkFileAndProcess(ctx, []string{"/does/not/exist"}); err == nil {
|
||||
t.Fatal("expected error on canceled context, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckFileAndProcess_EmptyPaths(t *testing.T) {
|
||||
// No check paths means no work to do: it must return immediately with no error,
|
||||
// even on a canceled context (nothing to scan or stat).
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
files, err := checkFileAndProcess(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for empty paths: %v", err)
|
||||
}
|
||||
if len(files) != 0 {
|
||||
t.Fatalf("expected no files, got %d", len(files))
|
||||
}
|
||||
}
|
||||
|
||||
func getRunningProcessesOld() ([]string, error) {
|
||||
processes, err := process.Processes()
|
||||
if err != nil {
|
||||
|
||||
@@ -9,6 +9,8 @@ set -o pipefail
|
||||
|
||||
SED_STRIP_PADDING='s/=//g'
|
||||
|
||||
NETBIRD_EULA_URL="https://netbird.io/self-hosted-EULA"
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "docker-compose"
|
||||
@@ -139,6 +141,43 @@ read_yes_no() {
|
||||
esac
|
||||
}
|
||||
|
||||
# Gate the install on explicit acceptance of the NetBird On-Premise EULA.
|
||||
require_eula_acceptance() {
|
||||
cat > /dev/stderr <<EOF
|
||||
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
NetBird On-Premise End User License Agreement
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
NetBird's on-premise software is commercial software, licensed and not
|
||||
sold. Your installation, deployment and use are governed by the NetBird
|
||||
On-Premise End User License Agreement (the "EULA"). Please read the EULA
|
||||
in full before continuing:
|
||||
|
||||
${NETBIRD_EULA_URL}
|
||||
|
||||
By typing "accept" and continuing the installation, you confirm that you
|
||||
have read and agree to the EULA, that you are authorized to accept it on
|
||||
behalf of your organization (the "Customer"), and that the Software is
|
||||
used for business purposes only.
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
EOF
|
||||
|
||||
if [[ "${NB_ACCEPT_EULA:-}" == "yes" ]]; then
|
||||
echo "EULA accepted via NB_ACCEPT_EULA=yes." > /dev/stderr
|
||||
return 0
|
||||
fi
|
||||
|
||||
local ans=""
|
||||
echo -n 'Type "accept" to agree, or anything else to abort: ' > /dev/stderr
|
||||
read -r ans < /dev/tty
|
||||
if [[ "$ans" != "accept" ]]; then
|
||||
echo "" > /dev/stderr
|
||||
echo "EULA not accepted. Aborting installation." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
echo "" > /dev/stderr
|
||||
}
|
||||
|
||||
wait_postgres() {
|
||||
set +e
|
||||
echo -n "Waiting for postgres to become ready"
|
||||
@@ -174,6 +213,9 @@ init_environment() {
|
||||
exit 1
|
||||
fi
|
||||
|
||||
require_eula_acceptance
|
||||
NETBIRD_EULA_ACCEPTED_AT=$(date -u +%Y-%m-%dT%H:%M:%SZ)
|
||||
|
||||
echo "NetBird Enterprise bootstrap"
|
||||
echo ""
|
||||
echo "Traffic flow:"
|
||||
@@ -260,6 +302,11 @@ render_env() {
|
||||
# Generated by getting-started-enterprise.sh
|
||||
# Holds all configuration and secrets for the stack. Mode 600.
|
||||
|
||||
# NetBird On-Premise EULA acceptance
|
||||
NETBIRD_EULA_ACCEPTED=yes
|
||||
NETBIRD_EULA_ACCEPTED_AT=${NETBIRD_EULA_ACCEPTED_AT}
|
||||
NETBIRD_EULA_URL=${NETBIRD_EULA_URL}
|
||||
|
||||
# Features (set by the script; don't edit without re-running)
|
||||
NETBIRD_TRAFFIC_FLOW_ENABLED=${NETBIRD_TRAFFIC_FLOW}
|
||||
|
||||
|
||||
@@ -25,6 +25,8 @@ set -o pipefail
|
||||
OVERRIDE_FILE="docker-compose.override.yml"
|
||||
ENTERPRISE_CONFIG_FILE="config.yaml.enterprise"
|
||||
|
||||
NETBIRD_EULA_URL="https://netbird.io/self-hosted-EULA"
|
||||
|
||||
check_docker_compose() {
|
||||
if command -v docker-compose &> /dev/null; then
|
||||
echo "docker-compose"
|
||||
@@ -115,6 +117,43 @@ read_yes_no() {
|
||||
esac
|
||||
}
|
||||
|
||||
# Gate the migration on explicit acceptance of the NetBird On-Premise EULA.
|
||||
require_eula_acceptance() {
|
||||
cat > /dev/stderr <<EOF
|
||||
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
NetBird On-Premise End User License Agreement
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
NetBird's on-premise software is commercial software, licensed and not
|
||||
sold. Your installation, deployment and use are governed by the NetBird
|
||||
On-Premise End User License Agreement (the "EULA"). Please read the EULA
|
||||
in full before continuing:
|
||||
|
||||
${NETBIRD_EULA_URL}
|
||||
|
||||
By typing "accept" and continuing the installation, you confirm that you
|
||||
have read and agree to the EULA, that you are authorized to accept it on
|
||||
behalf of your organization (the "Customer"), and that the Software is
|
||||
used for business purposes only.
|
||||
──────────────────────────────────────────────────────────────────────
|
||||
EOF
|
||||
|
||||
if [[ "${NB_ACCEPT_EULA:-}" == "yes" ]]; then
|
||||
echo "EULA accepted via NB_ACCEPT_EULA=yes." > /dev/stderr
|
||||
return 0
|
||||
fi
|
||||
|
||||
local ans=""
|
||||
echo -n 'Type "accept" to agree, or anything else to abort: ' > /dev/stderr
|
||||
read -r ans < /dev/tty
|
||||
if [[ "$ans" != "accept" ]]; then
|
||||
echo "" > /dev/stderr
|
||||
echo "EULA not accepted. Aborting migration." > /dev/stderr
|
||||
exit 1
|
||||
fi
|
||||
echo "" > /dev/stderr
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection — read the operator's existing compose to find service names and
|
||||
# paths we need to override. Bail loudly if shape isn't recognised.
|
||||
@@ -436,6 +475,9 @@ init_migration() {
|
||||
echo " Network: $COMPOSE_NETWORK"
|
||||
echo ""
|
||||
|
||||
require_eula_acceptance
|
||||
NETBIRD_EULA_ACCEPTED_AT=$(date -u +%Y-%m-%dT%H:%M:%SZ)
|
||||
|
||||
local proceed
|
||||
proceed=$(read_yes_no "Proceed with migration?" "y")
|
||||
if [[ "$proceed" != "yes" ]]; then
|
||||
@@ -529,6 +571,10 @@ apply_changes() {
|
||||
{
|
||||
echo ""
|
||||
echo "# Added by migrate-to-enterprise.sh on $(date -u +%Y-%m-%dT%H:%M:%SZ)"
|
||||
echo "# NetBird On-Premise EULA accepted at install time"
|
||||
echo "NETBIRD_EULA_ACCEPTED=yes"
|
||||
echo "NETBIRD_EULA_ACCEPTED_AT=${NETBIRD_EULA_ACCEPTED_AT}"
|
||||
echo "NETBIRD_EULA_URL=${NETBIRD_EULA_URL}"
|
||||
echo "NB_LICENSE_KEY=${NB_LICENSE_KEY}"
|
||||
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
|
||||
echo "NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}"
|
||||
|
||||
Reference in New Issue
Block a user