Compare commits

...

12 Commits

Author SHA1 Message Date
Viktor Liu
159e5736cd Make lazy connections opt-out with NB_LAZY_CONN and MDM overrides 2026-06-30 17:01:24 +02:00
Viktor Liu
4ef65294e9 [client] Reinject captured first packet on lazy connection activation (#6572) 2026-06-30 11:22:25 +02:00
Bethuel Mmbaga
5b5f11740a [misc] Require on-premise EULA acceptance in enterprise scripts (#6596) 2026-06-30 11:34:23 +03:00
Riccardo Manfrin
3de889d529 [client] bound system info / posture-check gathering with a timeout to prevent sync-loop freeze (#6512)
* Wraps syestem info / posture checks into a goroutine with timeout

e.checks = checks is set before doing the SyncMeta,
so if it fails next time isCheckEquals compares true and bypasses
the update. This is to avoid another repeating the 15 seconds hang.
The checks will be synced on reconnect or posture checks changes
push from mgmt.

* Propagate context to OS calls that can leverage its cancellation / timeout

* Distinguish timeout from cancellation in logs

* Dont log twice

* Block on timeout failure and reapply the exclude_ips

* Refactor for complexity
2026-06-30 08:18:51 +02:00
Zoltan Papp
04c3d19032 [client] Skip firewall ruleset rebuild when config is unchanged (#6508)
* [client] Skip firewall ruleset rebuild when config is unchanged

ApplyFiltering rebuilt every peer and route ACL and flushed the firewall
on every sync, with no guard for an unchanged configuration. Management
re-sends the same network map far more often than it actually changes
(account-wide updates, peer meta churn), so on busy accounts this is the
dominant client-side cost of redundant syncs — especially with a large
route set and a userspace firewall.

Hash the inputs ApplyFiltering consumes (peer rules, route rules, the
empty flag and the dns-route feature flag) and skip the rebuild + flush
when the hash matches the last successfully applied update. Mirrors the
guard the DNS server already uses (previousConfigHash). The hash is only
recorded after apply and flush both succeed, so a failed update is not
skipped on the next (possibly identical) sync and gets a chance to
reconcile the firewall state.

* [client] Include config hash in ACL skip debug log

* [client] Include RoutesFirewallRulesIsEmpty in firewall config hash

* [client] Add benchmarks for firewall config hash computation
2026-06-29 19:51:50 +02:00
Zoltan Papp
3f1fb3b52d [ingest] raise duration validation limit to 24 hours (#6598)
Peer connection timing fields (signaling_to_connection_seconds) can
legitimately exceed 5 minutes during long reconnections; the previous
300 s cap caused valid data points to be rejected.
2026-06-29 19:51:25 +02:00
Viktor Liu
b434cda062 [client] Refresh signal receive liveness when worker handoff drains (#6594) 2026-06-29 12:16:47 +02:00
Zoltan Papp
0b594c639a [client] report management unhealthy while Sync stream is failing (#6575)
* fix(mgm): report management unhealthy while Sync stream is failing

The health probe (IsHealthy) only checked the gRPC transport and a
GetServerKey call. GetServerKey succeeds even when the peer cannot sync
(e.g. the server returns "settings not found"), so the probe kept marking
management Connected while the Sync stream failed in a tight retry loop —
pinning the status to "Connected" forever despite no sync ever succeeding.

Track the last Sync stream error and have IsHealthy consult it, so a
healthy transport is no longer enough to report the connection healthy.

* fix(mgm): record disconnected state when sync stream setup fails

The connectToSyncStream failure path in handleSyncStream returned early
without updating syncStreamErr, so the client could still report healthy
even when stream setup failed. Mirror the receiveUpdatesEvents error path
by calling notifyDisconnected and setSyncStreamDisconnected.
2026-06-29 11:28:58 +02:00
Zoltan Papp
deff8af59f [client] Wait for signal receive watchdog to stop before reconnect (#6574)
* [client] Wait for signal receive watchdog to stop before reconnect

The per-stream watchReceiveStream goroutine was started fire-and-forget
and never joined. On reconnect a lingering watchdog could still flip
shared client state (receiveStalled, the disconnect notifier) on the
freshly established stream, since cancelStream only cancels its own
stream context.

Track the watchdog with a WaitGroup and wait for it to exit (after
cancelling its stream) before the operation returns, so each reconnect
starts with no stale watchdog.

* [client] Bind signal receive probe to the stream context

The watchdog probe reused the generic Send, which derives its per-attempt
timeouts from the long-lived client context, so cancelStream could not
interrupt an in-flight probe. After joining the watchdog on reconnect,
watchdogWg.Wait() could then block for the full send-attempt chain.

Split Send into a context-aware send and pass the stream context down
through sendReceiveProbe, so cancelStream aborts any in-flight probe and
the watchdog exits promptly.
2026-06-29 11:24:25 +02:00
Riccardo Manfrin
5711f0e38c [client] add per-phase timing metrics for sync processing (#6533)
* Adds metrics sync phases time split to monitor costs

* Address review fixes

* Increment README.md with description on usage with debug bundles
2026-06-29 11:02:02 +02:00
Maycon Santos
1409a1325a [misc] Update careers page link (#6538) 2026-06-29 09:19:01 +02:00
Viktor Liu
4400372f37 [client] Forward non-address DNS record types through route forwarders (#6455) 2026-06-28 18:50:17 +02:00
64 changed files with 2176 additions and 331 deletions

View File

@@ -33,7 +33,7 @@
<br/>
<br/>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
🚀 <a href="https://netbird.io/careers">We are hiring! Join us at https://netbird.io/careers</a>
</strong>
</p>

View File

@@ -10,7 +10,7 @@ var (
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
EnvKeyNBLazyConn = lazyconn.EnvLazyConn
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold

View File

@@ -71,12 +71,14 @@ var (
extraIFaceBlackList []string
anonymizeFlag bool
dnsRouteInterval time.Duration
lazyConnEnabled bool
mtu uint16
profilesDisabled bool
updateSettingsDisabled bool
captureEnabled bool
networksDisabled bool
// lazyConnEnabled is the parse target for the deprecated --enable-lazy-connection
// flag. The flag is inert; the value is no longer read (use NB_LAZY_CONN instead).
lazyConnEnabled bool
mtu uint16
profilesDisabled bool
updateSettingsDisabled bool
captureEnabled bool
networksDisabled bool
rootCmd = &cobra.Command{
Use: "netbird",
@@ -210,7 +212,8 @@ func init() {
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "Deprecated: no longer used. Lazy connections are controlled by the server and the NB_LAZY_CONN environment variable.")
_ = upCmd.PersistentFlags().MarkDeprecated(enableLazyConnectionFlag, "no longer used; lazy connections are controlled by the server and the NB_LAZY_CONN environment variable")
}

View File

@@ -479,10 +479,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
req.DisableIpv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
req.LazyConnectionEnabled = &lazyConnEnabled
}
return &req
}
@@ -600,9 +596,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.DisableIPv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
ic.LazyConnectionEnabled = &lazyConnEnabled
}
return &ic, nil
}
@@ -718,9 +711,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.DisableIpv6 = &disableIPv6
}
if cmd.Flag(enableLazyConnectionFlag).Changed {
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
}
return &loginRequest, nil
}

View File

@@ -136,6 +136,11 @@ func (p *ProxyBind) CloseConn() error {
return p.close()
}
// InjectPacket is a no-op for the userspace proxy: first-packet reinjection is kernel-only.
func (p *ProxyBind) InjectPacket(_ []byte) error {
return nil
}
func (p *ProxyBind) close() error {
if p.remoteConn == nil {
return nil

View File

@@ -219,6 +219,17 @@ func (p *ProxyWrapper) RedirectAs(endpoint *net.UDPAddr) {
p.pausedCond.L.Unlock()
}
// InjectPacket writes b to the remote peer over the underlying transport.
func (p *ProxyWrapper) InjectPacket(b []byte) error {
if p.remoteConn == nil {
return errors.New("proxy not started")
}
if _, err := p.remoteConn.Write(b); err != nil {
return err
}
return nil
}
// CloseConn close the remoteConn and automatically remove the conn instance from the map
func (p *ProxyWrapper) CloseConn() error {
if p.cancel == nil {

View File

@@ -18,4 +18,9 @@ type Proxy interface {
RedirectAs(endpoint *net.UDPAddr)
CloseConn() error
SetDisconnectListener(disconnected func())
// InjectPacket writes a raw packet directly to the remote peer over the underlying transport,
// bypassing WireGuard. Used to replay the captured lazyconn handshake initiation. Only the
// kernel-mode proxies act on it; the userspace proxy is a no-op since reinjection is kernel-only.
InjectPacket(b []byte) error
}

View File

@@ -147,6 +147,17 @@ func (p *WGUDPProxy) RedirectAs(endpoint *net.UDPAddr) {
p.sendPkg = p.srcFakerConn.SendPkg
}
// InjectPacket writes b to the remote peer over the underlying transport.
func (p *WGUDPProxy) InjectPacket(b []byte) error {
if p.remoteConn == nil {
return errors.New("proxy not started")
}
if _, err := p.remoteConn.Write(b); err != nil {
return err
}
return nil
}
// CloseConn close the localConn
func (p *WGUDPProxy) CloseConn() error {
if p.cancel == nil {

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/hashicorp/go-multierror"
"github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
@@ -30,11 +31,13 @@ type Manager interface {
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]struct{}
mutex sync.Mutex
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]struct{}
previousConfigHash uint64
hasAppliedConfig bool
mutex sync.Mutex
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
@@ -57,6 +60,23 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
return
}
// Skip the full rebuild + flush when the inputs that drive the firewall
// state are byte-for-byte identical to the last successfully applied
// update. Management re-sends the same network map far more often than it
// actually changes (account-wide updates, peer meta churn), and rebuilding
// every peer/route ACL and flushing the firewall on every such sync is the
// dominant client-side cost when nothing changed. Mirrors the same guard the
// DNS server already uses (previousConfigHash). Only the fields ApplyFiltering
// consumes participate in the hash, so an unrelated map change cannot mask a
// real ACL change.
hash, err := d.firewallConfigHash(networkMap, dnsRouteFeatureFlag)
if err != nil {
log.Errorf("unable to hash firewall configuration, applying unconditionally: %v", err)
} else if d.hasAppliedConfig && d.previousConfigHash == hash {
log.Debugf("not applying the firewall configuration update as there is nothing new (hash: %d)", hash)
return
}
start := time.Now()
defer func() {
total := 0
@@ -70,13 +90,49 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
d.applyPeerACLs(networkMap)
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err)
routeErr := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag)
if routeErr != nil {
log.Errorf("Failed to apply route ACLs: %v", routeErr)
}
if err := d.firewall.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
flushErr := d.firewall.Flush()
if flushErr != nil {
log.Error("failed to flush firewall rules: ", flushErr)
}
// Only remember the hash once the firewall actually reflects this config.
// If applying or flushing failed, leave the previous hash untouched so the
// next (possibly identical) update is not skipped and gets a chance to
// reconcile the firewall state.
if err == nil && routeErr == nil && flushErr == nil {
d.previousConfigHash = hash
d.hasAppliedConfig = true
} else {
d.hasAppliedConfig = false
}
}
// firewallConfigHash hashes exactly the inputs ApplyFiltering uses to build the
// firewall state, so an identical hash means an identical resulting ruleset.
func (d *DefaultManager) firewallConfigHash(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) (uint64, error) {
return hashstructure.Hash(struct {
PeerRules []*mgmProto.FirewallRule
PeerRulesIsEmpty bool
RouteRules []*mgmProto.RouteFirewallRule
RouteRulesIsEmpty bool
DNSRouteFeatureFlag bool
}{
PeerRules: networkMap.GetFirewallRules(),
PeerRulesIsEmpty: networkMap.GetFirewallRulesIsEmpty(),
RouteRules: networkMap.GetRoutesFirewallRules(),
RouteRulesIsEmpty: networkMap.GetRoutesFirewallRulesIsEmpty(),
DNSRouteFeatureFlag: dnsRouteFeatureFlag,
}, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
}
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {

View File

@@ -1,6 +1,7 @@
package acl
import (
"fmt"
"net/netip"
"testing"
@@ -485,3 +486,149 @@ func TestPortInfoEmpty(t *testing.T) {
})
}
}
// TestApplyFilteringSkipsUnchangedConfig verifies that an identical network map
// re-applied is recognized as a no-op (hash unchanged), while a real change to
// any firewall-relevant input forces a re-apply (hash changes). This is the
// guard that prevents a full ruleset rebuild + flush on every redundant sync.
func TestApplyFilteringSkipsUnchangedConfig(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap, false)
require.True(t, acl.hasAppliedConfig, "config should be marked applied after first apply")
firstHash := acl.previousConfigHash
require.NotZero(t, firstHash)
// Re-applying the identical map must not change the recorded hash: the
// expensive rebuild path was skipped.
acl.ApplyFiltering(networkMap, false)
assert.Equal(t, firstHash, acl.previousConfigHash,
"identical re-apply must be a no-op (hash unchanged)")
// A real change must produce a different hash and re-apply.
networkMap.FirewallRules[0].Action = mgmProto.RuleAction_DROP
acl.ApplyFiltering(networkMap, false)
assert.NotEqual(t, firstHash, acl.previousConfigHash,
"changing a rule's action must force a re-apply (hash changed)")
// The dnsRouteFeatureFlag also participates in the hash.
changedHash := acl.previousConfigHash
acl.ApplyFiltering(networkMap, true)
assert.NotEqual(t, changedHash, acl.previousConfigHash,
"flipping dnsRouteFeatureFlag must force a re-apply (hash changed)")
}
func buildNetworkMap(peerRules, routeRules int) *mgmProto.NetworkMap {
nm := &mgmProto.NetworkMap{
FirewallRulesIsEmpty: peerRules == 0,
RoutesFirewallRulesIsEmpty: routeRules == 0,
}
for i := range peerRules {
nm.FirewallRules = append(nm.FirewallRules, &mgmProto.FirewallRule{
PeerIP: fmt.Sprintf("10.%d.%d.%d", i>>16&0xff, i>>8&0xff, i&0xff),
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: fmt.Sprintf("%d", 1024+i%64511),
})
}
for i := range routeRules {
nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, &mgmProto.RouteFirewallRule{
Destination: fmt.Sprintf("192.168.%d.0/24", i%256),
SourceRanges: []string{fmt.Sprintf("10.0.%d.0/24", i%256)},
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
})
}
return nm
}
func BenchmarkFirewallConfigHash_Small(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(10, 5)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
func BenchmarkFirewallConfigHash_Medium(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(100, 50)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
func BenchmarkFirewallConfigHash_Large(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(1000, 200)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
// TestFirewallConfigHashDeterministic verifies the hash is stable for equal
// inputs and order-independent for the rule slices (management does not
// guarantee rule order).
func TestFirewallConfigHashDeterministic(t *testing.T) {
d := &DefaultManager{}
nm1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{PeerIP: "10.0.0.1", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_ACCEPT, Protocol: mgmProto.RuleProtocol_TCP, Port: "22"},
{PeerIP: "10.0.0.2", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_DROP, Protocol: mgmProto.RuleProtocol_TCP, Port: "80"},
},
}
// Same rules, reversed order.
nm2 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
nm1.FirewallRules[1],
nm1.FirewallRules[0],
},
}
h1, err := d.firewallConfigHash(nm1, false)
require.NoError(t, err)
h2, err := d.firewallConfigHash(nm2, false)
require.NoError(t, err)
assert.Equal(t, h1, h2, "hash must be order-independent for rule slices")
}

View File

@@ -322,7 +322,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
a.config.BlockLANAccess,
a.config.BlockInbound,
a.config.DisableIPv6,
a.config.LazyConnectionEnabled,
a.config.EnableSSHRoot,
a.config.EnableSSHSFTP,
a.config.EnableSSHLocalPortForwarding,

View File

@@ -16,6 +16,16 @@ import (
"github.com/netbirdio/netbird/route"
)
// lazyForce is the resolved local decision for lazy connections, layered above the
// management feature flag. lazyForceNone defers to management.
type lazyForce int
const (
lazyForceNone lazyForce = iota
lazyForceOn
lazyForceOff
)
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
//
// The connection manager is responsible for:
@@ -28,7 +38,7 @@ type ConnMgr struct {
peerStore *peerstore.Store
statusRecorder *peer.Status
iface lazyconn.WGIface
enabledLocally bool
force lazyForce
rosenpassEnabled bool
lazyConnMgr *manager.Manager
@@ -43,28 +53,34 @@ func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerSto
peerStore: peerStore,
statusRecorder: statusRecorder,
iface: iface,
force: resolveLazyForce(engineConfig.LazyConnection),
rosenpassEnabled: engineConfig.RosenpassEnabled,
}
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
e.enabledLocally = true
}
return e
}
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
// Start initializes the connection manager. It starts the lazy connection manager when a
// local override forces it on; with no local override it waits for the management feature flag.
func (e *ConnMgr) Start(ctx context.Context) {
if e.lazyConnMgr != nil {
log.Errorf("lazy connection manager is already started")
return
}
if !e.enabledLocally {
log.Infof("lazy connection manager is disabled")
switch e.force {
case lazyForceOff:
log.Infof("lazy connection manager is disabled by local override (%s or MDM policy)", lazyconn.EnvLazyConn)
e.statusRecorder.UpdateLazyConnection(false)
return
case lazyForceNone:
log.Infof("lazy connection manager is managed by the management feature flag")
e.statusRecorder.UpdateLazyConnection(false)
return
}
if e.rosenpassEnabled {
log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started")
e.statusRecorder.UpdateLazyConnection(false)
return
}
@@ -76,8 +92,8 @@ func (e *ConnMgr) Start(ctx context.Context) {
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
// do not disable lazy connection manager if it was enabled by env var
if e.enabledLocally {
// a local override (NB_LAZY_CONN or local config) takes precedence over management
if e.force != lazyForceNone {
return nil
}
@@ -89,6 +105,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
if e.rosenpassEnabled {
log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started")
e.statusRecorder.UpdateLazyConnection(false)
return nil
}
@@ -98,6 +115,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
return e.addPeersToLazyConnManager()
} else {
if e.lazyConnMgr == nil {
e.statusRecorder.UpdateLazyConnection(false)
return nil
}
log.Infof("lazy connection manager is disabled by management feature flag")
@@ -309,6 +327,25 @@ func (e *ConnMgr) isStartedWithLazyMgr() bool {
return e.lazyConnMgr != nil && e.lazyCtxCancel != nil
}
// resolveLazyForce determines the local override. NB_LAZY_CONN takes precedence; when it
// is unset the MDM policy override (mdmState) applies. Either wins in both directions over
// the management feature flag; StateUnset for both defers to management.
func resolveLazyForce(mdmState lazyconn.State) lazyForce {
state := lazyconn.EnvState()
if state == lazyconn.StateUnset {
state = mdmState
}
switch state {
case lazyconn.StateOn:
return lazyForceOn
case lazyconn.StateOff:
return lazyForceOff
default:
return lazyForceNone
}
}
func inactivityThresholdEnv() *time.Duration {
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
if envValue == "" {

View File

@@ -0,0 +1,40 @@
package internal
import (
"os"
"testing"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
func TestResolveLazyForce(t *testing.T) {
tests := []struct {
name string
env string
envSet bool
mdm lazyconn.State
want lazyForce
}{
{name: "env unset, mdm unset -> defer to management", mdm: lazyconn.StateUnset, want: lazyForceNone},
{name: "env on -> force on", env: "on", envSet: true, mdm: lazyconn.StateUnset, want: lazyForceOn},
{name: "env off -> force off", env: "off", envSet: true, mdm: lazyconn.StateUnset, want: lazyForceOff},
{name: "env unset, mdm on -> force on", mdm: lazyconn.StateOn, want: lazyForceOn},
{name: "env unset, mdm off -> force off", mdm: lazyconn.StateOff, want: lazyForceOff},
{name: "env on beats mdm off", env: "on", envSet: true, mdm: lazyconn.StateOff, want: lazyForceOn},
{name: "env off beats mdm on", env: "off", envSet: true, mdm: lazyconn.StateOn, want: lazyForceOff},
{name: "unrecognized env, mdm on -> mdm wins", env: "auto", envSet: true, mdm: lazyconn.StateOn, want: lazyForceOn},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv(lazyconn.EnvLazyConn, tt.env)
if !tt.envSet {
os.Unsetenv(lazyconn.EnvLazyConn)
}
if got := resolveLazyForce(tt.mdm); got != tt.want {
t.Fatalf("resolveLazyForce(%v) = %v, want %v", tt.mdm, got, tt.want)
}
})
}
}

View File

@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/metrics"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -596,7 +597,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
BlockInbound: config.BlockInbound,
DisableIPv6: config.DisableIPv6,
LazyConnectionEnabled: config.LazyConnectionEnabled,
LazyConnection: lazyconn.ParseState(config.LazyConnection),
MTU: selectMTU(config.MTU, peerConfig.Mtu),
LogPath: logPath,
@@ -670,7 +671,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.BlockLANAccess,
config.BlockInbound,
config.DisableIPv6,
config.LazyConnectionEnabled,
config.EnableSSHRoot,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,

View File

@@ -681,7 +681,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
configContent.WriteString(fmt.Sprintf("ClientCertKeyPath: %s\n", g.internalConfig.ClientCertKeyPath))
}
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
configContent.WriteString(fmt.Sprintf("LazyConnection: %q\n", g.internalConfig.LazyConnection))
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
}

View File

@@ -885,7 +885,7 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
DNSRouteInterval: 5 * time.Second,
ClientCertPath: "/tmp/cert",
ClientCertKeyPath: "/tmp/key",
LazyConnectionEnabled: true,
LazyConnection: "on",
MTU: 1280,
}

View File

@@ -8,6 +8,7 @@ import (
"errors"
"net"
"net/netip"
"slices"
"strings"
"github.com/miekg/dns"
@@ -167,7 +168,10 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
case dns.TypeA:
alternativeNetwork = "ip6"
default:
return dns.RcodeNameError
// Non-address types reach LookupIP only unexpectedly; without an
// address pair to probe we cannot prove the name is absent, so answer
// NODATA rather than a poisoning NXDOMAIN.
return dns.RcodeSuccess
}
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
@@ -184,6 +188,230 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
return dns.RcodeSuccess
}
// RecordResolver is the host resolver surface used to forward non-address
// record queries. net.DefaultResolver satisfies it.
type RecordResolver interface {
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
LookupCNAME(ctx context.Context, host string) (string, error)
LookupAddr(ctx context.Context, addr string) ([]string, error)
}
// LookupRecords resolves a non-address DNS record type through the host
// resolver and returns the resource records and the DNS rcode. Types the host
// resolver cannot answer (anything not covered by the net.Resolver Lookup*
// methods) yield NODATA so that a routed name is never poisoned with NXDOMAIN
// for an unsupported type.
func LookupRecords(ctx context.Context, r RecordResolver, name string, qtype uint16, ttl uint32) ([]dns.RR, int) {
fqdn := dns.Fqdn(name)
switch qtype {
case dns.TypeMX:
return lookupMX(ctx, r, name, fqdn, ttl)
case dns.TypeTXT:
return lookupTXT(ctx, r, name, fqdn, ttl)
case dns.TypeNS:
return lookupNS(ctx, r, name, fqdn, ttl)
case dns.TypeSRV:
return lookupSRV(ctx, r, name, fqdn, ttl)
case dns.TypeCNAME:
return lookupCNAME(ctx, r, name, fqdn, ttl)
case dns.TypePTR:
return lookupPTR(ctx, r, name, fqdn, ttl)
default:
return nil, dns.RcodeSuccess
}
}
func recordHeader(fqdn string, rrtype uint16, ttl uint32) dns.RR_Header {
return dns.RR_Header{Name: fqdn, Rrtype: rrtype, Class: dns.ClassINET, Ttl: ttl}
}
func lookupMX(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupMX(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, mx := range recs {
rrs = append(rrs, &dns.MX{
Hdr: recordHeader(fqdn, dns.TypeMX, ttl),
Preference: mx.Pref,
Mx: dns.Fqdn(mx.Host),
})
}
return rrs, dns.RcodeSuccess
}
func lookupTXT(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupTXT(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, txt := range recs {
rrs = append(rrs, &dns.TXT{
Hdr: recordHeader(fqdn, dns.TypeTXT, ttl),
Txt: chunkTXT(txt),
})
}
return rrs, dns.RcodeSuccess
}
func lookupNS(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupNS(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, ns := range recs {
rrs = append(rrs, &dns.NS{
Hdr: recordHeader(fqdn, dns.TypeNS, ttl),
Ns: dns.Fqdn(ns.Host),
})
}
return rrs, dns.RcodeSuccess
}
func lookupSRV(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
_, recs, err := r.LookupSRV(ctx, "", "", name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, srv := range recs {
rrs = append(rrs, &dns.SRV{
Hdr: recordHeader(fqdn, dns.TypeSRV, ttl),
Priority: srv.Priority,
Weight: srv.Weight,
Port: srv.Port,
Target: dns.Fqdn(srv.Target),
})
}
return rrs, dns.RcodeSuccess
}
func lookupCNAME(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
cname, err := r.LookupCNAME(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
// LookupCNAME returns the queried name itself when the name resolves but
// has no CNAME record; that is a NODATA result, not a CNAME.
if strings.EqualFold(dns.Fqdn(cname), fqdn) {
return nil, dns.RcodeSuccess
}
return []dns.RR{&dns.CNAME{
Hdr: recordHeader(fqdn, dns.TypeCNAME, ttl),
Target: dns.Fqdn(cname),
}}, dns.RcodeSuccess
}
func lookupPTR(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
addr, ok := ptrQueryAddr(name)
if !ok {
return nil, dns.RcodeSuccess
}
names, err := r.LookupAddr(ctx, addr)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(names))
for _, n := range names {
rrs = append(rrs, &dns.PTR{
Hdr: recordHeader(fqdn, dns.TypePTR, ttl),
Ptr: dns.Fqdn(n),
})
}
return rrs, dns.RcodeSuccess
}
// ptrQueryAddr converts a reverse-DNS query name (in-addr.arpa or ip6.arpa)
// into the address string expected by net.Resolver.LookupAddr. It reports false
// when the name is not a well-formed reverse name.
func ptrQueryAddr(qname string) (string, bool) {
name := strings.TrimSuffix(strings.ToLower(dns.Fqdn(qname)), ".")
switch {
case strings.HasSuffix(name, ".in-addr.arpa"):
return parseInAddrArpa(strings.TrimSuffix(name, ".in-addr.arpa"))
case strings.HasSuffix(name, ".ip6.arpa"):
return parseIP6Arpa(strings.TrimSuffix(name, ".ip6.arpa"))
default:
return "", false
}
}
// parseInAddrArpa turns the label portion of an in-addr.arpa name into an IPv4
// address string, reporting false when it is not a well-formed reverse name.
func parseInAddrArpa(labelPart string) (string, bool) {
labels := strings.Split(labelPart, ".")
if len(labels) != 4 {
return "", false
}
slices.Reverse(labels)
addr, err := netip.ParseAddr(strings.Join(labels, "."))
if err != nil || !addr.Is4() {
return "", false
}
return addr.String(), true
}
// parseIP6Arpa turns the nibble portion of an ip6.arpa name into an IPv6
// address string, reporting false when it is not a well-formed reverse name.
func parseIP6Arpa(nibblePart string) (string, bool) {
nibbles := strings.Split(nibblePart, ".")
if len(nibbles) != 32 {
return "", false
}
slices.Reverse(nibbles)
var sb strings.Builder
for i, n := range nibbles {
if i > 0 && i%4 == 0 {
sb.WriteByte(':')
}
sb.WriteString(n)
}
addr, err := netip.ParseAddr(sb.String())
if err != nil || !addr.Is6() {
return "", false
}
return addr.String(), true
}
// rcodeForRecordError maps a non-address lookup error to a DNS rcode. A
// not-found result becomes NODATA rather than NXDOMAIN: net.DNSError.IsNotFound
// does not distinguish a missing name from a name that exists only with records
// of other types, so the name cannot be proven absent and must not be poisoned.
func rcodeForRecordError(err error) int {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
return dns.RcodeSuccess
}
return dns.RcodeServerFailure
}
// chunkTXT splits a TXT string into character-strings no longer than 255 bytes
// so the record can be packed. The chunks form one TXT resource record.
func chunkTXT(s string) []string {
const maxLen = 255
if len(s) <= maxLen {
return []string{s}
}
var chunks []string
for len(s) > maxLen {
chunks = append(chunks, s[:maxLen])
s = s[maxLen:]
}
if len(s) > 0 {
chunks = append(chunks, s)
}
return chunks
}
// FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 {

View File

@@ -5,6 +5,7 @@ import (
"errors"
"net"
"net/netip"
"strings"
"testing"
"github.com/miekg/dns"
@@ -121,6 +122,164 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
}
func TestPtrQueryAddr(t *testing.T) {
tests := []struct {
name string
qname string
want string
wantOK bool
}{
{name: "ipv4", qname: "4.3.2.1.in-addr.arpa.", want: "1.2.3.4", wantOK: true},
{name: "ipv4 no trailing dot", qname: "1.0.0.127.in-addr.arpa", want: "127.0.0.1", wantOK: true},
{
name: "ipv6",
qname: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
want: "2001:db8::1",
wantOK: true,
},
{name: "ipv4 wrong label count", qname: "2.1.in-addr.arpa.", wantOK: false},
{name: "ipv6 wrong nibble count", qname: "1.0.ip6.arpa.", wantOK: false},
{name: "not a reverse name", qname: "example.com.", wantOK: false},
{name: "ipv4 bad octet", qname: "4.3.2.999.in-addr.arpa.", wantOK: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := ptrQueryAddr(tt.qname)
assert.Equal(t, tt.wantOK, ok, "parse success mismatch")
if tt.wantOK {
assert.Equal(t, tt.want, got, "parsed address mismatch")
}
})
}
}
type mockRecordResolver struct {
mx []*net.MX
txt []string
ns []*net.NS
srv []*net.SRV
cname string
ptr []string
err error
}
func (m *mockRecordResolver) LookupMX(context.Context, string) ([]*net.MX, error) {
return m.mx, m.err
}
func (m *mockRecordResolver) LookupTXT(context.Context, string) ([]string, error) {
return m.txt, m.err
}
func (m *mockRecordResolver) LookupNS(context.Context, string) ([]*net.NS, error) {
return m.ns, m.err
}
func (m *mockRecordResolver) LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error) {
return "", m.srv, m.err
}
func (m *mockRecordResolver) LookupCNAME(context.Context, string) (string, error) {
return m.cname, m.err
}
func (m *mockRecordResolver) LookupAddr(context.Context, string) ([]string, error) {
return m.ptr, m.err
}
func TestLookupRecords(t *testing.T) {
notFound := &net.DNSError{IsNotFound: true, Name: "example.com."}
t.Run("MX success", func(t *testing.T) {
r := &mockRecordResolver{mx: []*net.MX{{Host: "mail.example.com.", Pref: 10}}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "mail.example.com.", rrs[0].(*dns.MX).Mx)
})
t.Run("TXT short string is one character-string", func(t *testing.T) {
r := &mockRecordResolver{txt: []string{"v=spf1 -all"}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, []string{"v=spf1 -all"}, rrs[0].(*dns.TXT).Txt)
})
t.Run("TXT chunks long strings", func(t *testing.T) {
long := strings.Repeat("a", 300)
r := &mockRecordResolver{txt: []string{long}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
txt := rrs[0].(*dns.TXT).Txt
require.Len(t, txt, 2, "300-byte string should split into two character-strings")
assert.Equal(t, 255, len(txt[0]))
assert.Equal(t, 45, len(txt[1]))
})
t.Run("NS success", func(t *testing.T) {
r := &mockRecordResolver{ns: []*net.NS{{Host: "ns1.example.com."}}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeNS, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "ns1.example.com.", rrs[0].(*dns.NS).Ns)
})
t.Run("SRV success", func(t *testing.T) {
r := &mockRecordResolver{srv: []*net.SRV{{Target: "sip.example.com.", Port: 5060}}}
rrs, rcode := LookupRecords(context.Background(), r, "_sip._tcp.example.com.", dns.TypeSRV, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, uint16(5060), rrs[0].(*dns.SRV).Port)
})
t.Run("CNAME success", func(t *testing.T) {
r := &mockRecordResolver{cname: "target.example.com."}
rrs, rcode := LookupRecords(context.Background(), r, "www.example.com.", dns.TypeCNAME, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "target.example.com.", rrs[0].(*dns.CNAME).Target)
})
t.Run("CNAME equal to name is NODATA", func(t *testing.T) {
r := &mockRecordResolver{cname: "example.com."}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCNAME, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs, "self-referential CNAME is NODATA")
})
t.Run("PTR success", func(t *testing.T) {
r := &mockRecordResolver{ptr: []string{"host.example.com."}}
rrs, rcode := LookupRecords(context.Background(), r, "4.3.2.1.in-addr.arpa.", dns.TypePTR, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "host.example.com.", rrs[0].(*dns.PTR).Ptr)
})
t.Run("PTR malformed name is NODATA", func(t *testing.T) {
r := &mockRecordResolver{}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypePTR, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs)
})
t.Run("not found is NODATA never NXDOMAIN", func(t *testing.T) {
r := &mockRecordResolver{err: notFound}
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeSuccess, rcode, "missing record must not poison the name")
})
t.Run("server failure maps to SERVFAIL", func(t *testing.T) {
r := &mockRecordResolver{err: &net.DNSError{Err: "server misbehaving", IsTemporary: true}}
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeServerFailure, rcode)
})
t.Run("unsupported type is NODATA", func(t *testing.T) {
r := &mockRecordResolver{}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCAA, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs)
})
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{

View File

@@ -37,6 +37,12 @@ const (
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
LookupCNAME(ctx context.Context, host string) (string, error)
LookupAddr(ctx context.Context, addr string) ([]string, error)
}
type firewaller interface {
@@ -210,12 +216,6 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
f.writeResponse(logger, w, resp, qname, startTime)
return
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
if mostSpecificResId == "" {
@@ -227,9 +227,46 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
reqHasEdns := query.IsEdns0() != nil
switch question.Qtype {
case dns.TypeA, dns.TypeAAAA:
f.handleAddressQuery(ctx, logger, w, resp, mostSpecificResId, matchingEntries, reqHasEdns, startTime)
case dns.TypeMX, dns.TypeTXT, dns.TypeNS, dns.TypeSRV, dns.TypeCNAME, dns.TypePTR:
f.handleRecordQuery(ctx, logger, w, resp, startTime)
default:
// The domain is routed here, so any other type is answered NODATA
// (NOERROR, empty answer) rather than falling back to a resolver that
// would poison the name with NXDOMAIN. The Extended DNS Error lets a
// client tell this capability-driven NODATA apart from an
// authoritative one. The OPT pseudo-record must not appear unless the
// query advertised EDNS0.
if reqHasEdns {
attachEDE(resp, dns.ExtendedErrorCodeNotSupported, "netbird forwarder: unsupported query type")
}
f.writeResponse(logger, w, resp, qname, startTime)
}
}
// handleAddressQuery resolves A/AAAA queries, programs the firewall sets and
// resolved-IP state, and caches the answer for resilience on upstream failure.
func (f *DNSForwarder) handleAddressQuery(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
resp *dns.Msg,
mostSpecificResId route.ResID,
matchingEntries []*ForwarderEntry,
reqHasEdns bool,
startTime time.Time,
) {
question := resp.Question[0]
qname := strings.ToLower(question.Name)
network := resutil.NetworkForQtype(question.Qtype)
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
f.handleDNSError(ctx, logger, w, question, resp, qname, result, reqHasEdns, startTime)
return
}
@@ -240,6 +277,25 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
f.writeResponse(logger, w, resp, qname, startTime)
}
// handleRecordQuery resolves non-address record types (MX, TXT, NS, SRV,
// CNAME, PTR) through the host resolver. Missing records are answered NODATA so
// the routed name is never poisoned with NXDOMAIN.
func (f *DNSForwarder) handleRecordQuery(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
resp *dns.Msg,
startTime time.Time,
) {
question := resp.Question[0]
qname := strings.ToLower(question.Name)
records, rcode := resutil.LookupRecords(ctx, f.resolver, qname, question.Qtype, f.ttl)
resp.Rcode = rcode
resp.Answer = append(resp.Answer, records...)
f.writeResponse(logger, w, resp, qname, startTime)
}
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)

View File

@@ -133,6 +133,41 @@ func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([
return args.Get(0).([]netip.Addr), args.Error(1)
}
func (m *MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]*net.MX)
return recs, args.Error(1)
}
func (m *MockResolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]string)
return recs, args.Error(1)
}
func (m *MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]*net.NS)
return recs, args.Error(1)
}
func (m *MockResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
args := m.Called(ctx, service, proto, name)
recs, _ := args.Get(1).([]*net.SRV)
return args.String(0), recs, args.Error(2)
}
func (m *MockResolver) LookupCNAME(ctx context.Context, host string) (string, error) {
args := m.Called(ctx, host)
return args.String(0), args.Error(1)
}
func (m *MockResolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
args := m.Called(ctx, addr)
recs, _ := args.Get(0).([]string)
return recs, args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct {
name string
@@ -545,12 +580,15 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
}
func TestDNSForwarder_ResponseCodes(t *testing.T) {
// A type with no net.Resolver Lookup method (CAA) must answer NODATA
// (NOERROR, empty) rather than NXDOMAIN/NOTIMP to avoid poisoning the name.
tests := []struct {
name string
queryType uint16
queryDomain string
configured string
expectedCode int
expectEDE bool
description string
}{
{
@@ -562,28 +600,13 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
description: "RFC compliant REFUSED for unauthorized queries",
},
{
name: "unsupported query type returns NOTIMP",
queryType: dns.TypeMX,
name: "unsupported query type returns NODATA",
queryType: dns.TypeCAA,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "RFC compliant NOTIMP for unsupported types",
},
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
expectedCode: dns.RcodeSuccess,
expectEDE: true,
description: "Unsupported types answer NODATA, not NXDOMAIN/NOTIMP",
},
}
@@ -599,6 +622,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
query.SetEdns0(dns.DefaultMsgSize, false)
// Capture the written response
var writtenResp *dns.Msg
@@ -614,10 +638,213 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
assert.Empty(t, writtenResp.Answer, "Non-address response should carry no answers")
if tt.expectEDE {
require.NotNil(t, writtenResp.IsEdns0(), "EDNS0 client should get an OPT in the reply")
assert.True(t, hasEDE(writtenResp, dns.ExtendedErrorCodeNotSupported),
"unsupported type NODATA should carry EDE Not Supported")
}
})
}
}
func hasEDE(m *dns.Msg, code uint16) bool {
opt := m.IsEdns0()
if opt == nil {
return false
}
for _, o := range opt.Option {
if ede, ok := o.(*dns.EDNS0_EDE); ok && ede.InfoCode == code {
return true
}
}
return false
}
func TestDNSForwarder_RecordQueries(t *testing.T) {
notFound := &net.DNSError{IsNotFound: true, Name: "example.com"}
t.Run("MX records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupMX", mock.Anything, "example.com.").
Return([]*net.MX{{Host: "mail.example.com.", Pref: 10}}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
mx, ok := resp.Answer[0].(*dns.MX)
require.True(t, ok, "answer should be an MX record")
assert.Equal(t, uint16(10), mx.Preference)
assert.Equal(t, "mail.example.com.", mx.Mx)
mockResolver.AssertExpectations(t)
})
t.Run("missing MX is NODATA not NXDOMAIN", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
// A not-found cannot prove the name is absent (it may exist with only
// other record types), so it must answer NODATA, never NXDOMAIN.
mockResolver.On("LookupMX", mock.Anything, "example.com.").
Return(nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "missing record must be NODATA")
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("NS records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupNS", mock.Anything, "example.com.").
Return([]*net.NS{{Host: "ns1.example.com."}}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ns, ok := resp.Answer[0].(*dns.NS)
require.True(t, ok, "answer should be an NS record")
assert.Equal(t, "ns1.example.com.", ns.Ns)
mockResolver.AssertExpectations(t)
})
t.Run("missing NS is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupNS", mock.Anything, "example.com.").
Return(nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("SRV records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
Return("", []*net.SRV{{Target: "sip.example.com.", Port: 5060, Priority: 10, Weight: 5}}, nil).Once()
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
srv, ok := resp.Answer[0].(*dns.SRV)
require.True(t, ok, "answer should be an SRV record")
assert.Equal(t, "sip.example.com.", srv.Target)
assert.Equal(t, uint16(5060), srv.Port)
assert.Equal(t, uint16(10), srv.Priority)
mockResolver.AssertExpectations(t)
})
t.Run("missing SRV is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
Return("", nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("TXT records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupTXT", mock.Anything, "example.com.").
Return([]string{"v=spf1 -all"}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeTXT)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
txt, ok := resp.Answer[0].(*dns.TXT)
require.True(t, ok, "answer should be a TXT record")
assert.Equal(t, []string{"v=spf1 -all"}, txt.Txt)
mockResolver.AssertExpectations(t)
})
t.Run("CNAME record is forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "www.example.com")
mockResolver.On("LookupCNAME", mock.Anything, "www.example.com.").
Return("target.example.com.", nil).Once()
resp := runRecordQuery(t, forwarder, "www.example.com", dns.TypeCNAME)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok, "answer should be a CNAME record")
assert.Equal(t, "target.example.com.", cname.Target)
mockResolver.AssertExpectations(t)
})
t.Run("CNAME equal to the name is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
// No CNAME exists: LookupCNAME echoes the queried name back.
mockResolver.On("LookupCNAME", mock.Anything, "example.com.").
Return("example.com.", nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeCNAME)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer, "self-referential CNAME means no CNAME record")
mockResolver.AssertExpectations(t)
})
t.Run("PTR record is forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "*.in-addr.arpa")
// The reverse name is parsed back to the address LookupAddr expects.
mockResolver.On("LookupAddr", mock.Anything, "1.2.3.4").
Return([]string{"host.example.com."}, nil).Once()
resp := runRecordQuery(t, forwarder, "4.3.2.1.in-addr.arpa", dns.TypePTR)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok, "answer should be a PTR record")
assert.Equal(t, "host.example.com.", ptr.Ptr)
mockResolver.AssertExpectations(t)
})
}
func newRecordTestForwarder(t *testing.T, r resolver, configured string) *DNSForwarder {
t.Helper()
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = r
d, err := domain.FromString(configured)
require.NoError(t, err)
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
return forwarder
}
func runRecordQuery(t *testing.T, forwarder *DNSForwarder, qname string, qtype uint16) *dns.Msg {
t.Helper()
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(qname), qtype)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp, "expected response to be written")
return resp
}
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
tests := []struct {
name string

View File

@@ -40,6 +40,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/expose"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/metrics"
"github.com/netbirdio/netbird/client/internal/netflow"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
@@ -82,6 +83,12 @@ const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
disableAutoUpdate = "disabled"
// systemInfoTimeout bounds how long the sync loop waits for system info / posture
// check gathering. The gathering runs uncancellable system calls (process scan,
// exec, os.Stat); without this bound a single stuck call freezes handleSync, and
// thus syncMsgMux, for as long as the call hangs (observed multi-minute freezes).
systemInfoTimeout = 15 * time.Second
)
var ErrResetConnection = fmt.Errorf("reset connection")
@@ -141,7 +148,9 @@ type EngineConfig struct {
BlockInbound bool
DisableIPv6 bool
LazyConnectionEnabled bool
// LazyConnection is the MDM-sourced lazy-connection override; StateUnset defers to
// the env var and management feature flag.
LazyConnection lazyconn.State
MTU uint16
@@ -895,6 +904,16 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
}
// phase times a sync sub-phase: it returns a function that records the elapsed
// duration when called. Starting the timer at the call site keeps inter-phase
// glue code out of the measurement.
func (e *Engine) phase(name string) func() {
start := time.Now()
return func() {
e.clientMetrics.RecordSyncPhase(e.ctx, name, time.Since(start))
}
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
@@ -914,7 +933,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
}
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
done := e.phase("netbird_config")
err := e.updateNetbirdConfig(update.GetNetbirdConfig())
done()
if err != nil {
return err
}
@@ -928,11 +950,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
done = e.phase("checks")
err = e.updateChecksIfNew(update.Checks)
done()
if err != nil {
return err
}
done = e.phase("persist")
e.persistSyncResponse(update)
done()
// only apply new changes and ignore old ones
if err := e.updateNetworkMap(nm); err != nil {
@@ -1066,11 +1093,22 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
}
e.checks = checks
info, err := system.GetInfoWithChecks(e.ctx, checks, e.overlayAddresses()...)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, checks, e.overlayAddresses()...)
if !ok {
// Gathering timed out; skip the meta sync this cycle rather than blocking the
// sync loop (and syncMsgMux) on a stuck system call. A later sync will retry.
return nil
}
e.applyInfoFlags(info)
if err := e.mgmClient.SyncMeta(info); err != nil {
return fmt.Errorf("could not sync meta: error %s", err)
}
return nil
}
// applyInfoFlags sets the engine's config-derived feature flags on the gathered system info.
func (e *Engine) applyInfoFlags(info *system.Info) {
info.SetFlags(
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
@@ -1082,19 +1120,12 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.DisableIPv6,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
log.Errorf("could not sync meta: error %s", err)
return err
}
return nil
}
// overlayAddresses returns our own WireGuard overlay address (v4 and v6) so it
@@ -1254,31 +1285,15 @@ func (e *Engine) receiveManagementEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.overlayAddresses()...)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, e.checks, e.overlayAddresses()...)
if !ok {
// Gathering timed out; connect the stream with base info so management
// connectivity still comes up rather than blocking here.
info = system.GetInfo(e.ctx)
}
info.SetFlags(
e.config.RosenpassEnabled,
e.config.RosenpassPermissive,
&e.config.ServerSSHAllowed,
e.config.DisableClientRoutes,
e.config.DisableServerRoutes,
e.config.DisableDNS,
e.config.DisableFirewall,
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.DisableIPv6,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
e.applyInfoFlags(info)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
err := e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
@@ -1371,13 +1386,16 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address())
done := e.phase("dns_server")
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
done()
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
// apply routes first, route related actions might depend on routing being enabled
done = e.phase("routes_classify")
routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
@@ -1386,29 +1404,60 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.connMgr.UpdateRouteHAMap(clientRoutes)
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
}
done()
done = e.phase("routes_apply")
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update routes: %v", err)
}
done()
done = e.phase("filtering")
if e.acl != nil {
e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
}
done()
done = e.phase("dns_forwarder")
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
done()
// Ingress forward rules
done = e.phase("forward_rules")
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
if err != nil {
log.Errorf("failed to update forward rules, err: %v", err)
}
done()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
done = e.phase("offline_peers")
e.updateOfflinePeers(networkMap.GetOfflinePeers())
done()
remotePeers, err := e.reconcilePeers(networkMap)
if err != nil {
return err
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
done = e.phase("lazy_exclude")
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
done()
e.networkSerial = serial
return nil
}
// reconcilePeers applies the remote peer list from the network map (removing,
// modifying and adding peers, then updating SSH config) and returns the remote
// peers with our own peer filtered out, for use by later sync steps.
func (e *Engine) reconcilePeers(networkMap *mgmProto.NetworkMap) ([]*mgmProto.RemotePeerConfig, error) {
// Filter out own peer from the remote peers list
localPubKey := e.config.WgPrivateKey.PublicKey().String()
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
@@ -1423,42 +1472,43 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
err := e.removeAllPeers()
e.statusRecorder.FinishPeerListModifications()
if err != nil {
return err
return nil, err
}
} else {
err := e.removePeers(remotePeers)
if err != nil {
return err
}
err = e.modifyPeers(remotePeers)
if err != nil {
return err
}
err = e.addNewPeers(remotePeers)
if err != nil {
return err
}
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
return remotePeers, nil
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
done := e.phase("removed_peers")
err := e.removePeers(remotePeers)
done()
if err != nil {
return nil, err
}
e.networkSerial = serial
done = e.phase("modified_peers")
err = e.modifyPeers(remotePeers)
done()
if err != nil {
return nil, err
}
return nil
done = e.phase("added_peers")
err = e.addNewPeers(remotePeers)
done()
if err != nil {
return nil, err
}
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
return remotePeers, nil
}
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
@@ -1938,7 +1988,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.BlockLANAccess,
e.config.BlockInbound,
e.config.DisableIPv6,
e.config.LazyConnectionEnabled,
e.config.EnableSSHRoot,
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,

View File

@@ -178,6 +178,10 @@ func (m *MockWGIface) LastActivities() map[string]monotime.Time {
return nil
}
func (m *MockWGIface) MTU() uint16 {
return 1280
}
func (m *MockWGIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
return nil
}

View File

@@ -44,4 +44,5 @@ type wgIfaceBase interface {
FullStats() (*configurer.Stats, error)
LastActivities() map[string]monotime.Time
SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error
MTU() uint16
}

View File

@@ -124,6 +124,11 @@ func (d *BindListener) ReadPackets() {
d.done.Done()
}
// CapturedPacket is unused in userspace bind mode: first-packet reinjection is kernel-only.
func (d *BindListener) CapturedPacket() []byte {
return nil
}
// Close stops the listener and cleans up resources.
func (d *BindListener) Close() {
d.peerCfg.Log.Infof("closing activity listener (LazyConn)")

View File

@@ -45,10 +45,6 @@ type MockWGIfaceBind struct {
endpointMgr *mockEndpointManager
}
func (m *MockWGIfaceBind) RemovePeer(string) error {
return nil
}
func (m *MockWGIfaceBind) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
@@ -68,6 +64,10 @@ func (m *MockWGIfaceBind) GetBind() device.EndpointManager {
return m.endpointMgr
}
func (m *MockWGIfaceBind) MTU() uint16 {
return 1280
}
func TestBindListener_Creation(t *testing.T) {
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
@@ -207,8 +207,9 @@ func TestManager_BindMode(t *testing.T) {
require.NoError(t, err)
select {
case peerConnID := <-mgr.OnActivityChan:
assert.Equal(t, cfg.PeerConnID, peerConnID, "Received peer connection ID should match")
case ev := <-mgr.OnActivityChan:
assert.Equal(t, cfg.PeerConnID, ev.PeerConnID, "Received peer connection ID should match")
assert.Nil(t, ev.FirstPacket, "Bind mode does not capture packets: reinjection is kernel-only")
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notification")
}
@@ -266,8 +267,8 @@ func TestManager_BindMode_MultiplePeers(t *testing.T) {
receivedPeers := make(map[peerid.ConnID]bool)
for i := 0; i < 2; i++ {
select {
case peerConnID := <-mgr.OnActivityChan:
receivedPeers[peerConnID] = true
case ev := <-mgr.OnActivityChan:
receivedPeers[ev.PeerConnID] = true
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for activity notifications")
}

View File

@@ -3,11 +3,13 @@ package activity
import (
"fmt"
"net"
"slices"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/bufsize"
"github.com/netbirdio/netbird/client/internal/lazyconn"
)
@@ -20,6 +22,8 @@ type UDPListener struct {
done sync.Mutex
isClosed atomic.Bool
capturedPacket []byte
}
// NewUDPListener creates a listener that detects activity via UDP socket reads.
@@ -46,9 +50,13 @@ func NewUDPListener(wgIface WgInterface, cfg lazyconn.PeerConfig) (*UDPListener,
}
// ReadPackets blocks reading from the UDP socket until activity is detected or the listener is closed.
// The first packet that triggers activity is captured so it can be reinjected through the real
// transport once it is established. Without this, kernel WireGuard's handshake initiation would be
// dropped and WG would only retry after REKEY_TIMEOUT.
func (d *UDPListener) ReadPackets() {
for {
n, remoteAddr, err := d.conn.ReadFromUDP(make([]byte, 1))
buf := make([]byte, int(d.wgIface.MTU())+bufsize.WGBufferOverhead)
n, remoteAddr, err := d.conn.ReadFromUDP(buf)
if err != nil {
if d.isClosed.Load() {
d.peerCfg.Log.Infof("exit from activity listener")
@@ -62,20 +70,24 @@ func (d *UDPListener) ReadPackets() {
d.peerCfg.Log.Warnf("received %d bytes from %s, too short", n, remoteAddr)
continue
}
d.peerCfg.Log.Infof("activity detected")
d.capturedPacket = slices.Clone(buf[:n])
d.peerCfg.Log.Infof("activity detected, captured %d bytes for reinjection", n)
break
}
d.peerCfg.Log.Debugf("removing lazy endpoint: %s", d.endpoint.String())
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
// Ignore close error as it may return "use of closed network connection" if already closed.
// Leave the peer in place. ConfigureWGEndpoint will UpdatePeer with the real endpoint;
// removing the peer here wipes kernel WG's staged queue and drops the user packet that
// triggered activation.
_ = d.conn.Close()
d.done.Unlock()
}
// CapturedPacket returns the first packet that triggered activity, or nil if none was captured.
// Safe to call after ReadPackets returns.
func (d *UDPListener) CapturedPacket() []byte {
return d.capturedPacket
}
// Close stops the listener and cleans up resources.
func (d *UDPListener) Close() {
d.peerCfg.Log.Infof("closing activity listener: %s", d.conn.LocalAddr().String())

View File

@@ -19,17 +19,25 @@ import (
type listener interface {
ReadPackets()
Close()
CapturedPacket() []byte
}
// Event reports activity on a managed peer. FirstPacket is the bytes that triggered activation,
// captured for reinjection through the real transport.
type Event struct {
PeerConnID peerid.ConnID
FirstPacket []byte
}
type WgInterface interface {
RemovePeer(peerKey string) error
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
IsUserspaceBind() bool
Address() wgaddr.Address
MTU() uint16
}
type Manager struct {
OnActivityChan chan peerid.ConnID
OnActivityChan chan Event
wgIface WgInterface
@@ -41,7 +49,7 @@ type Manager struct {
func NewManager(wgIface WgInterface) *Manager {
m := &Manager{
OnActivityChan: make(chan peerid.ConnID, 1),
OnActivityChan: make(chan Event, 1),
wgIface: wgIface,
peers: make(map[peerid.ConnID]listener),
done: make(chan struct{}),
@@ -116,12 +124,12 @@ func (m *Manager) waitForTraffic(l listener, peerConnID peerid.ConnID) {
delete(m.peers, peerConnID)
m.mu.Unlock()
m.notify(peerConnID)
m.notify(Event{PeerConnID: peerConnID, FirstPacket: l.CapturedPacket()})
}
func (m *Manager) notify(peerConnID peerid.ConnID) {
func (m *Manager) notify(ev Event) {
select {
case <-m.done:
case m.OnActivityChan <- peerConnID:
case m.OnActivityChan <- ev:
}
}

View File

@@ -1,6 +1,7 @@
package activity
import (
"bytes"
"net"
"net/netip"
"testing"
@@ -25,10 +26,6 @@ func (m *MocPeer) ConnID() peerid.ConnID {
type MocWGIface struct {
}
func (m MocWGIface) RemovePeer(string) error {
return nil
}
func (m MocWGIface) UpdatePeer(string, []netip.Prefix, time.Duration, *net.UDPAddr, *wgtypes.Key) error {
return nil
}
@@ -44,6 +41,10 @@ func (m MocWGIface) Address() wgaddr.Address {
}
}
func (m MocWGIface) MTU() uint16 {
return 1280
}
// GetPeerListener is a test helper to access listeners
func (m *Manager) GetPeerListener(peerConnID peerid.ConnID) (listener, bool) {
m.mu.Lock()
@@ -86,11 +87,15 @@ func TestManager_MonitorPeerActivity(t *testing.T) {
}
select {
case peerConnID := <-mgr.OnActivityChan:
if peerConnID != peerCfg1.PeerConnID {
t.Fatalf("unexpected peerConnID: %v", peerConnID)
case ev := <-mgr.OnActivityChan:
if ev.PeerConnID != peerCfg1.PeerConnID {
t.Fatalf("unexpected peerConnID: %v", ev.PeerConnID)
}
if !bytes.Equal(ev.FirstPacket, []byte{0x01, 0x02, 0x03, 0x04, 0x05}) {
t.Fatalf("unexpected first packet: %v", ev.FirstPacket)
}
case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for activity")
}
}

View File

@@ -3,24 +3,57 @@ package lazyconn
import (
"os"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
)
const (
EnvEnableLazyConn = "NB_ENABLE_EXPERIMENTAL_LAZY_CONN"
EnvLazyConn = "NB_LAZY_CONN"
EnvInactivityThreshold = "NB_LAZY_CONN_INACTIVITY_THRESHOLD"
)
func IsLazyConnEnabledByEnv() bool {
val := os.Getenv(EnvEnableLazyConn)
if val == "" {
return false
}
enabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvEnableLazyConn, err)
return false
}
return enabled
// State is the tri-state local override for lazy connections read from the environment.
type State int
const (
// StateUnset means no local override; defer to the management feature flag.
StateUnset State = iota
// StateOn forces lazy connections on, overriding management.
StateOn
// StateOff forces lazy connections off, overriding management.
StateOff
)
// EnvState reads NB_LAZY_CONN and returns the local override state.
func EnvState() State {
return ParseState(os.Getenv(EnvLazyConn))
}
// ParseState interprets a lazy-connection override value (from the environment or an MDM
// policy). It accepts the on/off aliases plus any value strconv.ParseBool understands
// (true/false/1/0). An empty or unrecognized value returns StateUnset so that the
// management feature flag remains in control.
func ParseState(raw string) State {
if raw == "" {
return StateUnset
}
normalized := strings.ToLower(strings.TrimSpace(raw))
switch normalized {
case "on":
return StateOn
case "off":
return StateOff
}
enabled, err := strconv.ParseBool(normalized)
if err != nil {
log.Warnf("failed to parse %s value %q: %v", EnvLazyConn, raw, err)
return StateUnset
}
if enabled {
return StateOn
}
return StateOff
}

View File

@@ -0,0 +1,45 @@
package lazyconn
import (
"os"
"testing"
)
func TestEnvState(t *testing.T) {
tests := []struct {
value string
set bool
want State
}{
{set: false, want: StateUnset},
{value: "", set: true, want: StateUnset},
{value: "on", set: true, want: StateOn},
{value: "ON", set: true, want: StateOn},
{value: "true", set: true, want: StateOn},
{value: "1", set: true, want: StateOn},
{value: " on ", set: true, want: StateOn},
{value: "off", set: true, want: StateOff},
{value: "OFF", set: true, want: StateOff},
{value: "false", set: true, want: StateOff},
{value: "0", set: true, want: StateOff},
{value: "auto", set: true, want: StateUnset},
{value: "garbage", set: true, want: StateUnset},
}
for _, tt := range tests {
name := tt.value
if !tt.set {
name = "unset"
}
t.Run(name, func(t *testing.T) {
t.Setenv(EnvLazyConn, tt.value)
if !tt.set {
os.Unsetenv(EnvLazyConn)
}
if got := EnvState(); got != tt.want {
t.Fatalf("EnvState() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -130,8 +130,8 @@ func (m *Manager) Start(ctx context.Context) {
select {
case <-ctx.Done():
return
case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(peerConnID)
case ev := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ev)
case peerIDs := <-m.inactivityManager.InactivePeersChan():
m.onPeerInactivityTimedOut(peerIDs)
}
@@ -513,13 +513,13 @@ func (m *Manager) checkHaGroupActivity(haGroup route.HAUniqueID, peerID string,
return false
}
func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
func (m *Manager) onPeerActivity(ev activity.Event) {
m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock()
mp, ok := m.managedPeersByConnID[peerConnID]
mp, ok := m.managedPeersByConnID[ev.PeerConnID]
if !ok {
log.Errorf("peer not found by conn id: %v", peerConnID)
log.Errorf("peer not found by conn id: %v", ev.PeerConnID)
return
}
@@ -536,7 +536,7 @@ func (m *Manager) onPeerActivity(peerConnID peerid.ConnID) {
m.activateHAGroupPeers(mp.peerCfg)
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
m.peerStore.PeerConnOpenWithFirstPacket(m.engineCtx, mp.peerCfg.PublicKey, ev.FirstPacket)
}
func (m *Manager) onPeerInactivityTimedOut(peerIDs map[string]struct{}) {

View File

@@ -17,4 +17,5 @@ type WGIface interface {
IsUserspaceBind() bool
Address() wgaddr.Address
LastActivities() map[string]monotime.Time
MTU() uint16
}

View File

@@ -120,6 +120,30 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
m.trimLocked()
}
func (m *influxDBMetrics) RecordSyncPhase(_ context.Context, agentInfo AgentInfo, phase string, duration time.Duration) {
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s,phase=%s",
agentInfo.DeploymentType.String(),
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
phase,
)
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_sync_phase",
tags: tags,
fields: map[string]float64{
"duration_seconds": duration.Seconds(),
},
timestamp: time.Now(),
})
m.trimLocked()
}
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
result := "success"
if !success {

View File

@@ -78,6 +78,25 @@ Tags:
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
### Sync Phase Timing
Measurement: `netbird_sync_phase`
Breaks down where time goes inside a single sync, so the total `netbird_sync` duration can be attributed to the sub-step that dominates.
| Field | Description |
|-------|-------------|
| `duration_seconds` | Time spent in one sub-phase of sync processing |
Tags:
- `phase`: the sub-phase — `netbird_config`, `checks`, `persist`, `dns_server`, `routes_classify`, `routes_apply`, `filtering`, `dns_forwarder`, `forward_rules`, `offline_peers`, `removed_peers`, `modified_peers`, `added_peers`, `lazy_exclude`
- `deployment_type`: "cloud" | "selfhosted" | "unknown"
- `version`: NetBird version string
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
**Note:** this is wall-time per phase — it includes both CPU work and time spent waiting on locks. A slow phase points to *where* the time goes, not *why*; pair it with lock-wait metrics to tell contention apart from real work.
### Login Duration
Measurement: `netbird_login`
@@ -191,4 +210,52 @@ docker compose exec influxdb influx query \
# Check ingest server health
curl http://localhost:8087/health
```
```
## Analyzing a Debug Bundle
Metrics collection is always on, so every debug bundle ships a `metrics.txt` in InfluxDB line protocol — a timestamped time series of all recorded events (sync durations, sync phases, connection stages, login). You can replay it into the local stack and graph it, without a running client.
The bundle's `metrics.txt` is a rolling window (capped at 5 days / ~20k samples, see [Buffer Limits](#buffer-limits)). For a connection incident the relevant window is short (connection setup is seconds), so a bundle captured during the issue is enough.
### 1. Start the stack
```bash
# From this directory (client/internal/metrics/infra)
INFLUXDB_ADMIN_TOKEN=admin123 INFLUXDB_ADMIN_PASSWORD=admin123 GRAFANA_ADMIN_PASSWORD=admin123 \
docker compose up -d
```
(`admin123` are throwaway local credentials — fine for offline analysis.)
### 2. Clear any previous data
So you only see this bundle:
```bash
docker exec influxdb influx delete --org netbird --bucket metrics --token admin123 \
--start 1970-01-01T00:00:00Z --stop 2100-01-01T00:00:00Z
```
### 3. Import the bundle's metrics.txt
InfluxDB is not exposed on the host, so import inside the container:
```bash
docker cp /path/to/bundle/metrics.txt influxdb:/tmp/m.txt
docker exec influxdb influx write --org netbird --bucket metrics --precision ns \
--token admin123 --file /tmp/m.txt
```
Re-importing the same file is idempotent (same measurement+tags+timestamp overwrites).
### 4. View the dashboards
Grafana on http://localhost:3001 (login `admin` / `admin123`), datasource pre-provisioned:
- **Where sync time goes:** http://localhost:3001/d/netbird-sync-phases/netbird-sync-phases-where-time-goes
- **General client metrics:** http://localhost:3001/d/netbird-influxdb-metrics
**Set the time range** to cover the bundle's timestamps (e.g. "Last 7 days" or an absolute range matching when the bundle was taken) — with the default short range the panels look empty.
Bundles are distinguishable by the `version` tag; add a tag at import time (e.g. `sed 's/^netbird_\([a-z_]*\),/netbird_\1,bundle=mycase,/' metrics.txt`) if you want to compare several side by side.

View File

@@ -0,0 +1,259 @@
{
"annotations": {
"list": []
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 1,
"links": [],
"refresh": "",
"schemaVersion": 39,
"tags": [
"netbird",
"sync"
],
"templating": {
"list": [
{
"current": {
"text": "All",
"value": "$__all"
},
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"definition": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
"includeAll": true,
"label": "version",
"multi": true,
"name": "version",
"query": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
"refresh": 2,
"type": "query",
"allValue": ".*"
}
]
},
"time": {
"from": "now-2d",
"to": "now"
},
"timepicker": {},
"timezone": "",
"title": "NetBird Sync Phases (where time goes)",
"uid": "netbird-sync-phases",
"version": 1,
"panels": [
{
"id": 1,
"title": "Time per phase over time (stacked, ms)",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 10,
"w": 24,
"x": 0,
"y": 0
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"custom": {
"drawStyle": "bars",
"stacking": {
"mode": "normal",
"group": "A"
},
"fillOpacity": 80,
"lineWidth": 0
}
},
"overrides": []
},
"options": {
"legend": {
"displayMode": "table",
"placement": "right",
"calcs": [
"max",
"mean"
]
},
"tooltip": {
"mode": "multi",
"sort": "desc"
}
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"phase\"])\n |> group(columns: [\"phase\"])"
}
]
},
{
"id": 2,
"title": "p95 per phase (ms)",
"type": "bargauge",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 11,
"w": 12,
"x": 0,
"y": 10
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"color": {
"mode": "continuous-GrYlRd"
}
},
"overrides": []
},
"options": {
"displayMode": "gradient",
"orientation": "horizontal",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showUnfilled": true
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> sort(columns: [\"_value\"], desc: true)"
}
]
},
{
"id": 3,
"title": "Per-phase stats (ms): mean / p95 / max",
"type": "table",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 11,
"w": 12,
"x": 12,
"y": 10
},
"fieldConfig": {
"defaults": {
"unit": "ms"
},
"overrides": []
},
"options": {
"showHeader": true,
"sortBy": [
{
"displayName": "max",
"desc": true
}
]
},
"transformations": [
{
"id": "merge",
"options": {}
}
],
"targets": [
{
"refId": "mean",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> mean()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"mean\"})"
},
{
"refId": "p95",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"p95\"})"
},
{
"refId": "max",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> max()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"max\"})"
}
]
},
{
"id": 4,
"title": "Total sync duration (netbird_sync, ms) \u2014 reference",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 24,
"x": 0,
"y": 21
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"custom": {
"drawStyle": "points",
"pointSize": 5
}
},
"overrides": []
},
"options": {
"legend": {
"displayMode": "table",
"placement": "right",
"calcs": [
"max",
"mean"
]
},
"tooltip": {
"mode": "single"
}
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"version\"])\n |> group(columns: [\"version\"])"
}
]
}
]
}

View File

@@ -19,7 +19,7 @@ const (
defaultListenAddr = ":8087"
defaultInfluxDBURL = "http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns"
maxBodySize = 50 * 1024 * 1024 // 50 MB max request body
maxDurationSeconds = 300.0 // reject any duration field > 5 minutes
maxDurationSeconds = 86400.0 // reject any duration field > 24 hours
peerIDLength = 16 // truncated SHA-256: 8 bytes = 16 hex chars
maxTagValueLength = 64 // reject tag values longer than this
)
@@ -59,6 +59,19 @@ var allowedMeasurements = map[string]measurementSpec{
"peer_id": true,
},
},
"netbird_sync_phase": {
allowedFields: map[string]bool{
"duration_seconds": true,
},
allowedTags: map[string]bool{
"deployment_type": true,
"version": true,
"os": true,
"arch": true,
"peer_id": true,
"phase": true,
},
},
"netbird_login": {
allowedFields: map[string]bool{
"duration_seconds": true,

View File

@@ -53,14 +53,14 @@ func TestValidateLine_NegativeValue(t *testing.T) {
}
func TestValidateLine_DurationTooLarge(t *testing.T) {
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=999 1234567890`
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=100000 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "too large")
}
func TestValidateLine_TotalSecondsTooLarge(t *testing.T) {
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=500 1234567890`
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=100000 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "too large")

View File

@@ -56,6 +56,9 @@ type metricsImplementation interface {
// RecordSyncDuration records how long it took to process a sync message
RecordSyncDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration)
// RecordSyncPhase records how long a single sub-phase of sync processing took
RecordSyncPhase(ctx context.Context, agentInfo AgentInfo, phase string, duration time.Duration)
// RecordLoginDuration records how long the login to management took
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
@@ -127,6 +130,18 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
}
// RecordSyncPhase records the duration of a single sub-phase of sync processing
func (c *ClientMetrics) RecordSyncPhase(ctx context.Context, phase string, duration time.Duration) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
c.impl.RecordSyncPhase(ctx, agentInfo, phase, duration)
}
// RecordLoginDuration records how long the login to management server took
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
if c == nil {

View File

@@ -70,6 +70,9 @@ func (m *mockMetrics) RecordConnectionStages(_ context.Context, _ AgentInfo, _ s
func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.Duration) {
}
func (m *mockMetrics) RecordSyncPhase(_ context.Context, _ AgentInfo, _ string, _ time.Duration) {
}
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
}

View File

@@ -6,6 +6,7 @@ import (
"net"
"net/netip"
"runtime"
"slices"
"sync"
"time"
@@ -136,6 +137,39 @@ type Conn struct {
// Connection stage timestamps for metrics
metricsRecorder MetricsRecorder
metricsStages *MetricsStages
// pendingFirstPacket is the lazyconn-captured handshake init, replayed once the real
// transport is up.
pendingFirstPacket []byte
}
// injectPendingFirstPacket replays the captured handshake through the proxy if present, else
// directly through the ICE conn. The packet is cleared only after a successful write, so a failed
// or transport-less attempt leaves it available for a later reinjection. Caller must hold conn.mu.
func (conn *Conn) injectPendingFirstPacket(proxy wgproxy.Proxy, directConn net.Conn) {
pkt := conn.pendingFirstPacket
if len(pkt) == 0 {
return
}
switch {
case proxy != nil:
if err := proxy.InjectPacket(pkt); err != nil {
conn.Log.Debugf("failed to reinject captured first packet via proxy: %v", err)
return
}
case directConn != nil:
if _, err := directConn.Write(pkt); err != nil {
conn.Log.Debugf("failed to reinject captured first packet via direct conn: %v", err)
return
}
default:
conn.Log.Debugf("no transport available to reinject captured first packet")
return
}
conn.pendingFirstPacket = nil
conn.Log.Debugf("reinjected captured first packet (%d bytes)", len(pkt))
}
// NewConn creates a new not opened Conn to the remote peer.
@@ -172,6 +206,16 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
// be used.
func (conn *Conn) Open(engineCtx context.Context) error {
return conn.open(engineCtx, nil)
}
// OpenWithFirstPacket opens the connection like Open and stashes firstPacket to be replayed once
// the real transport is established. The packet is retained only on a successful open.
func (conn *Conn) OpenWithFirstPacket(engineCtx context.Context, firstPacket []byte) error {
return conn.open(engineCtx, firstPacket)
}
func (conn *Conn) open(engineCtx context.Context, firstPacket []byte) error {
conn.mu.Lock()
defer conn.mu.Unlock()
@@ -227,6 +271,9 @@ func (conn *Conn) Open(engineCtx context.Context) error {
defer conn.wg.Done()
conn.guard.Start(conn.ctx, conn.onGuardEvent)
}()
if len(firstPacket) > 0 {
conn.pendingFirstPacket = slices.Clone(firstPacket)
}
conn.opened = true
return nil
}
@@ -423,6 +470,8 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
conn.wgProxyRelay.RedirectAs(ep)
}
conn.injectPendingFirstPacket(wgProxy, iceConnInfo.RemoteConn)
conn.currentConnPriority = priority
conn.statusICE.SetConnected()
conn.updateIceState(iceConnInfo, updateTime)
@@ -546,6 +595,8 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
wgConfigWorkaround()
conn.injectPendingFirstPacket(wgProxy, nil)
conn.rosenpassRemoteKey = rci.rosenpassPubKey
conn.currentConnPriority = conntype.Relay
conn.statusRelay.SetConnected()

View File

@@ -88,11 +88,24 @@ func (s *Store) PeerConnOpen(ctx context.Context, pubKey string) {
if !ok {
return
}
// this can be blocked because of the connect open limiter semaphore
if err := p.Open(ctx); err != nil {
p.Log.Errorf("failed to open peer connection: %v", err)
}
}
// PeerConnOpenWithFirstPacket opens the peer connection and stashes a first packet to be
// reinjected once the real transport is established.
func (s *Store) PeerConnOpenWithFirstPacket(ctx context.Context, pubKey string, firstPacket []byte) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return
}
if err := p.OpenWithFirstPacket(ctx, firstPacket); err != nil {
p.Log.Errorf("failed to open peer connection: %v", err)
}
}
func (s *Store) PeerConnIdle(pubKey string) {

View File

@@ -101,8 +101,6 @@ type ConfigInput struct {
DNSLabels domain.List
LazyConnectionEnabled *bool
MTU *uint16
}
@@ -180,7 +178,9 @@ type Config struct {
ClientCertKeyPair *tls.Certificate `json:"-"`
LazyConnectionEnabled bool
// LazyConnection is the MDM-managed lazy-connection override ("on"/"off"/"").
// Runtime-only: re-derived from MDM policy on each load, never persisted.
LazyConnection string `json:"-"`
MTU uint16
@@ -632,12 +632,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
updated = true
}
if input.MTU != nil && *input.MTU != config.MTU {
log.Infof("updating MTU to %d (old value %d)", *input.MTU, config.MTU)
config.MTU = *input.MTU
@@ -728,6 +722,11 @@ func (config *Config) applyMDMPolicy(policy *mdm.Policy) {
log.Warnf("MDM wireguard port %d out of range [1,65535]; keeping previous value", v)
}
}
if v, ok := policy.GetString(mdm.KeyLazyConnection); ok {
config.LazyConnection = v
logApplied(mdm.KeyLazyConnection, v)
}
}
// parseURL parses and validates the URL for the named service. The URL

View File

@@ -226,12 +226,11 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
return
}
// All query types for an intercepted domain are forwarded to the peer's
// DNS forwarder, which owns the name. Falling through to the system
// resolver would let it answer NXDOMAIN for a name it isn't authoritative
// for, poisoning the whole name (including the A/AAAA records the route
// does serve). The forwarder answers NODATA for types it cannot resolve.
d.mu.RLock()
peerKey := d.currentPeerKey
d.mu.RUnlock()
@@ -293,19 +292,6 @@ func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger
}
}
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed writing DNS continue response: %v", err)
}
}
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
if !exists {

View File

@@ -38,7 +38,7 @@ func GetEnvKeyNBForceRelay() string {
// GetEnvKeyNBLazyConn Exports the environment variable for the iOS client
func GetEnvKeyNBLazyConn() string {
return lazyconn.EnvEnableLazyConn
return lazyconn.EnvLazyConn
}
// GetEnvKeyNBInactivityThreshold Exports the environment variable for the iOS client

View File

@@ -41,6 +41,10 @@ const (
// construction — only one mode can be set at a time.
KeySplitTunnelMode = "splitTunnelMode"
KeySplitTunnelApps = "splitTunnelApps"
// KeyLazyConnection forces the lazy-connection feature on or off, overriding
// the management feature flag. Values: "on" / "off" (absent = defer to management).
KeyLazyConnection = "lazyConnection"
)
// Split-tunnel mode literals (KeySplitTunnelMode values).
@@ -67,7 +71,6 @@ var boolStringLiterals = map[string]bool{
"no": false,
}
// Policy holds MDM-managed settings read from the platform source. A nil or
// empty Policy means no enforcement is active.
type Policy struct {

View File

@@ -152,7 +152,6 @@ func (s *Server) restartEngineForMDMLocked() error {
s.config = config
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
ctx, cancel := context.WithCancel(s.rootCtx)
s.actCancel = cancel
@@ -305,7 +304,6 @@ func setConfigRequestHasConfigOverrides(msg *proto.SetConfigRequest) bool {
msg.DisableFirewall != nil ||
msg.BlockLanAccess != nil ||
msg.DisableNotifications != nil ||
msg.LazyConnectionEnabled != nil ||
msg.BlockInbound != nil ||
msg.DisableIpv6 != nil ||
msg.EnableSSHRoot != nil ||
@@ -348,7 +346,6 @@ func loginRequestHasConfigOverrides(msg *proto.LoginRequest) bool {
msg.BlockLanAccess != nil ||
msg.DisableNotifications != nil ||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
msg.LazyConnectionEnabled != nil ||
msg.BlockInbound != nil
}

View File

@@ -214,7 +214,6 @@ func (s *Server) Start() error {
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
if s.sessionWatcher == nil {
s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder)
@@ -463,7 +462,6 @@ func (s *Server) setConfigInputFromRequest(msg *proto.SetConfigRequest) (profile
config.DisableFirewall = msg.DisableFirewall
config.BlockLANAccess = msg.BlockLanAccess
config.DisableNotifications = msg.DisableNotifications
config.LazyConnectionEnabled = msg.LazyConnectionEnabled
config.BlockInbound = msg.BlockInbound
config.DisableIPv6 = msg.DisableIpv6
config.EnableSSHRoot = msg.EnableSSHRoot
@@ -1647,7 +1645,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
ServerSSHAllowed: *cfg.ServerSSHAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
RosenpassPermissive: cfg.RosenpassPermissive,
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
BlockInbound: cfg.BlockInbound,
DisableNotifications: disableNotifications,
NetworkMonitor: networkMonitor,

View File

@@ -69,43 +69,41 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
disableFirewall := true
blockLANAccess := true
disableNotifications := true
lazyConnectionEnabled := true
blockInbound := true
disableIPv6 := true
mtu := int64(1280)
sshJWTCacheTTL := int32(300)
req := &proto.SetConfigRequest{
ProfileName: profName,
Username: currUser.Username,
ManagementUrl: "https://new-api.netbird.io:443",
AdminURL: "https://new-admin.netbird.io",
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
DisableAutoConnect: &disableAutoConnect,
NetworkMonitor: &networkMonitor,
DisableClientRoutes: &disableClientRoutes,
DisableServerRoutes: &disableServerRoutes,
DisableDns: &disableDNS,
DisableFirewall: &disableFirewall,
BlockLanAccess: &blockLANAccess,
DisableNotifications: &disableNotifications,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,
DisableIpv6: &disableIPv6,
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
CleanNATExternalIPs: false,
CustomDNSAddress: []byte("1.1.1.1:53"),
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
DnsLabels: []string{"label1", "label2"},
CleanDNSLabels: false,
DnsRouteInterval: durationpb.New(2 * time.Minute),
Mtu: &mtu,
SshJWTCacheTTL: &sshJWTCacheTTL,
ProfileName: profName,
Username: currUser.Username,
ManagementUrl: "https://new-api.netbird.io:443",
AdminURL: "https://new-admin.netbird.io",
RosenpassEnabled: &rosenpassEnabled,
RosenpassPermissive: &rosenpassPermissive,
ServerSSHAllowed: &serverSSHAllowed,
InterfaceName: &interfaceName,
WireguardPort: &wireguardPort,
OptionalPreSharedKey: &preSharedKey,
DisableAutoConnect: &disableAutoConnect,
NetworkMonitor: &networkMonitor,
DisableClientRoutes: &disableClientRoutes,
DisableServerRoutes: &disableServerRoutes,
DisableDns: &disableDNS,
DisableFirewall: &disableFirewall,
BlockLanAccess: &blockLANAccess,
DisableNotifications: &disableNotifications,
BlockInbound: &blockInbound,
DisableIpv6: &disableIPv6,
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
CleanNATExternalIPs: false,
CustomDNSAddress: []byte("1.1.1.1:53"),
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
DnsLabels: []string{"label1", "label2"},
CleanDNSLabels: false,
DnsRouteInterval: durationpb.New(2 * time.Minute),
Mtu: &mtu,
SshJWTCacheTTL: &sshJWTCacheTTL,
}
_, err = s.SetConfig(ctx, req)
@@ -140,7 +138,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.Equal(t, blockLANAccess, cfg.BlockLANAccess)
require.NotNil(t, cfg.DisableNotifications)
require.Equal(t, disableNotifications, *cfg.DisableNotifications)
require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled)
require.Equal(t, blockInbound, cfg.BlockInbound)
require.Equal(t, disableIPv6, cfg.DisableIPv6)
require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs)
@@ -164,13 +161,14 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
t.Helper()
metadataFields := map[string]bool{
"state": true, // protobuf internal
"sizeCache": true, // protobuf internal
"unknownFields": true, // protobuf internal
"Username": true, // metadata
"ProfileName": true, // metadata
"CleanNATExternalIPs": true, // control flag for clearing
"CleanDNSLabels": true, // control flag for clearing
"state": true, // protobuf internal
"sizeCache": true, // protobuf internal
"unknownFields": true, // protobuf internal
"Username": true, // metadata
"ProfileName": true, // metadata
"CleanNATExternalIPs": true, // control flag for clearing
"CleanDNSLabels": true, // control flag for clearing
"LazyConnectionEnabled": true, // deprecated: proto field retained for compat, no longer applied
}
expectedFields := map[string]bool{
@@ -190,7 +188,6 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
"DisableFirewall": true,
"BlockLanAccess": true,
"DisableNotifications": true,
"LazyConnectionEnabled": true,
"BlockInbound": true,
"DisableIpv6": true,
"NatExternalIPs": true,
@@ -252,7 +249,6 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
"block-lan-access": "BlockLanAccess",
"block-inbound": "BlockInbound",
"disable-ipv6": "DisableIpv6",
"enable-lazy-connection": "LazyConnectionEnabled",
"external-ip-map": "NatExternalIPs",
"dns-resolver-address": "CustomDNSAddress",
"extra-iface-blacklist": "ExtraIFaceBlacklist",
@@ -269,7 +265,8 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
fieldsWithoutCLIFlags := map[string]bool{
"DisableNotifications": true, // Only settable via UI
"DisableNotifications": true, // Only settable via UI
"LazyConnectionEnabled": true, // deprecated: no longer settable (managed by server + NB_LAZY_CONN)
}
// Get all SetConfigRequest fields to verify our map is complete.

View File

@@ -2,9 +2,11 @@ package system
import (
"context"
"errors"
"net/netip"
"slices"
"strings"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/metadata"
@@ -72,8 +74,6 @@ type Info struct {
BlockInbound bool
DisableIPv6 bool
LazyConnectionEnabled bool
EnableSSHRoot bool
EnableSSHSFTP bool
EnableSSHLocalPortForwarding bool
@@ -85,7 +85,7 @@ func (i *Info) SetFlags(
rosenpassEnabled, rosenpassPermissive bool,
serverSSHAllowed *bool,
disableClientRoutes, disableServerRoutes,
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool,
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6 bool,
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
disableSSHAuth *bool,
) {
@@ -103,8 +103,6 @@ func (i *Info) SetFlags(
i.BlockInbound = blockInbound
i.DisableIPv6 = disableIPv6
i.LazyConnectionEnabled = lazyConnectionEnabled
if enableSSHRoot != nil {
i.EnableSSHRoot = *enableSSHRoot
}
@@ -174,7 +172,7 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs .
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
}
files, err := checkFileAndProcess(processCheckPaths)
files, err := checkFileAndProcess(ctx, processCheckPaths)
if err != nil {
return nil, err
}
@@ -187,3 +185,43 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs .
log.Debugf("all system information gathered successfully")
return info, nil
}
// GetInfoWithChecksTimeout is GetInfoWithChecks bounded by timeout. Posture-check gathering
// runs uncancellable system calls (process enumeration, os.Stat), so calling it inline can
// block the caller for as long as such a call hangs. It runs in a goroutine instead: if it
// does not return within timeout the caller gets (nil, false) and should proceed with
// degraded behavior rather than block. On a gathering error it falls back to base GetInfo.
//
// The buffered channel lets the abandoned goroutine finish and exit once its blocking call
// returns, so it does not leak beyond the duration of that call.
func GetInfoWithChecksTimeout(ctx context.Context, timeout time.Duration, checks []*proto.Checks, excludeIPs ...netip.Addr) (*Info, bool) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
infoCh := make(chan *Info, 1)
go func() {
info, err := GetInfoWithChecks(ctx, checks, excludeIPs...)
if err != nil {
if ctx.Err() != nil {
return
}
log.Warnf("failed to get system info with checks: %v", err)
info = GetInfo(ctx)
info.removeAddresses(excludeIPs...)
}
infoCh <- info
}()
select {
case info := <-infoCh:
return info, true
case <-ctx.Done():
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
log.Warnf("gathering system info with checks timed out after %s", timeout)
} else {
// Parent context canceled (e.g. shutdown), not a timeout.
log.Warnf("gathering system info with checks canceled: %v", ctx.Err())
}
return nil, false
}
}

View File

@@ -50,7 +50,7 @@ func GetInfo(ctx context.Context) *Info {
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
return []File{}, nil
}

View File

@@ -32,7 +32,7 @@ func GetInfo(ctx context.Context) *Info {
sysName := string(bytes.Split(utsname.Sysname[:], []byte{0})[0])
machine := string(bytes.Split(utsname.Machine[:], []byte{0})[0])
release := string(bytes.Split(utsname.Release[:], []byte{0})[0])
swVersion, err := exec.Command("sw_vers", "-productVersion").Output()
swVersion, err := exec.CommandContext(ctx, "sw_vers", "-productVersion").Output()
if err != nil {
log.Warnf("got an error while retrieving macOS version with sw_vers, error: %s. Using darwin version instead.\n", err)
swVersion = []byte(release)

View File

@@ -105,7 +105,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
return []File{}, nil
}

View File

@@ -103,7 +103,7 @@ func collectLocationInfo(info *Info) {
}
}
func checkFileAndProcess(_ []string) ([]File, error) {
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
return []File{}, nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"
@@ -35,6 +36,20 @@ func Test_CustomHostname(t *testing.T) {
assert.Equal(t, want, got.Hostname)
}
func TestGetInfoWithChecksTimeout_Success(t *testing.T) {
info, ok := GetInfoWithChecksTimeout(context.Background(), 30*time.Second, nil)
assert.True(t, ok, "expected gathering to complete within the timeout")
assert.NotNil(t, info)
}
func TestGetInfoWithChecksTimeout_Timeout(t *testing.T) {
// A 1ns budget expires before the (real) system-info gathering can finish, so the
// caller must get (nil, false) instead of blocking on the in-flight goroutine.
info, ok := GetInfoWithChecksTimeout(context.Background(), time.Nanosecond, nil)
assert.False(t, ok, "expected timeout to be reported")
assert.Nil(t, info)
}
func Test_NetAddresses(t *testing.T) {
addr, err := networkAddresses()
if err != nil {

View File

@@ -3,24 +3,30 @@
package system
import (
"context"
"os"
"slices"
"github.com/shirou/gopsutil/v3/process"
)
// getRunningProcesses returns a list of running process paths.
func getRunningProcesses() ([]string, error) {
processIDs, err := process.Pids()
// getRunningProcesses returns a list of running process paths. The context bounds the work:
// the per-PID loop bails as soon as ctx is done, and the gopsutil calls honor it where they
// can, so a stuck enumeration cannot run unbounded.
func getRunningProcesses(ctx context.Context) ([]string, error) {
processIDs, err := process.PidsWithContext(ctx)
if err != nil {
return nil, err
}
processMap := make(map[string]bool)
for _, pID := range processIDs {
if err := ctx.Err(); err != nil {
return nil, err
}
p := &process.Process{Pid: pID}
path, _ := p.Exe()
path, _ := p.ExeWithContext(ctx)
if path != "" {
processMap[path] = false
}
@@ -35,18 +41,21 @@ func getRunningProcesses() ([]string, error) {
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
func checkFileAndProcess(ctx context.Context, paths []string) ([]File, error) {
files := make([]File, len(paths))
if len(paths) == 0 {
return files, nil
}
runningProcesses, err := getRunningProcesses()
runningProcesses, err := getRunningProcesses(ctx)
if err != nil {
return nil, err
}
for i, path := range paths {
if err := ctx.Err(); err != nil {
return nil, err
}
file := File{Path: path}
_, err := os.Stat(path)

View File

@@ -1,6 +1,7 @@
package system
import (
"context"
"testing"
"github.com/shirou/gopsutil/v3/process"
@@ -9,7 +10,7 @@ import (
func Benchmark_getRunningProcesses(b *testing.B) {
b.Run("getRunningProcesses new", func(b *testing.B) {
for i := 0; i < b.N; i++ {
ps, err := getRunningProcesses()
ps, err := getRunningProcesses(context.Background())
if err != nil {
b.Fatalf("unexpected error: %v", err)
}
@@ -29,12 +30,38 @@ func Benchmark_getRunningProcesses(b *testing.B) {
}
}
})
s, _ := getRunningProcesses()
s, _ := getRunningProcesses(context.Background())
b.Logf("getRunningProcesses returned %d processes", len(s))
s, _ = getRunningProcessesOld()
b.Logf("getRunningProcessesOld returned %d processes", len(s))
}
func TestCheckFileAndProcess_ContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
// With a canceled context and non-empty paths the gathering must bail with an error
// instead of running the (potentially blocking) process scan / stat loop.
if _, err := checkFileAndProcess(ctx, []string{"/does/not/exist"}); err == nil {
t.Fatal("expected error on canceled context, got nil")
}
}
func TestCheckFileAndProcess_EmptyPaths(t *testing.T) {
// No check paths means no work to do: it must return immediately with no error,
// even on a canceled context (nothing to scan or stat).
ctx, cancel := context.WithCancel(context.Background())
cancel()
files, err := checkFileAndProcess(ctx, nil)
if err != nil {
t.Fatalf("unexpected error for empty paths: %v", err)
}
if len(files) != 0 {
t.Fatalf("expected no files, got %d", len(files))
}
}
func getRunningProcessesOld() ([]string, error) {
processes, err := process.Processes()
if err != nil {

View File

@@ -266,7 +266,6 @@ type serviceClient struct {
mAllowSSH *systray.MenuItem
mAutoConnect *systray.MenuItem
mEnableRosenpass *systray.MenuItem
mLazyConnEnabled *systray.MenuItem
mBlockInbound *systray.MenuItem
mNotifications *systray.MenuItem
mAdvancedSettings *systray.MenuItem
@@ -336,11 +335,11 @@ type serviceClient struct {
// mNetworks + mExitNode submenu items. Combines features.DisableNetworks
// AND s.connected — both must be true for the menus to be active.
// Zero value (false) matches the Disable() call at AddMenuItem time.
networksMenuEnabled bool
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
wQuickActions fyne.Window
networksMenuEnabled bool
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
wQuickActions fyne.Window
eventManager *event.Manager
@@ -1094,7 +1093,6 @@ func (s *serviceClient) onTrayReady() {
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable Lazy Connections", lazyConnMenuDescr, false)
s.mBlockInbound = s.mSettings.AddSubMenuItemCheckbox("Block Inbound Connections", blockInboundMenuDescr, false)
s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false)
s.mSettings.AddSeparator()
@@ -1578,7 +1576,6 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
config.RosenpassEnabled = cfg.RosenpassEnabled
config.RosenpassPermissive = cfg.RosenpassPermissive
config.DisableNotifications = &cfg.DisableNotifications
config.LazyConnectionEnabled = cfg.LazyConnectionEnabled
config.BlockInbound = cfg.BlockInbound
config.NetworkMonitor = &cfg.NetworkMonitor
config.DisableDNS = cfg.DisableDns
@@ -1682,12 +1679,6 @@ func (s *serviceClient) loadSettings() {
s.mEnableRosenpass.Uncheck()
}
if cfg.LazyConnectionEnabled {
s.mLazyConnEnabled.Check()
} else {
s.mLazyConnEnabled.Uncheck()
}
if cfg.BlockInbound {
s.mBlockInbound.Check()
} else {
@@ -1833,7 +1824,6 @@ func (s *serviceClient) updateConfig() error {
disableAutoStart := !s.mAutoConnect.Checked()
sshAllowed := s.mAllowSSH.Checked()
rosenpassEnabled := s.mEnableRosenpass.Checked()
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
blockInbound := s.mBlockInbound.Checked()
notificationsDisabled := !s.mNotifications.Checked()
@@ -1856,14 +1846,13 @@ func (s *serviceClient) updateConfig() error {
}
req := proto.SetConfigRequest{
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
DisableAutoConnect: &disableAutoStart,
ServerSSHAllowed: &sshAllowed,
RosenpassEnabled: &rosenpassEnabled,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,
DisableNotifications: &notificationsDisabled,
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
DisableAutoConnect: &disableAutoStart,
ServerSSHAllowed: &sshAllowed,
RosenpassEnabled: &rosenpassEnabled,
BlockInbound: &blockInbound,
DisableNotifications: &notificationsDisabled,
}
if _, err := conn.SetConfig(s.ctx, &req); err != nil {

View File

@@ -4,7 +4,6 @@ const (
allowSSHMenuDescr = "Allow SSH connections"
autoConnectMenuDescr = "Connect automatically when the service starts"
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
lazyConnMenuDescr = "[Experimental] Enable lazy connections"
blockInboundMenuDescr = "Block inbound connections to the local machine and routed networks"
notificationsMenuDescr = "Enable notifications"
advancedSettingsMenuDescr = "Advanced settings of the application"

View File

@@ -43,8 +43,6 @@ func (h *eventHandler) listen(ctx context.Context) {
h.handleAutoConnectClick()
case <-h.client.mEnableRosenpass.ClickedCh:
h.handleRosenpassClick()
case <-h.client.mLazyConnEnabled.ClickedCh:
h.handleLazyConnectionClick()
case <-h.client.mBlockInbound.ClickedCh:
h.handleBlockInboundClick()
case <-h.client.mAdvancedSettings.ClickedCh:
@@ -152,15 +150,6 @@ func (h *eventHandler) handleRosenpassClick() {
}
}
func (h *eventHandler) handleLazyConnectionClick() {
h.toggleCheckbox(h.client.mLazyConnEnabled)
if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error
log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update lazy connection settings")
}
}
func (h *eventHandler) handleBlockInboundClick() {
h.toggleCheckbox(h.client.mBlockInbound)
if err := h.updateConfigWithErr(); err != nil {

View File

@@ -9,6 +9,8 @@ set -o pipefail
SED_STRIP_PADDING='s/=//g'
NETBIRD_EULA_URL="https://netbird.io/self-hosted-EULA"
check_docker_compose() {
if command -v docker-compose &> /dev/null; then
echo "docker-compose"
@@ -139,6 +141,43 @@ read_yes_no() {
esac
}
# Gate the install on explicit acceptance of the NetBird On-Premise EULA.
require_eula_acceptance() {
cat > /dev/stderr <<EOF
──────────────────────────────────────────────────────────────────────
NetBird On-Premise End User License Agreement
──────────────────────────────────────────────────────────────────────
NetBird's on-premise software is commercial software, licensed and not
sold. Your installation, deployment and use are governed by the NetBird
On-Premise End User License Agreement (the "EULA"). Please read the EULA
in full before continuing:
${NETBIRD_EULA_URL}
By typing "accept" and continuing the installation, you confirm that you
have read and agree to the EULA, that you are authorized to accept it on
behalf of your organization (the "Customer"), and that the Software is
used for business purposes only.
──────────────────────────────────────────────────────────────────────
EOF
if [[ "${NB_ACCEPT_EULA:-}" == "yes" ]]; then
echo "EULA accepted via NB_ACCEPT_EULA=yes." > /dev/stderr
return 0
fi
local ans=""
echo -n 'Type "accept" to agree, or anything else to abort: ' > /dev/stderr
read -r ans < /dev/tty
if [[ "$ans" != "accept" ]]; then
echo "" > /dev/stderr
echo "EULA not accepted. Aborting installation." > /dev/stderr
exit 1
fi
echo "" > /dev/stderr
}
wait_postgres() {
set +e
echo -n "Waiting for postgres to become ready"
@@ -174,6 +213,9 @@ init_environment() {
exit 1
fi
require_eula_acceptance
NETBIRD_EULA_ACCEPTED_AT=$(date -u +%Y-%m-%dT%H:%M:%SZ)
echo "NetBird Enterprise bootstrap"
echo ""
echo "Traffic flow:"
@@ -260,6 +302,11 @@ render_env() {
# Generated by getting-started-enterprise.sh
# Holds all configuration and secrets for the stack. Mode 600.
# NetBird On-Premise EULA acceptance
NETBIRD_EULA_ACCEPTED=yes
NETBIRD_EULA_ACCEPTED_AT=${NETBIRD_EULA_ACCEPTED_AT}
NETBIRD_EULA_URL=${NETBIRD_EULA_URL}
# Features (set by the script; don't edit without re-running)
NETBIRD_TRAFFIC_FLOW_ENABLED=${NETBIRD_TRAFFIC_FLOW}

View File

@@ -25,6 +25,8 @@ set -o pipefail
OVERRIDE_FILE="docker-compose.override.yml"
ENTERPRISE_CONFIG_FILE="config.yaml.enterprise"
NETBIRD_EULA_URL="https://netbird.io/self-hosted-EULA"
check_docker_compose() {
if command -v docker-compose &> /dev/null; then
echo "docker-compose"
@@ -115,6 +117,43 @@ read_yes_no() {
esac
}
# Gate the migration on explicit acceptance of the NetBird On-Premise EULA.
require_eula_acceptance() {
cat > /dev/stderr <<EOF
──────────────────────────────────────────────────────────────────────
NetBird On-Premise End User License Agreement
──────────────────────────────────────────────────────────────────────
NetBird's on-premise software is commercial software, licensed and not
sold. Your installation, deployment and use are governed by the NetBird
On-Premise End User License Agreement (the "EULA"). Please read the EULA
in full before continuing:
${NETBIRD_EULA_URL}
By typing "accept" and continuing the installation, you confirm that you
have read and agree to the EULA, that you are authorized to accept it on
behalf of your organization (the "Customer"), and that the Software is
used for business purposes only.
──────────────────────────────────────────────────────────────────────
EOF
if [[ "${NB_ACCEPT_EULA:-}" == "yes" ]]; then
echo "EULA accepted via NB_ACCEPT_EULA=yes." > /dev/stderr
return 0
fi
local ans=""
echo -n 'Type "accept" to agree, or anything else to abort: ' > /dev/stderr
read -r ans < /dev/tty
if [[ "$ans" != "accept" ]]; then
echo "" > /dev/stderr
echo "EULA not accepted. Aborting migration." > /dev/stderr
exit 1
fi
echo "" > /dev/stderr
}
# ---------------------------------------------------------------------------
# Detection — read the operator's existing compose to find service names and
# paths we need to override. Bail loudly if shape isn't recognised.
@@ -436,6 +475,9 @@ init_migration() {
echo " Network: $COMPOSE_NETWORK"
echo ""
require_eula_acceptance
NETBIRD_EULA_ACCEPTED_AT=$(date -u +%Y-%m-%dT%H:%M:%SZ)
local proceed
proceed=$(read_yes_no "Proceed with migration?" "y")
if [[ "$proceed" != "yes" ]]; then
@@ -529,6 +571,10 @@ apply_changes() {
{
echo ""
echo "# Added by migrate-to-enterprise.sh on $(date -u +%Y-%m-%dT%H:%M:%SZ)"
echo "# NetBird On-Premise EULA accepted at install time"
echo "NETBIRD_EULA_ACCEPTED=yes"
echo "NETBIRD_EULA_ACCEPTED_AT=${NETBIRD_EULA_ACCEPTED_AT}"
echo "NETBIRD_EULA_URL=${NETBIRD_EULA_URL}"
echo "NB_LICENSE_KEY=${NB_LICENSE_KEY}"
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
echo "NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}"

View File

@@ -55,6 +55,14 @@ type GrpcClient struct {
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
serverURL string
// syncStreamErr holds the last Sync stream error, or nil while the stream
// is established and healthy. GetServerKey succeeds even when the peer
// cannot sync (e.g. the server returns "settings not found"), so the
// health probe must consult this to avoid reporting a healthy management
// connection while the Sync stream keeps failing.
syncStreamMu sync.RWMutex
syncStreamErr error
}
type ExposeRequest struct {
@@ -364,6 +372,8 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
stream, err := c.connectToSyncStream(ctx, serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
c.notifyDisconnected(err)
c.setSyncStreamDisconnected(err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
}
@@ -372,11 +382,13 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
log.Infof("connected to the Management Service stream")
c.notifyConnected()
c.setSyncStreamConnected()
// blocking until error
err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler)
if err != nil {
c.notifyDisconnected(err)
c.setSyncStreamDisconnected(err)
if ctx.Err() != nil {
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
@@ -530,6 +542,13 @@ func (c *GrpcClient) IsHealthy() bool {
log.Warnf("health check returned: %s", err)
return false
}
if syncErr := c.syncStreamError(); syncErr != nil {
c.notifyDisconnected(syncErr)
log.Warnf("management transport is up but the Sync stream is unhealthy: %s", syncErr)
return false
}
c.notifyConnected()
return true
}
@@ -771,6 +790,24 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
return err
}
func (c *GrpcClient) setSyncStreamConnected() {
c.syncStreamMu.Lock()
defer c.syncStreamMu.Unlock()
c.syncStreamErr = nil
}
func (c *GrpcClient) setSyncStreamDisconnected(err error) {
c.syncStreamMu.Lock()
defer c.syncStreamMu.Unlock()
c.syncStreamErr = err
}
func (c *GrpcClient) syncStreamError() error {
c.syncStreamMu.RLock()
defer c.syncStreamMu.RUnlock()
return c.syncStreamErr
}
func (c *GrpcClient) notifyDisconnected(err error) {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()
@@ -993,8 +1030,6 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
BlockLANAccess: info.BlockLANAccess,
BlockInbound: info.BlockInbound,
DisableIPv6: info.DisableIPv6,
LazyConnectionEnabled: info.LazyConnectionEnabled,
},
Capabilities: peerCapabilities(*info),

View File

@@ -85,6 +85,7 @@ type GrpcClient struct {
// receive backpressure as a dead stream: reconnecting cannot help, since the
// new stream feeds the same worker, and only triggers a reconnect storm.
receiveHandoffBlocked atomic.Bool
watchdogWg sync.WaitGroup
}
// NewClient creates a new Signal client
@@ -200,10 +201,18 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes
// Guard the receive direction: the transport can stay healthy while the
// server stops delivering messages. The watchdog reconnects via cancelStream.
c.markReceived()
go c.watchReceiveStream(streamCtx, cancelStream)
c.watchdogWg.Add(1)
go func() {
defer c.watchdogWg.Done()
c.watchReceiveStream(streamCtx, cancelStream)
}()
// start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream)
cancelStream()
c.watchdogWg.Wait()
if err != nil {
// Check the parent context, not streamCtx: a watchdog-triggered
// cancelStream must reconnect, only a parent cancel is shutdown.
@@ -400,7 +409,12 @@ func (c *GrpcClient) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage
// Send sends a message to the remote Peer through the Signal Exchange.
func (c *GrpcClient) Send(msg *proto.Message) error {
return c.send(c.ctx, msg)
}
// send delivers a message deriving per-attempt timeouts from parentCtx, so a
// caller can abort an in-flight send by cancelling that context.
func (c *GrpcClient) send(parentCtx context.Context, msg *proto.Message) error {
if !c.Ready() {
return fmt.Errorf("no connection to signal")
}
@@ -416,7 +430,7 @@ func (c *GrpcClient) Send(msg *proto.Message) error {
if attempt > 1 {
attemptTimeout = time.Duration(attempt) * 5 * time.Second
}
ctx, cancel := context.WithTimeout(c.ctx, attemptTimeout)
ctx, cancel := context.WithTimeout(parentCtx, attemptTimeout)
_, err = c.realClient.Send(ctx, encryptedMessage)
@@ -486,7 +500,7 @@ func (c *GrpcClient) watchReceiveStream(ctx context.Context, cancelStream contex
}
if probeSentAt.IsZero() {
if err := c.sendReceiveProbe(); err != nil {
if err := c.sendReceiveProbe(ctx); err != nil {
log.Debugf("failed to send signal receive probe: %v", err)
}
probeSentAt = time.Now()
@@ -495,11 +509,13 @@ func (c *GrpcClient) watchReceiveStream(ctx context.Context, cancelStream contex
}
}
// sendReceiveProbe sends a self-addressed heartbeat. The Signal server routes it
// back to this client, exercising the exact receive path the watchdog guards.
func (c *GrpcClient) sendReceiveProbe() error {
// sendReceiveProbe sends a self-addressed heartbeat bound to ctx, so cancelStream
// aborts an in-flight probe instead of leaving the watchdog blocked on send timeouts.
// The Signal server routes it back to this client, exercising the exact receive
// path the watchdog guards.
func (c *GrpcClient) sendReceiveProbe(ctx context.Context) error {
self := c.key.PublicKey().String()
return c.Send(&proto.Message{
return c.send(ctx, &proto.Message{
Key: self,
RemoteKey: self,
Body: &proto.Body{Type: proto.Body_HEARTBEAT},
@@ -541,6 +557,9 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient) er
if err := c.decryptionWorker.AddMsg(c.ctx, msg); err != nil {
log.Errorf("failed to add message to decryption worker: %v", err)
}
// Refresh liveness before clearing the flag so the window between here and
// the next Recv does not read a stale timestamp as a dead stream.
c.markReceived()
c.receiveHandoffBlocked.Store(false)
}
}

View File

@@ -2,6 +2,7 @@ package client
import (
"context"
"io"
"net"
"testing"
"time"
@@ -74,7 +75,7 @@ func TestReceiveProbeRoundTrips(t *testing.T) {
t.Fatal("signal stream did not connect within timeout")
}
require.NoError(t, client.sendReceiveProbe())
require.NoError(t, client.sendReceiveProbe(ctx))
select {
case <-received:
@@ -106,3 +107,72 @@ func TestReceiveAliveTreatsHandoffBlockAsLiveness(t *testing.T) {
c.markReceived()
require.True(t, c.receiveAlive(), "a freshly received frame must keep the stream alive")
}
// fakeRecvStream feeds the receive loop frames from a channel and reports EOF
// once the channel is closed. Only Recv is exercised by the loop.
type fakeRecvStream struct {
sigProto.SignalExchange_ConnectStreamClient
frames chan *sigProto.EncryptedMessage
}
func (s *fakeRecvStream) Recv() (*sigProto.EncryptedMessage, error) {
msg, ok := <-s.frames
if !ok {
return nil, io.EOF
}
return msg, nil
}
// TestReceiveLoopRefreshesLivenessAfterBlockedHandoff drives the real receive
// loop into a handoff that blocks past the inactivity threshold, then checks the
// window after the handoff drains but before the next Recv. The loop must have
// refreshed the timestamp on unblocking, otherwise that window reads the stale
// pre-handoff timestamp as a dead stream and the watchdog tears down a healthy
// connection.
func TestReceiveLoopRefreshesLivenessAfterBlockedHandoff(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
c := &GrpcClient{ctx: ctx}
handling := make(chan struct{}, 8)
gate := make(chan struct{})
decrypt := func(*sigProto.EncryptedMessage) (*sigProto.Message, error) { return &sigProto.Message{}, nil }
handler := func(*sigProto.Message) error {
handling <- struct{}{}
<-gate
return nil
}
c.decryptionWorker = NewWorker(decrypt, handler)
workerCtx, workerCancel := context.WithCancel(context.Background())
go c.decryptionWorker.Work(workerCtx)
t.Cleanup(workerCancel)
frames := make(chan *sigProto.EncryptedMessage)
t.Cleanup(func() { close(frames) })
go func() { _ = c.receive(&fakeRecvStream{frames: frames}) }()
// First frame: the worker drains it and parks in the blocking handler.
frames <- &sigProto.EncryptedMessage{}
<-handling
// Second frame fills the worker's single-slot pool.
frames <- &sigProto.EncryptedMessage{}
// Third frame: the pool is full, so the loop parks on the handoff.
frames <- &sigProto.EncryptedMessage{}
require.Eventually(t, c.receiveHandoffBlocked.Load, time.Second, time.Millisecond,
"receive loop should park on the worker handoff")
// Simulate the handoff having blocked past the inactivity threshold.
c.lastReceived.Store(time.Now().Add(-2 * receiveInactivityThreshold).UnixNano())
require.True(t, c.receiveAlive(), "a loop parked on the handoff must stay alive")
// Drain the worker so the handoff returns and the loop resumes reading.
close(gate)
// Once the handoff clears, the loop is parked on the next Recv with no frame
// pending. The stream must still read as alive in that window.
require.Eventually(t, func() bool { return !c.receiveHandoffBlocked.Load() }, time.Second, time.Millisecond,
"handoff should drain once the worker is released")
require.True(t, c.receiveAlive(),
"the loop must refresh liveness when the handoff drains, before the next Recv")
}