Compare commits

..

9 Commits

Author SHA1 Message Date
dependabot[bot]
4cc54695a3 Bump the actions group with 5 updates
Bumps the actions group with 5 updates:

| Package | From | To |
| --- | --- | --- |
| [actions/cache](https://github.com/actions/cache) | `6.0.0` | `6.1.0` |
| [vmactions/freebsd-vm](https://github.com/vmactions/freebsd-vm) | `1.4.8` | `1.5.0` |
| [actions/cache/restore](https://github.com/actions/cache) | `6.0.0` | `6.1.0` |
| [golangci/golangci-lint-action](https://github.com/golangci/golangci-lint-action) | `9.2.1` | `9.3.0` |
| [goreleaser/goreleaser-action](https://github.com/goreleaser/goreleaser-action) | `7.2.2` | `7.2.3` |


Updates `actions/cache` from 6.0.0 to 6.1.0
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](2c8a9bd745...55cc834586)

Updates `vmactions/freebsd-vm` from 1.4.8 to 1.5.0
- [Release notes](https://github.com/vmactions/freebsd-vm/releases)
- [Commits](b84ab5559b...5a72679103)

Updates `actions/cache/restore` from 6.0.0 to 6.1.0
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](2c8a9bd745...55cc834586)

Updates `golangci/golangci-lint-action` from 9.2.1 to 9.3.0
- [Release notes](https://github.com/golangci/golangci-lint-action/releases)
- [Commits](82606bf257...ba0d7d2ec0)

Updates `goreleaser/goreleaser-action` from 7.2.2 to 7.2.3
- [Release notes](https://github.com/goreleaser/goreleaser-action/releases)
- [Commits](5daf1e915a...f06c13b6b1)

---
updated-dependencies:
- dependency-name: actions/cache
  dependency-version: 6.1.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: actions
- dependency-name: vmactions/freebsd-vm
  dependency-version: 1.5.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: actions
- dependency-name: actions/cache/restore
  dependency-version: 6.1.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: actions
- dependency-name: golangci/golangci-lint-action
  dependency-version: 9.3.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: actions
- dependency-name: goreleaser/goreleaser-action
  dependency-version: 7.2.3
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: actions
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-06-30 02:46:15 +00:00
Zoltan Papp
04c3d19032 [client] Skip firewall ruleset rebuild when config is unchanged (#6508)
* [client] Skip firewall ruleset rebuild when config is unchanged

ApplyFiltering rebuilt every peer and route ACL and flushed the firewall
on every sync, with no guard for an unchanged configuration. Management
re-sends the same network map far more often than it actually changes
(account-wide updates, peer meta churn), so on busy accounts this is the
dominant client-side cost of redundant syncs — especially with a large
route set and a userspace firewall.

Hash the inputs ApplyFiltering consumes (peer rules, route rules, the
empty flag and the dns-route feature flag) and skip the rebuild + flush
when the hash matches the last successfully applied update. Mirrors the
guard the DNS server already uses (previousConfigHash). The hash is only
recorded after apply and flush both succeed, so a failed update is not
skipped on the next (possibly identical) sync and gets a chance to
reconcile the firewall state.

* [client] Include config hash in ACL skip debug log

* [client] Include RoutesFirewallRulesIsEmpty in firewall config hash

* [client] Add benchmarks for firewall config hash computation
2026-06-29 19:51:50 +02:00
Zoltan Papp
3f1fb3b52d [ingest] raise duration validation limit to 24 hours (#6598)
Peer connection timing fields (signaling_to_connection_seconds) can
legitimately exceed 5 minutes during long reconnections; the previous
300 s cap caused valid data points to be rejected.
2026-06-29 19:51:25 +02:00
Viktor Liu
b434cda062 [client] Refresh signal receive liveness when worker handoff drains (#6594) 2026-06-29 12:16:47 +02:00
Zoltan Papp
0b594c639a [client] report management unhealthy while Sync stream is failing (#6575)
* fix(mgm): report management unhealthy while Sync stream is failing

The health probe (IsHealthy) only checked the gRPC transport and a
GetServerKey call. GetServerKey succeeds even when the peer cannot sync
(e.g. the server returns "settings not found"), so the probe kept marking
management Connected while the Sync stream failed in a tight retry loop —
pinning the status to "Connected" forever despite no sync ever succeeding.

Track the last Sync stream error and have IsHealthy consult it, so a
healthy transport is no longer enough to report the connection healthy.

* fix(mgm): record disconnected state when sync stream setup fails

The connectToSyncStream failure path in handleSyncStream returned early
without updating syncStreamErr, so the client could still report healthy
even when stream setup failed. Mirror the receiveUpdatesEvents error path
by calling notifyDisconnected and setSyncStreamDisconnected.
2026-06-29 11:28:58 +02:00
Zoltan Papp
deff8af59f [client] Wait for signal receive watchdog to stop before reconnect (#6574)
* [client] Wait for signal receive watchdog to stop before reconnect

The per-stream watchReceiveStream goroutine was started fire-and-forget
and never joined. On reconnect a lingering watchdog could still flip
shared client state (receiveStalled, the disconnect notifier) on the
freshly established stream, since cancelStream only cancels its own
stream context.

Track the watchdog with a WaitGroup and wait for it to exit (after
cancelling its stream) before the operation returns, so each reconnect
starts with no stale watchdog.

* [client] Bind signal receive probe to the stream context

The watchdog probe reused the generic Send, which derives its per-attempt
timeouts from the long-lived client context, so cancelStream could not
interrupt an in-flight probe. After joining the watchdog on reconnect,
watchdogWg.Wait() could then block for the full send-attempt chain.

Split Send into a context-aware send and pass the stream context down
through sendReceiveProbe, so cancelStream aborts any in-flight probe and
the watchdog exits promptly.
2026-06-29 11:24:25 +02:00
Riccardo Manfrin
5711f0e38c [client] add per-phase timing metrics for sync processing (#6533)
* Adds metrics sync phases time split to monitor costs

* Address review fixes

* Increment README.md with description on usage with debug bundles
2026-06-29 11:02:02 +02:00
Maycon Santos
1409a1325a [misc] Update careers page link (#6538) 2026-06-29 09:19:01 +02:00
Viktor Liu
4400372f37 [client] Forward non-address DNS record types through route forwarders (#6455) 2026-06-28 18:50:17 +02:00
32 changed files with 1614 additions and 747 deletions

View File

@@ -27,7 +27,7 @@ jobs:
cache: false
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }}

View File

@@ -28,7 +28,7 @@ jobs:
id: test
env:
GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
uses: vmactions/freebsd-vm@5a72679103d223925653750faa878a143340fbd0 # v1.5.0
with:
usesh: true
copyback: false

View File

@@ -41,7 +41,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
id: cache
with:
path: |
@@ -135,7 +135,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}
@@ -192,7 +192,7 @@ jobs:
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
id: cache-restore
with:
path: |
@@ -266,7 +266,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}
@@ -325,7 +325,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}
@@ -383,7 +383,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}
@@ -440,7 +440,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}
@@ -545,7 +545,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}
@@ -640,7 +640,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}
@@ -710,7 +710,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}

View File

@@ -35,7 +35,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
${{ env.cache }}

View File

@@ -56,7 +56,7 @@ jobs:
if: matrix.os == 'ubuntu-latest'
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: golangci-lint
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
uses: golangci/golangci-lint-action@ba0d7d2ec06a0ea1cb5fa41b2e4a3ab91d21278a #v9.3.0
with:
version: latest
skip-cache: true

View File

@@ -34,7 +34,7 @@ jobs:
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620

View File

@@ -64,7 +64,7 @@ jobs:
if: steps.check_diff.outputs.diff_exists == 'true'
env:
GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
uses: vmactions/freebsd-vm@5a72679103d223925653750faa878a143340fbd0 # v1.5.0
with:
usesh: true
copyback: false
@@ -171,7 +171,7 @@ jobs:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache/restore@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
~/go/pkg/mod
@@ -221,7 +221,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
uses: goreleaser/goreleaser-action@f06c13b6b1a9625abc9e6e439d9c05a8f2190e94 # v7.2.3
with:
version: ${{ env.GORELEASER_VER }}
args: release --clean ${{ env.flags }}
@@ -379,7 +379,7 @@ jobs:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
~/go/pkg/mod
@@ -420,7 +420,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
uses: goreleaser/goreleaser-action@f06c13b6b1a9625abc9e6e439d9c05a8f2190e94 # v7.2.3
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
@@ -474,7 +474,7 @@ jobs:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: |
~/go/pkg/mod
@@ -488,7 +488,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Run GoReleaser
id: goreleaser
uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
uses: goreleaser/goreleaser-action@f06c13b6b1a9625abc9e6e439d9c05a8f2190e94 # v7.2.3
with:
version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}

View File

@@ -78,7 +78,7 @@ jobs:
go-version-file: "go.mod"
- name: Cache Go modules
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
uses: actions/cache@55cc8345863c7cc4c66a329aec7e433d2d1c52a9 # v6.1.0
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}

View File

@@ -29,7 +29,7 @@ jobs:
- name: Install dependencies
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
- name: Install golangci-lint
uses: golangci/golangci-lint-action@82606bf257cbaff209d206a39f5134f0cfbfd2ee #v9.2.1
uses: golangci/golangci-lint-action@ba0d7d2ec06a0ea1cb5fa41b2e4a3ab91d21278a #v9.3.0
with:
version: latest
install-mode: binary

View File

@@ -33,7 +33,7 @@
<br/>
<br/>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
🚀 <a href="https://netbird.io/careers">We are hiring! Join us at https://netbird.io/careers</a>
</strong>
</p>

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/hashicorp/go-multierror"
"github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
@@ -30,11 +31,13 @@ type Manager interface {
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]struct{}
mutex sync.Mutex
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]struct{}
previousConfigHash uint64
hasAppliedConfig bool
mutex sync.Mutex
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
@@ -57,6 +60,23 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
return
}
// Skip the full rebuild + flush when the inputs that drive the firewall
// state are byte-for-byte identical to the last successfully applied
// update. Management re-sends the same network map far more often than it
// actually changes (account-wide updates, peer meta churn), and rebuilding
// every peer/route ACL and flushing the firewall on every such sync is the
// dominant client-side cost when nothing changed. Mirrors the same guard the
// DNS server already uses (previousConfigHash). Only the fields ApplyFiltering
// consumes participate in the hash, so an unrelated map change cannot mask a
// real ACL change.
hash, err := d.firewallConfigHash(networkMap, dnsRouteFeatureFlag)
if err != nil {
log.Errorf("unable to hash firewall configuration, applying unconditionally: %v", err)
} else if d.hasAppliedConfig && d.previousConfigHash == hash {
log.Debugf("not applying the firewall configuration update as there is nothing new (hash: %d)", hash)
return
}
start := time.Now()
defer func() {
total := 0
@@ -70,13 +90,49 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
d.applyPeerACLs(networkMap)
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err)
routeErr := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag)
if routeErr != nil {
log.Errorf("Failed to apply route ACLs: %v", routeErr)
}
if err := d.firewall.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
flushErr := d.firewall.Flush()
if flushErr != nil {
log.Error("failed to flush firewall rules: ", flushErr)
}
// Only remember the hash once the firewall actually reflects this config.
// If applying or flushing failed, leave the previous hash untouched so the
// next (possibly identical) update is not skipped and gets a chance to
// reconcile the firewall state.
if err == nil && routeErr == nil && flushErr == nil {
d.previousConfigHash = hash
d.hasAppliedConfig = true
} else {
d.hasAppliedConfig = false
}
}
// firewallConfigHash hashes exactly the inputs ApplyFiltering uses to build the
// firewall state, so an identical hash means an identical resulting ruleset.
func (d *DefaultManager) firewallConfigHash(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) (uint64, error) {
return hashstructure.Hash(struct {
PeerRules []*mgmProto.FirewallRule
PeerRulesIsEmpty bool
RouteRules []*mgmProto.RouteFirewallRule
RouteRulesIsEmpty bool
DNSRouteFeatureFlag bool
}{
PeerRules: networkMap.GetFirewallRules(),
PeerRulesIsEmpty: networkMap.GetFirewallRulesIsEmpty(),
RouteRules: networkMap.GetRoutesFirewallRules(),
RouteRulesIsEmpty: networkMap.GetRoutesFirewallRulesIsEmpty(),
DNSRouteFeatureFlag: dnsRouteFeatureFlag,
}, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
}
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {

View File

@@ -1,6 +1,7 @@
package acl
import (
"fmt"
"net/netip"
"testing"
@@ -485,3 +486,149 @@ func TestPortInfoEmpty(t *testing.T) {
})
}
}
// TestApplyFilteringSkipsUnchangedConfig verifies that an identical network map
// re-applied is recognized as a no-op (hash unchanged), while a real change to
// any firewall-relevant input forces a re-apply (hash changes). This is the
// guard that prevents a full ruleset rebuild + flush on every redundant sync.
func TestApplyFilteringSkipsUnchangedConfig(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap, false)
require.True(t, acl.hasAppliedConfig, "config should be marked applied after first apply")
firstHash := acl.previousConfigHash
require.NotZero(t, firstHash)
// Re-applying the identical map must not change the recorded hash: the
// expensive rebuild path was skipped.
acl.ApplyFiltering(networkMap, false)
assert.Equal(t, firstHash, acl.previousConfigHash,
"identical re-apply must be a no-op (hash unchanged)")
// A real change must produce a different hash and re-apply.
networkMap.FirewallRules[0].Action = mgmProto.RuleAction_DROP
acl.ApplyFiltering(networkMap, false)
assert.NotEqual(t, firstHash, acl.previousConfigHash,
"changing a rule's action must force a re-apply (hash changed)")
// The dnsRouteFeatureFlag also participates in the hash.
changedHash := acl.previousConfigHash
acl.ApplyFiltering(networkMap, true)
assert.NotEqual(t, changedHash, acl.previousConfigHash,
"flipping dnsRouteFeatureFlag must force a re-apply (hash changed)")
}
func buildNetworkMap(peerRules, routeRules int) *mgmProto.NetworkMap {
nm := &mgmProto.NetworkMap{
FirewallRulesIsEmpty: peerRules == 0,
RoutesFirewallRulesIsEmpty: routeRules == 0,
}
for i := range peerRules {
nm.FirewallRules = append(nm.FirewallRules, &mgmProto.FirewallRule{
PeerIP: fmt.Sprintf("10.%d.%d.%d", i>>16&0xff, i>>8&0xff, i&0xff),
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: fmt.Sprintf("%d", 1024+i%64511),
})
}
for i := range routeRules {
nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, &mgmProto.RouteFirewallRule{
Destination: fmt.Sprintf("192.168.%d.0/24", i%256),
SourceRanges: []string{fmt.Sprintf("10.0.%d.0/24", i%256)},
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
})
}
return nm
}
func BenchmarkFirewallConfigHash_Small(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(10, 5)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
func BenchmarkFirewallConfigHash_Medium(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(100, 50)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
func BenchmarkFirewallConfigHash_Large(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(1000, 200)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
// TestFirewallConfigHashDeterministic verifies the hash is stable for equal
// inputs and order-independent for the rule slices (management does not
// guarantee rule order).
func TestFirewallConfigHashDeterministic(t *testing.T) {
d := &DefaultManager{}
nm1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{PeerIP: "10.0.0.1", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_ACCEPT, Protocol: mgmProto.RuleProtocol_TCP, Port: "22"},
{PeerIP: "10.0.0.2", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_DROP, Protocol: mgmProto.RuleProtocol_TCP, Port: "80"},
},
}
// Same rules, reversed order.
nm2 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
nm1.FirewallRules[1],
nm1.FirewallRules[0],
},
}
h1, err := d.firewallConfigHash(nm1, false)
require.NoError(t, err)
h2, err := d.firewallConfigHash(nm2, false)
require.NoError(t, err)
assert.Equal(t, h1, h2, "hash must be order-independent for rule slices")
}

View File

@@ -8,6 +8,7 @@ import (
"errors"
"net"
"net/netip"
"slices"
"strings"
"github.com/miekg/dns"
@@ -167,7 +168,10 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
case dns.TypeA:
alternativeNetwork = "ip6"
default:
return dns.RcodeNameError
// Non-address types reach LookupIP only unexpectedly; without an
// address pair to probe we cannot prove the name is absent, so answer
// NODATA rather than a poisoning NXDOMAIN.
return dns.RcodeSuccess
}
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
@@ -184,6 +188,230 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
return dns.RcodeSuccess
}
// RecordResolver is the host resolver surface used to forward non-address
// record queries. net.DefaultResolver satisfies it.
type RecordResolver interface {
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
LookupCNAME(ctx context.Context, host string) (string, error)
LookupAddr(ctx context.Context, addr string) ([]string, error)
}
// LookupRecords resolves a non-address DNS record type through the host
// resolver and returns the resource records and the DNS rcode. Types the host
// resolver cannot answer (anything not covered by the net.Resolver Lookup*
// methods) yield NODATA so that a routed name is never poisoned with NXDOMAIN
// for an unsupported type.
func LookupRecords(ctx context.Context, r RecordResolver, name string, qtype uint16, ttl uint32) ([]dns.RR, int) {
fqdn := dns.Fqdn(name)
switch qtype {
case dns.TypeMX:
return lookupMX(ctx, r, name, fqdn, ttl)
case dns.TypeTXT:
return lookupTXT(ctx, r, name, fqdn, ttl)
case dns.TypeNS:
return lookupNS(ctx, r, name, fqdn, ttl)
case dns.TypeSRV:
return lookupSRV(ctx, r, name, fqdn, ttl)
case dns.TypeCNAME:
return lookupCNAME(ctx, r, name, fqdn, ttl)
case dns.TypePTR:
return lookupPTR(ctx, r, name, fqdn, ttl)
default:
return nil, dns.RcodeSuccess
}
}
func recordHeader(fqdn string, rrtype uint16, ttl uint32) dns.RR_Header {
return dns.RR_Header{Name: fqdn, Rrtype: rrtype, Class: dns.ClassINET, Ttl: ttl}
}
func lookupMX(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupMX(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, mx := range recs {
rrs = append(rrs, &dns.MX{
Hdr: recordHeader(fqdn, dns.TypeMX, ttl),
Preference: mx.Pref,
Mx: dns.Fqdn(mx.Host),
})
}
return rrs, dns.RcodeSuccess
}
func lookupTXT(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupTXT(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, txt := range recs {
rrs = append(rrs, &dns.TXT{
Hdr: recordHeader(fqdn, dns.TypeTXT, ttl),
Txt: chunkTXT(txt),
})
}
return rrs, dns.RcodeSuccess
}
func lookupNS(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupNS(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, ns := range recs {
rrs = append(rrs, &dns.NS{
Hdr: recordHeader(fqdn, dns.TypeNS, ttl),
Ns: dns.Fqdn(ns.Host),
})
}
return rrs, dns.RcodeSuccess
}
func lookupSRV(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
_, recs, err := r.LookupSRV(ctx, "", "", name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, srv := range recs {
rrs = append(rrs, &dns.SRV{
Hdr: recordHeader(fqdn, dns.TypeSRV, ttl),
Priority: srv.Priority,
Weight: srv.Weight,
Port: srv.Port,
Target: dns.Fqdn(srv.Target),
})
}
return rrs, dns.RcodeSuccess
}
func lookupCNAME(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
cname, err := r.LookupCNAME(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
// LookupCNAME returns the queried name itself when the name resolves but
// has no CNAME record; that is a NODATA result, not a CNAME.
if strings.EqualFold(dns.Fqdn(cname), fqdn) {
return nil, dns.RcodeSuccess
}
return []dns.RR{&dns.CNAME{
Hdr: recordHeader(fqdn, dns.TypeCNAME, ttl),
Target: dns.Fqdn(cname),
}}, dns.RcodeSuccess
}
func lookupPTR(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
addr, ok := ptrQueryAddr(name)
if !ok {
return nil, dns.RcodeSuccess
}
names, err := r.LookupAddr(ctx, addr)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(names))
for _, n := range names {
rrs = append(rrs, &dns.PTR{
Hdr: recordHeader(fqdn, dns.TypePTR, ttl),
Ptr: dns.Fqdn(n),
})
}
return rrs, dns.RcodeSuccess
}
// ptrQueryAddr converts a reverse-DNS query name (in-addr.arpa or ip6.arpa)
// into the address string expected by net.Resolver.LookupAddr. It reports false
// when the name is not a well-formed reverse name.
func ptrQueryAddr(qname string) (string, bool) {
name := strings.TrimSuffix(strings.ToLower(dns.Fqdn(qname)), ".")
switch {
case strings.HasSuffix(name, ".in-addr.arpa"):
return parseInAddrArpa(strings.TrimSuffix(name, ".in-addr.arpa"))
case strings.HasSuffix(name, ".ip6.arpa"):
return parseIP6Arpa(strings.TrimSuffix(name, ".ip6.arpa"))
default:
return "", false
}
}
// parseInAddrArpa turns the label portion of an in-addr.arpa name into an IPv4
// address string, reporting false when it is not a well-formed reverse name.
func parseInAddrArpa(labelPart string) (string, bool) {
labels := strings.Split(labelPart, ".")
if len(labels) != 4 {
return "", false
}
slices.Reverse(labels)
addr, err := netip.ParseAddr(strings.Join(labels, "."))
if err != nil || !addr.Is4() {
return "", false
}
return addr.String(), true
}
// parseIP6Arpa turns the nibble portion of an ip6.arpa name into an IPv6
// address string, reporting false when it is not a well-formed reverse name.
func parseIP6Arpa(nibblePart string) (string, bool) {
nibbles := strings.Split(nibblePart, ".")
if len(nibbles) != 32 {
return "", false
}
slices.Reverse(nibbles)
var sb strings.Builder
for i, n := range nibbles {
if i > 0 && i%4 == 0 {
sb.WriteByte(':')
}
sb.WriteString(n)
}
addr, err := netip.ParseAddr(sb.String())
if err != nil || !addr.Is6() {
return "", false
}
return addr.String(), true
}
// rcodeForRecordError maps a non-address lookup error to a DNS rcode. A
// not-found result becomes NODATA rather than NXDOMAIN: net.DNSError.IsNotFound
// does not distinguish a missing name from a name that exists only with records
// of other types, so the name cannot be proven absent and must not be poisoned.
func rcodeForRecordError(err error) int {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
return dns.RcodeSuccess
}
return dns.RcodeServerFailure
}
// chunkTXT splits a TXT string into character-strings no longer than 255 bytes
// so the record can be packed. The chunks form one TXT resource record.
func chunkTXT(s string) []string {
const maxLen = 255
if len(s) <= maxLen {
return []string{s}
}
var chunks []string
for len(s) > maxLen {
chunks = append(chunks, s[:maxLen])
s = s[maxLen:]
}
if len(s) > 0 {
chunks = append(chunks, s)
}
return chunks
}
// FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 {

View File

@@ -5,6 +5,7 @@ import (
"errors"
"net"
"net/netip"
"strings"
"testing"
"github.com/miekg/dns"
@@ -121,6 +122,164 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
}
func TestPtrQueryAddr(t *testing.T) {
tests := []struct {
name string
qname string
want string
wantOK bool
}{
{name: "ipv4", qname: "4.3.2.1.in-addr.arpa.", want: "1.2.3.4", wantOK: true},
{name: "ipv4 no trailing dot", qname: "1.0.0.127.in-addr.arpa", want: "127.0.0.1", wantOK: true},
{
name: "ipv6",
qname: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
want: "2001:db8::1",
wantOK: true,
},
{name: "ipv4 wrong label count", qname: "2.1.in-addr.arpa.", wantOK: false},
{name: "ipv6 wrong nibble count", qname: "1.0.ip6.arpa.", wantOK: false},
{name: "not a reverse name", qname: "example.com.", wantOK: false},
{name: "ipv4 bad octet", qname: "4.3.2.999.in-addr.arpa.", wantOK: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := ptrQueryAddr(tt.qname)
assert.Equal(t, tt.wantOK, ok, "parse success mismatch")
if tt.wantOK {
assert.Equal(t, tt.want, got, "parsed address mismatch")
}
})
}
}
type mockRecordResolver struct {
mx []*net.MX
txt []string
ns []*net.NS
srv []*net.SRV
cname string
ptr []string
err error
}
func (m *mockRecordResolver) LookupMX(context.Context, string) ([]*net.MX, error) {
return m.mx, m.err
}
func (m *mockRecordResolver) LookupTXT(context.Context, string) ([]string, error) {
return m.txt, m.err
}
func (m *mockRecordResolver) LookupNS(context.Context, string) ([]*net.NS, error) {
return m.ns, m.err
}
func (m *mockRecordResolver) LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error) {
return "", m.srv, m.err
}
func (m *mockRecordResolver) LookupCNAME(context.Context, string) (string, error) {
return m.cname, m.err
}
func (m *mockRecordResolver) LookupAddr(context.Context, string) ([]string, error) {
return m.ptr, m.err
}
func TestLookupRecords(t *testing.T) {
notFound := &net.DNSError{IsNotFound: true, Name: "example.com."}
t.Run("MX success", func(t *testing.T) {
r := &mockRecordResolver{mx: []*net.MX{{Host: "mail.example.com.", Pref: 10}}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "mail.example.com.", rrs[0].(*dns.MX).Mx)
})
t.Run("TXT short string is one character-string", func(t *testing.T) {
r := &mockRecordResolver{txt: []string{"v=spf1 -all"}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, []string{"v=spf1 -all"}, rrs[0].(*dns.TXT).Txt)
})
t.Run("TXT chunks long strings", func(t *testing.T) {
long := strings.Repeat("a", 300)
r := &mockRecordResolver{txt: []string{long}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
txt := rrs[0].(*dns.TXT).Txt
require.Len(t, txt, 2, "300-byte string should split into two character-strings")
assert.Equal(t, 255, len(txt[0]))
assert.Equal(t, 45, len(txt[1]))
})
t.Run("NS success", func(t *testing.T) {
r := &mockRecordResolver{ns: []*net.NS{{Host: "ns1.example.com."}}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeNS, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "ns1.example.com.", rrs[0].(*dns.NS).Ns)
})
t.Run("SRV success", func(t *testing.T) {
r := &mockRecordResolver{srv: []*net.SRV{{Target: "sip.example.com.", Port: 5060}}}
rrs, rcode := LookupRecords(context.Background(), r, "_sip._tcp.example.com.", dns.TypeSRV, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, uint16(5060), rrs[0].(*dns.SRV).Port)
})
t.Run("CNAME success", func(t *testing.T) {
r := &mockRecordResolver{cname: "target.example.com."}
rrs, rcode := LookupRecords(context.Background(), r, "www.example.com.", dns.TypeCNAME, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "target.example.com.", rrs[0].(*dns.CNAME).Target)
})
t.Run("CNAME equal to name is NODATA", func(t *testing.T) {
r := &mockRecordResolver{cname: "example.com."}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCNAME, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs, "self-referential CNAME is NODATA")
})
t.Run("PTR success", func(t *testing.T) {
r := &mockRecordResolver{ptr: []string{"host.example.com."}}
rrs, rcode := LookupRecords(context.Background(), r, "4.3.2.1.in-addr.arpa.", dns.TypePTR, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "host.example.com.", rrs[0].(*dns.PTR).Ptr)
})
t.Run("PTR malformed name is NODATA", func(t *testing.T) {
r := &mockRecordResolver{}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypePTR, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs)
})
t.Run("not found is NODATA never NXDOMAIN", func(t *testing.T) {
r := &mockRecordResolver{err: notFound}
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeSuccess, rcode, "missing record must not poison the name")
})
t.Run("server failure maps to SERVFAIL", func(t *testing.T) {
r := &mockRecordResolver{err: &net.DNSError{Err: "server misbehaving", IsTemporary: true}}
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeServerFailure, rcode)
})
t.Run("unsupported type is NODATA", func(t *testing.T) {
r := &mockRecordResolver{}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCAA, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs)
})
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{

View File

@@ -37,6 +37,12 @@ const (
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
LookupCNAME(ctx context.Context, host string) (string, error)
LookupAddr(ctx context.Context, addr string) ([]string, error)
}
type firewaller interface {
@@ -210,12 +216,6 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
f.writeResponse(logger, w, resp, qname, startTime)
return
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
if mostSpecificResId == "" {
@@ -227,9 +227,46 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
reqHasEdns := query.IsEdns0() != nil
switch question.Qtype {
case dns.TypeA, dns.TypeAAAA:
f.handleAddressQuery(ctx, logger, w, resp, mostSpecificResId, matchingEntries, reqHasEdns, startTime)
case dns.TypeMX, dns.TypeTXT, dns.TypeNS, dns.TypeSRV, dns.TypeCNAME, dns.TypePTR:
f.handleRecordQuery(ctx, logger, w, resp, startTime)
default:
// The domain is routed here, so any other type is answered NODATA
// (NOERROR, empty answer) rather than falling back to a resolver that
// would poison the name with NXDOMAIN. The Extended DNS Error lets a
// client tell this capability-driven NODATA apart from an
// authoritative one. The OPT pseudo-record must not appear unless the
// query advertised EDNS0.
if reqHasEdns {
attachEDE(resp, dns.ExtendedErrorCodeNotSupported, "netbird forwarder: unsupported query type")
}
f.writeResponse(logger, w, resp, qname, startTime)
}
}
// handleAddressQuery resolves A/AAAA queries, programs the firewall sets and
// resolved-IP state, and caches the answer for resilience on upstream failure.
func (f *DNSForwarder) handleAddressQuery(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
resp *dns.Msg,
mostSpecificResId route.ResID,
matchingEntries []*ForwarderEntry,
reqHasEdns bool,
startTime time.Time,
) {
question := resp.Question[0]
qname := strings.ToLower(question.Name)
network := resutil.NetworkForQtype(question.Qtype)
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
f.handleDNSError(ctx, logger, w, question, resp, qname, result, reqHasEdns, startTime)
return
}
@@ -240,6 +277,25 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
f.writeResponse(logger, w, resp, qname, startTime)
}
// handleRecordQuery resolves non-address record types (MX, TXT, NS, SRV,
// CNAME, PTR) through the host resolver. Missing records are answered NODATA so
// the routed name is never poisoned with NXDOMAIN.
func (f *DNSForwarder) handleRecordQuery(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
resp *dns.Msg,
startTime time.Time,
) {
question := resp.Question[0]
qname := strings.ToLower(question.Name)
records, rcode := resutil.LookupRecords(ctx, f.resolver, qname, question.Qtype, f.ttl)
resp.Rcode = rcode
resp.Answer = append(resp.Answer, records...)
f.writeResponse(logger, w, resp, qname, startTime)
}
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)

View File

@@ -133,6 +133,41 @@ func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([
return args.Get(0).([]netip.Addr), args.Error(1)
}
func (m *MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]*net.MX)
return recs, args.Error(1)
}
func (m *MockResolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]string)
return recs, args.Error(1)
}
func (m *MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]*net.NS)
return recs, args.Error(1)
}
func (m *MockResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
args := m.Called(ctx, service, proto, name)
recs, _ := args.Get(1).([]*net.SRV)
return args.String(0), recs, args.Error(2)
}
func (m *MockResolver) LookupCNAME(ctx context.Context, host string) (string, error) {
args := m.Called(ctx, host)
return args.String(0), args.Error(1)
}
func (m *MockResolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
args := m.Called(ctx, addr)
recs, _ := args.Get(0).([]string)
return recs, args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct {
name string
@@ -545,12 +580,15 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
}
func TestDNSForwarder_ResponseCodes(t *testing.T) {
// A type with no net.Resolver Lookup method (CAA) must answer NODATA
// (NOERROR, empty) rather than NXDOMAIN/NOTIMP to avoid poisoning the name.
tests := []struct {
name string
queryType uint16
queryDomain string
configured string
expectedCode int
expectEDE bool
description string
}{
{
@@ -562,28 +600,13 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
description: "RFC compliant REFUSED for unauthorized queries",
},
{
name: "unsupported query type returns NOTIMP",
queryType: dns.TypeMX,
name: "unsupported query type returns NODATA",
queryType: dns.TypeCAA,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "RFC compliant NOTIMP for unsupported types",
},
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
expectedCode: dns.RcodeSuccess,
expectEDE: true,
description: "Unsupported types answer NODATA, not NXDOMAIN/NOTIMP",
},
}
@@ -599,6 +622,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
query.SetEdns0(dns.DefaultMsgSize, false)
// Capture the written response
var writtenResp *dns.Msg
@@ -614,10 +638,213 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
assert.Empty(t, writtenResp.Answer, "Non-address response should carry no answers")
if tt.expectEDE {
require.NotNil(t, writtenResp.IsEdns0(), "EDNS0 client should get an OPT in the reply")
assert.True(t, hasEDE(writtenResp, dns.ExtendedErrorCodeNotSupported),
"unsupported type NODATA should carry EDE Not Supported")
}
})
}
}
func hasEDE(m *dns.Msg, code uint16) bool {
opt := m.IsEdns0()
if opt == nil {
return false
}
for _, o := range opt.Option {
if ede, ok := o.(*dns.EDNS0_EDE); ok && ede.InfoCode == code {
return true
}
}
return false
}
func TestDNSForwarder_RecordQueries(t *testing.T) {
notFound := &net.DNSError{IsNotFound: true, Name: "example.com"}
t.Run("MX records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupMX", mock.Anything, "example.com.").
Return([]*net.MX{{Host: "mail.example.com.", Pref: 10}}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
mx, ok := resp.Answer[0].(*dns.MX)
require.True(t, ok, "answer should be an MX record")
assert.Equal(t, uint16(10), mx.Preference)
assert.Equal(t, "mail.example.com.", mx.Mx)
mockResolver.AssertExpectations(t)
})
t.Run("missing MX is NODATA not NXDOMAIN", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
// A not-found cannot prove the name is absent (it may exist with only
// other record types), so it must answer NODATA, never NXDOMAIN.
mockResolver.On("LookupMX", mock.Anything, "example.com.").
Return(nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "missing record must be NODATA")
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("NS records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupNS", mock.Anything, "example.com.").
Return([]*net.NS{{Host: "ns1.example.com."}}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ns, ok := resp.Answer[0].(*dns.NS)
require.True(t, ok, "answer should be an NS record")
assert.Equal(t, "ns1.example.com.", ns.Ns)
mockResolver.AssertExpectations(t)
})
t.Run("missing NS is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupNS", mock.Anything, "example.com.").
Return(nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("SRV records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
Return("", []*net.SRV{{Target: "sip.example.com.", Port: 5060, Priority: 10, Weight: 5}}, nil).Once()
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
srv, ok := resp.Answer[0].(*dns.SRV)
require.True(t, ok, "answer should be an SRV record")
assert.Equal(t, "sip.example.com.", srv.Target)
assert.Equal(t, uint16(5060), srv.Port)
assert.Equal(t, uint16(10), srv.Priority)
mockResolver.AssertExpectations(t)
})
t.Run("missing SRV is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
Return("", nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("TXT records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupTXT", mock.Anything, "example.com.").
Return([]string{"v=spf1 -all"}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeTXT)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
txt, ok := resp.Answer[0].(*dns.TXT)
require.True(t, ok, "answer should be a TXT record")
assert.Equal(t, []string{"v=spf1 -all"}, txt.Txt)
mockResolver.AssertExpectations(t)
})
t.Run("CNAME record is forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "www.example.com")
mockResolver.On("LookupCNAME", mock.Anything, "www.example.com.").
Return("target.example.com.", nil).Once()
resp := runRecordQuery(t, forwarder, "www.example.com", dns.TypeCNAME)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok, "answer should be a CNAME record")
assert.Equal(t, "target.example.com.", cname.Target)
mockResolver.AssertExpectations(t)
})
t.Run("CNAME equal to the name is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
// No CNAME exists: LookupCNAME echoes the queried name back.
mockResolver.On("LookupCNAME", mock.Anything, "example.com.").
Return("example.com.", nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeCNAME)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer, "self-referential CNAME means no CNAME record")
mockResolver.AssertExpectations(t)
})
t.Run("PTR record is forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "*.in-addr.arpa")
// The reverse name is parsed back to the address LookupAddr expects.
mockResolver.On("LookupAddr", mock.Anything, "1.2.3.4").
Return([]string{"host.example.com."}, nil).Once()
resp := runRecordQuery(t, forwarder, "4.3.2.1.in-addr.arpa", dns.TypePTR)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok, "answer should be a PTR record")
assert.Equal(t, "host.example.com.", ptr.Ptr)
mockResolver.AssertExpectations(t)
})
}
func newRecordTestForwarder(t *testing.T, r resolver, configured string) *DNSForwarder {
t.Helper()
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = r
d, err := domain.FromString(configured)
require.NoError(t, err)
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
return forwarder
}
func runRecordQuery(t *testing.T, forwarder *DNSForwarder, qname string, qtype uint16) *dns.Msg {
t.Helper()
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(qname), qtype)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp, "expected response to be written")
return resp
}
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
tests := []struct {
name string

View File

@@ -210,12 +210,6 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
// forwardingRules holds the ingress forward rules applied for the current target.
// Wholesale sections (incl. forward rules) run only on the first pass of a target;
// it is stashed here so the final, peer-converged pass can build the lazy-connection
// exclude list without recomputing them on every bounded peer pass.
forwardingRules []firewallManager.ForwardRule
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
@@ -768,15 +762,7 @@ func (e *Engine) blockLanAccess() {
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
// maxPeersPerSyncPass is the default per-pass cap on how many peers each of
// removePeers/modifyPeers/addNewPeers applies, so syncMsgMux is held only for a
// batch at a time and other subsystems can interleave between passes. It is
// passed in (not read globally) so tests can exercise the multi-pass path.
const maxPeersPerSyncPass = 300
// modifyPeers re-applies up to maxBatch changed peers per call. It returns true
// when more changed peers remained than the cap, so the caller re-runs.
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// first, check if peers have been modified
var modified []*mgmProto.RemotePeerConfig
@@ -806,32 +792,26 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch
}
}
more := false
if len(modified) > maxBatch {
modified = modified[:maxBatch]
more = true
}
// second, close all modified connections and remove them from the state map
for _, p := range modified {
if err := e.removePeer(p.GetWgPubKey()); err != nil {
return false, err
err := e.removePeer(p.GetWgPubKey())
if err != nil {
return err
}
}
// third, add the peer connections again
for _, p := range modified {
if err := e.addNewPeer(p); err != nil {
return false, err
err := e.addNewPeer(p)
if err != nil {
return err
}
}
return more, nil
return nil
}
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
// removePeers removes up to maxBatch peers per call. It returns true when more
// peers remained to remove than the cap, so the caller re-runs.
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
newPeers := make([]string, 0, len(peersUpdate))
for _, p := range peersUpdate {
newPeers = append(newPeers, p.GetWgPubKey())
@@ -839,19 +819,14 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
more := false
if len(toRemove) > maxBatch {
toRemove = toRemove[:maxBatch]
more = true
}
for _, p := range toRemove {
if err := e.removePeer(p); err != nil {
return false, err
err := e.removePeer(p)
if err != nil {
return err
}
log.Infof("removed peer %s", p)
}
return more, nil
return nil
}
func (e *Engine) removeAllPeers() error {
@@ -920,25 +895,40 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
}
// applySyncPass applies one bounded pass of the sync update under syncMsgMux and
// returns true if more peers remained than the per-pass cap. It is driven by the
// mapStateManager, which re-invokes it (releasing the lock between passes) until
// the update is fully applied.
func (e *Engine) applySyncPass(update *mgmProto.SyncResponse, firstPass bool) (bool, error) {
// phase times a sync sub-phase: it returns a function that records the elapsed
// duration when called. Starting the timer at the call site keeps inter-phase
// glue code out of the measurement.
func (e *Engine) phase(name string) func() {
start := time.Now()
return func() {
e.clientMetrics.RecordSyncPhase(e.ctx, name, time.Since(start))
}
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
duration := time.Since(started)
log.Infof("sync finished in %s", duration)
e.clientMetrics.RecordSyncDuration(e.ctx, duration)
}()
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
// Check context INSIDE lock to ensure atomicity with shutdown
if e.ctx.Err() != nil {
return false, e.ctx.Err()
return e.ctx.Err()
}
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
}
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
return false, err
done := e.phase("netbird_config")
err := e.updateNetbirdConfig(update.GetNetbirdConfig())
done()
if err != nil {
return err
}
// Posture checks are bound to the network map presence:
@@ -948,22 +938,28 @@ func (e *Engine) applySyncPass(update *mgmProto.SyncResponse, firstPass bool) (b
// leave the previously applied checks untouched
nm := update.GetNetworkMap()
if nm == nil {
return false, nil
return nil
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
return false, err
done = e.phase("checks")
err = e.updateChecksIfNew(update.Checks)
done()
if err != nil {
return err
}
done = e.phase("persist")
e.persistSyncResponse(update)
done()
// only apply new changes and ignore old ones
more, err := e.updateNetworkMap(nm, maxPeersPerSyncPass, firstPass)
if err != nil {
return false, err
if err := e.updateNetworkMap(nm); err != nil {
return err
}
e.statusRecorder.PublishEvent(cProto.SystemEvent_INFO, cProto.SystemEvent_SYSTEM, "Network map updated", "", nil)
return more, nil
return nil
}
// updateNetbirdConfig applies the management-provided NetBird configuration:
@@ -1009,13 +1005,6 @@ func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
// (not syncMsgMux) is held for the whole Set so the store cannot be cleared (disabled /
// engine close) mid-call and have this write resurrect a file that was just removed.
func (e *Engine) persistSyncResponse(update *mgmProto.SyncResponse) {
// Only persist updates that carry a network map. Config-only updates (e.g. relay
// token rotation, STUN/TURN) have a nil NetworkMap; persisting them would overwrite
// the last full map on disk and break restore-on-restart.
if update.GetNetworkMap() == nil {
return
}
e.syncRespMux.RLock()
defer e.syncRespMux.RUnlock()
@@ -1307,19 +1296,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.DisableSSHAuth,
)
// The map-state manager converges the latest update in the background in
// bounded passes; the stream callback only hands it the newest target.
manager := newMapStateManager(e.applySyncPass, e.persistSyncResponse, func(d time.Duration) {
log.Infof("sync finished in %s", d)
e.clientMetrics.RecordSyncDuration(e.ctx, d)
})
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
manager.run(e.ctx)
}()
err = e.mgmClient.Sync(e.ctx, info, manager.SetTarget)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
if err != nil {
// happens if management is unavailable for a long time.
// We want to cancel the operation of the whole client
@@ -1370,104 +1347,21 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
return nil
}
// updateNetworkMap applies the wholesale parts (config, routes, ACL, DNS) in full
// and up to maxBatch peers per phase. It returns true when more peers remained
// than the cap, so the caller re-runs until convergence.
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap, maxBatch int, firstPass bool) (bool, error) {
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
if networkMap.GetPeerConfig() != nil {
err := e.updateConfig(networkMap.GetPeerConfig())
if err != nil {
return false, err
return err
}
}
serial := networkMap.GetSerial()
if e.networkSerial > serial {
log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial)
return false, nil
return nil
}
// Wholesale sections (firewall/ACL, DNS, routes, forward rules) are applied
// up-front and only once per target: they are cheap, local, idempotent and must
// be in place before peers come up (fail-closed). On the bounded re-runs that only
// drain the remaining peer batches they are skipped — the applied forward rules are
// reused from e.forwardingRules for the lazy-exclude finalize.
if firstPass {
e.applyWholesale(networkMap, serial)
}
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers())
// Filter out own peer from the remote peers list
localPubKey := e.config.WgPrivateKey.PublicKey().String()
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
for _, p := range networkMap.GetRemotePeers() {
if p.GetWgPubKey() != localPubKey {
remotePeers = append(remotePeers, p)
}
}
// needMore signals the caller to re-run when a peer phase hit its per-pass cap.
needMore := false
// cleanup request, most likely our peer has been deleted
if networkMap.GetRemotePeersIsEmpty() {
err := e.removeAllPeers()
e.statusRecorder.FinishPeerListModifications()
if err != nil {
return false, err
}
} else {
removeMore, err := e.removePeers(remotePeers, maxBatch)
if err != nil {
return false, err
}
modifyMore, err := e.modifyPeers(remotePeers, maxBatch)
if err != nil {
return false, err
}
addMore, err := e.addNewPeers(remotePeers, maxBatch)
if err != nil {
return false, err
}
needMore = removeMore || modifyMore || addMore
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
}
// Set the exclude list only once peers have fully converged (this pass added
// the last batch). It needs all target peers present in the store, and
// ExcludePeer has replace-semantics — a partial set mid-convergence would be wrong.
if !needMore {
excludedLazyPeers := e.toExcludedLazyPeers(e.forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
}
e.networkSerial = serial
return needMore, nil
}
// applyWholesale applies the cheap, local, idempotent map sections — lazy feature
// flag, firewall/legacy management, DNS, routes, ACL filtering, DNS forwarder and
// ingress forward rules — that must be in place before peers come up. It runs once
// per target (first pass only); the resulting forward rules are stashed in
// e.forwardingRules for the lazy-exclude finalize on the peer-converged pass.
func (e *Engine) applyWholesale(networkMap *mgmProto.NetworkMap, serial uint64) {
if err := e.connMgr.UpdatedRemoteFeatureFlag(e.ctx, networkMap.GetPeerConfig().GetLazyConnectionEnabled()); err != nil {
log.Errorf("failed to update lazy connection feature flag: %v", err)
}
@@ -1495,13 +1389,16 @@ func (e *Engine) applyWholesale(networkMap *mgmProto.NetworkMap, serial uint64)
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address())
done := e.phase("dns_server")
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
done()
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
// apply routes first, route related actions might depend on routing being enabled
done = e.phase("routes_classify")
routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
@@ -1510,25 +1407,111 @@ func (e *Engine) applyWholesale(networkMap *mgmProto.NetworkMap, serial uint64)
e.connMgr.UpdateRouteHAMap(clientRoutes)
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
}
done()
done = e.phase("routes_apply")
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update routes: %v", err)
}
done()
done = e.phase("filtering")
if e.acl != nil {
e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
}
done()
done = e.phase("dns_forwarder")
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
done()
// Ingress forward rules
done = e.phase("forward_rules")
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
if err != nil {
log.Errorf("failed to update forward rules, err: %v", err)
}
e.forwardingRules = forwardingRules
done()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
done = e.phase("offline_peers")
e.updateOfflinePeers(networkMap.GetOfflinePeers())
done()
remotePeers, err := e.reconcilePeers(networkMap)
if err != nil {
return err
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
done = e.phase("lazy_exclude")
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
done()
e.networkSerial = serial
return nil
}
// reconcilePeers applies the remote peer list from the network map (removing,
// modifying and adding peers, then updating SSH config) and returns the remote
// peers with our own peer filtered out, for use by later sync steps.
func (e *Engine) reconcilePeers(networkMap *mgmProto.NetworkMap) ([]*mgmProto.RemotePeerConfig, error) {
// Filter out own peer from the remote peers list
localPubKey := e.config.WgPrivateKey.PublicKey().String()
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
for _, p := range networkMap.GetRemotePeers() {
if p.GetWgPubKey() != localPubKey {
remotePeers = append(remotePeers, p)
}
}
// cleanup request, most likely our peer has been deleted
if networkMap.GetRemotePeersIsEmpty() {
err := e.removeAllPeers()
e.statusRecorder.FinishPeerListModifications()
if err != nil {
return nil, err
}
return remotePeers, nil
}
done := e.phase("removed_peers")
err := e.removePeers(remotePeers)
done()
if err != nil {
return nil, err
}
done = e.phase("modified_peers")
err = e.modifyPeers(remotePeers)
done()
if err != nil {
return nil, err
}
done = e.phase("added_peers")
err = e.addNewPeers(remotePeers)
done()
if err != nil {
return nil, err
}
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
return remotePeers, nil
}
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {
@@ -1708,23 +1691,14 @@ func addrToString(addr netip.Addr) string {
}
// addNewPeers adds peers that were not know before but arrived from the Management service with the update
// addNewPeers adds up to maxBatch not-yet-present peers per call. It returns true
// when more new peers remained than the cap, so the caller re-runs.
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig, maxBatch int) (bool, error) {
added := 0
func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
for _, p := range peersUpdate {
if _, ok := e.peerStore.PeerConn(p.GetWgPubKey()); ok {
continue // already present (cheap skip), does not count toward the cap
err := e.addNewPeer(p)
if err != nil {
return err
}
if added >= maxBatch {
return true, nil // at least one more new peer remains
}
if err := e.addNewPeer(p); err != nil {
return false, err
}
added++
}
return false, nil
return nil
}
// addNewPeer add peer if connection doesn't exist

View File

@@ -124,7 +124,7 @@ func TestEngine_SSH(t *testing.T) {
RemotePeersIsEmpty: false,
}
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
@@ -146,7 +146,7 @@ func TestEngine_SSH(t *testing.T) {
RemotePeersIsEmpty: false,
}
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
time.Sleep(250 * time.Millisecond)
@@ -159,7 +159,7 @@ func TestEngine_SSH(t *testing.T) {
RemotePeersIsEmpty: false,
}
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
// time.Sleep(250 * time.Millisecond)
@@ -174,7 +174,7 @@ func TestEngine_SSH(t *testing.T) {
RemotePeersIsEmpty: false,
}
_, err = engine.updateNetworkMap(networkMap, maxPeersPerSyncPass, true)
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)

View File

@@ -433,7 +433,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
for _, c := range []testCase{case1, case2, case3, case4, case5, case6} {
t.Run(c.name, func(t *testing.T) {
_, err = engine.updateNetworkMap(c.networkMap, maxPeersPerSyncPass, true)
err = engine.updateNetworkMap(c.networkMap)
if err != nil {
t.Fatal(err)
return
@@ -460,47 +460,6 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}
})
}
// chunked apply: with a per-pass cap smaller than the number of peers, a
// single updateNetworkMap applies one batch and reports more==true; the
// caller re-runs until convergence. (engine currently holds 0 peers.)
t.Run("chunked add converges over multiple passes", func(t *testing.T) {
nm := &mgmtProto.NetworkMap{
Serial: 6,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
}
more, err := engine.updateNetworkMap(nm, 1, true)
require.NoError(t, err)
require.True(t, more, "pass 1 should signal more")
require.Len(t, engine.peerStore.PeersPubKey(), 1)
more, err = engine.updateNetworkMap(nm, 1, false)
require.NoError(t, err)
require.True(t, more, "pass 2 should signal more")
require.Len(t, engine.peerStore.PeersPubKey(), 2)
more, err = engine.updateNetworkMap(nm, 1, false)
require.NoError(t, err)
require.False(t, more, "pass 3 should converge")
require.Len(t, engine.peerStore.PeersPubKey(), 3)
})
t.Run("chunked remove converges over multiple passes", func(t *testing.T) {
nm := &mgmtProto.NetworkMap{
Serial: 7,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1}, // remove peer2, peer3
}
more, err := engine.updateNetworkMap(nm, 1, true)
require.NoError(t, err)
require.True(t, more, "pass 1 should signal more (2 to remove, cap 1)")
more, err = engine.updateNetworkMap(nm, 1, false)
require.NoError(t, err)
require.False(t, more, "pass 2 should converge")
require.Len(t, engine.peerStore.PeersPubKey(), 1)
})
}
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
@@ -671,7 +630,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
}
}()
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true)
err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.clientRoutes, testCase.expectedLen, "clientRoutes len should match")
@@ -875,7 +834,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}
}()
_, err = engine.updateNetworkMap(testCase.networkMap, maxPeersPerSyncPass, true)
err = engine.updateNetworkMap(testCase.networkMap)
assert.NoError(t, err, "shouldn't return error")
assert.Equal(t, testCase.expectedSerial, input.inputSerial, "serial should match")
assert.Len(t, input.inputNSGroups, testCase.expectedZonesLen, "zones len should match")

View File

@@ -1,190 +0,0 @@
package internal
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// mapStateManager is the single read/write point between the management stream
// (writes) and the convergence loop (reads/applies).
//
// The stream calls SetTarget with the latest full SyncResponse — the complete
// desired state. A single background goroutine (run) applies it to the engine in
// bounded passes via apply() until converged, releasing syncMsgMux between passes
// so other subsystems interleave. If a newer update arrives mid-flight, the loop
// coalesces: it keeps converging toward the latest target and the intermediate one
// is SKIPPED — never applied on its own (logged, no onConverged).
//
// Convergence is a single comparison: appliedGen == targetGen. targetGen
// increments on every SetTarget (an internal generation counter, so it also covers
// config-only updates that carry no network-map serial).
//
// onConverged fires once for each — and only each — map that is actually processed
// (i.e. converged as the target). Skipped/superseded maps and dropped-on-error maps
// do NOT fire it. So "sync finished in X" / RecordSyncDuration always corresponds
// to a real, completed alignment.
type mapStateManager struct {
// apply performs one bounded apply pass and reports whether more passes are needed.
// firstPass is true on the first pass of a given target, so the caller can run
// wholesale (firewall/routes/DNS/forward-rules) once per target and skip it on the
// re-runs that only drain the bounded peer batches. The manager owns this signal
// because it owns the convergence boundary; the engine need not track serials for it.
apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error)
// onConverged is called once per processed map, with the elapsed time since that
// map was received (for the sync-duration metric / "sync finished" log).
onConverged func(time.Duration)
// persist snapshots an update to disk for restore-on-restart. Called once per
// update received from management (in SetTarget), including ones later coalesced
// or skipped from apply, so the on-disk state mirrors what management last sent.
// The impl skips config-only updates (nil NetworkMap). May be nil.
persist func(*mgmProto.SyncResponse)
mu sync.Mutex
target *mgmProto.SyncResponse
targetGen uint64
appliedGen uint64
targetSetAt time.Time
wake chan struct{}
}
func newMapStateManager(apply func(update *mgmProto.SyncResponse, firstPass bool) (bool, error), persist func(*mgmProto.SyncResponse), onConverged func(time.Duration)) *mapStateManager {
return &mapStateManager{
apply: apply,
persist: persist,
onConverged: onConverged,
wake: make(chan struct{}, 1),
}
}
// SetTarget records the latest update as the desired state and wakes the loop.
// It returns immediately; convergence happens in the background. Serial-based
// staleness of the network map is still enforced inside apply (updateNetworkMap).
func (m *mapStateManager) SetTarget(update *mgmProto.SyncResponse) error {
m.mu.Lock()
// A target that has not settled yet (targetGen > appliedGen) is being superseded
// before it converged: we coalesce to the latest map and never apply this one on
// its own. It is SKIPPED — logged here, and it will not fire onConverged.
if m.target != nil && m.targetGen > m.appliedGen {
log.Debugf("sync map (gen %d) superseded before convergence, skipping", m.targetGen)
}
m.target = m.mergeTarget(m.target, update)
// Bump an internal generation counter, NOT the map serial: config-only updates
// (relay token rotation, STUN/TURN) arrive with NetworkMap == nil and carry no
// serial, yet must still be applied. Every SetTarget is therefore a distinct
// target regardless of payload. Map-serial staleness is enforced separately
// inside apply (updateNetworkMap).
m.targetGen++
m.targetSetAt = time.Now()
m.mu.Unlock()
select {
case m.wake <- struct{}{}:
default:
}
// Persist every update received from management — once per update (not per apply
// pass), and including ones that get coalesced/skipped from apply, so the on-disk
// state always reflects the latest map management sent. Done after waking the loop
// so convergence can start in parallel with the disk write. The persist impl skips
// config-only updates (nil NetworkMap).
if m.persist != nil {
m.persist(update)
}
return nil
}
// mergeTarget combines the currently pending target with a freshly received update
// and returns the new desired state. It is called under m.mu from SetTarget and is
// the single seam where the replace-vs-squash decision lives.
//
// Today management always sends a FULL map (the complete desired state), so the
// update simply replaces whatever was pending — prev is ignored. When management
// starts sending incremental/delta updates, squash `update` onto `prev` here; the
// rest of the manager (generation tracking, convergence, signaling) is unaffected
// because it already treats target as "the complete desired state, whatever it is".
func (m *mapStateManager) mergeTarget(prev, update *mgmProto.SyncResponse) *mgmProto.SyncResponse {
return update
}
// run drives convergence until ctx is done. It is meant to run in its own goroutine.
func (m *mapStateManager) run(ctx context.Context) {
// passGen is the generation of the most recent apply() call (0 = none). A pass is
// the first for its target when its generation differs from the previous one —
// true on a fresh target and on a coalesced switch to a newer target mid-flight.
var passGen uint64
for {
m.mu.Lock()
target, tg, ag := m.target, m.targetGen, m.appliedGen
m.mu.Unlock()
// Fully converged (or nothing yet): block until a new target arrives.
if target == nil || ag == tg {
select {
case <-ctx.Done():
return
case <-m.wake:
continue
}
}
firstPass := tg != passGen
passGen = tg
more, err := m.apply(target, firstPass)
if err != nil {
if ctx.Err() != nil {
return
}
// Log and DROP this target — do not retry it. A deterministic failure
// (e.g. a malformed peer in the map) would otherwise spin every pass
// making no progress. Management is the source of truth and re-delivers
// the full map on the next sync, so dropping is safe; peers already
// applied this convergence stay (idempotent diffs) and the remainder is
// reconciled by the next target. Mirrors the legacy handleSync path,
// where the apply error was logged by the gRPC client and the update
// dropped. No onConverged: this target did not converge.
log.Errorf("apply sync pass, dropping update: %v", err)
m.settle(tg, false)
continue
}
if more {
// keep converging the current target; syncMsgMux was released by apply
// between passes so other subsystems interleave.
continue
}
// This pass converged. Mark applied and signal this one map.
m.settle(tg, true)
// if a newer target arrived mid-pass, settle is a no-op (targetGen != tg) and
// ag<tg next iteration -> apply it; this generation was skipped (logged in
// SetTarget) and is not signaled.
}
}
// settle marks generation tg as processed so the loop goes idle instead of
// re-applying the same target. It is a no-op when a newer target arrived during the
// pass (targetGen != tg), leaving appliedGen behind so that target re-applies — the
// just-finished generation was already counted as skipped.
//
// When signal is true (the pass converged) it fires onConverged once for this map;
// when false (the target was dropped on error) it does not — the map did not converge.
func (m *mapStateManager) settle(tg uint64, signal bool) {
m.mu.Lock()
if m.targetGen != tg {
m.mu.Unlock()
return
}
m.appliedGen = tg
setAt := m.targetSetAt
m.mu.Unlock()
if signal && m.onConverged != nil {
m.onConverged(time.Since(setAt))
}
}

View File

@@ -1,242 +0,0 @@
package internal
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// converges over the bounded passes (apply returns more until the 3rd pass),
// fires onConverged exactly once, then blocks (no further apply) until a new target.
func TestMapStateManager_ConvergesThenStops(t *testing.T) {
var passes int32
var firstPasses int32
converged := make(chan struct{}, 1)
apply := func(_ *mgmProto.SyncResponse, firstPass bool) (bool, error) {
n := atomic.AddInt32(&passes, 1)
if firstPass {
atomic.AddInt32(&firstPasses, 1)
}
return n < 3, nil // more on pass 1 and 2, converge on pass 3
}
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-converged:
case <-time.After(2 * time.Second):
t.Fatal("manager did not converge")
}
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
require.EqualValues(t, 1, atomic.LoadInt32(&firstPasses), "firstPass true only on pass 1, false on re-runs of the same target")
// once converged the loop blocks: no further apply calls
time.Sleep(100 * time.Millisecond)
require.EqualValues(t, 3, atomic.LoadInt32(&passes), "apply must not run after convergence")
}
// persist runs once per received update (not per apply pass), regardless of how many
// bounded passes that target takes to converge.
func TestMapStateManager_PersistsOncePerUpdate(t *testing.T) {
var passes, persists int32
converged := make(chan struct{}, 1)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
n := atomic.AddInt32(&passes, 1)
return n < 3, nil // 3 passes for one target
}
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
m := newMapStateManager(apply, persist, func(time.Duration) { converged <- struct{}{} })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-converged:
case <-time.After(2 * time.Second):
t.Fatal("did not converge")
}
require.EqualValues(t, 3, atomic.LoadInt32(&passes))
require.EqualValues(t, 1, atomic.LoadInt32(&persists), "persist once per update, not per pass")
}
// every update received from management is persisted — even one that is coalesced /
// skipped from apply before it ever converges.
func TestMapStateManager_PersistsEveryUpdateIncludingSkipped(t *testing.T) {
release := make(chan struct{})
var persists int32
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
<-release // hold the first apply so the second update coalesces/skips
return false, nil
}
persist := func(*mgmProto.SyncResponse) { atomic.AddInt32(&persists, 1) }
m := newMapStateManager(apply, persist, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map1 -> apply blocks
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{})) // map2 supersedes map1 (skipped from apply)
close(release)
// both updates persisted even though map1 is skipped from apply
require.Eventually(t, func() bool { return atomic.LoadInt32(&persists) == 2 }, 2*time.Second, 10*time.Millisecond)
}
// each map that is actually processed (converged before the next arrives) fires
// onConverged exactly once — mirroring the legacy per-message handleSync timing.
func TestMapStateManager_SignalsEachProcessedMap(t *testing.T) {
converged := make(chan struct{}, 8)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
return false, nil // converge in one pass
}
m := newMapStateManager(apply, nil, func(time.Duration) { converged <- struct{}{} })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
const maps = 3
for i := 0; i < maps; i++ {
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select { // wait for this map to converge before sending the next (no coalescing)
case <-converged:
case <-time.After(2 * time.Second):
t.Fatalf("map %d not signaled", i)
}
}
// no extra signals once the stream goes quiet
select {
case <-converged:
t.Fatal("unexpected extra onConverged")
case <-time.After(100 * time.Millisecond):
}
}
// a map superseded before it converges is skipped: only the latest (processed) map
// fires onConverged, not the skipped one.
func TestMapStateManager_SkippedMapNotSignaled(t *testing.T) {
release := make(chan struct{})
var applies, converged atomic.Int32
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
applies.Add(1)
<-release // hold the first apply in-flight so we can queue a newer target
return false, nil
}
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
// map1 is picked up; its apply blocks on release
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
require.Eventually(t, func() bool { return applies.Load() >= 1 }, 2*time.Second, 5*time.Millisecond)
// map2 supersedes map1 before it settled -> map1 is skipped
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
close(release) // let both applies proceed
// only the processed (latest) map signals; the skipped one does not
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
time.Sleep(150 * time.Millisecond)
require.EqualValues(t, 1, converged.Load(), "skipped map must not fire onConverged")
require.EqualValues(t, 2, applies.Load(), "both targets entered apply (map1 once, map2 once)")
}
// an apply error drops the target: no retry of the same target, no onConverged,
// the loop goes idle — and a fresh target is still applied afterwards.
func TestMapStateManager_DropsTargetOnError(t *testing.T) {
applied := make(chan struct{}, 8)
var failNext atomic.Bool
failNext.Store(true)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
applied <- struct{}{}
if failNext.Load() {
return false, errors.New("boom")
}
return false, nil // converge in one pass
}
var converged atomic.Int32
m := newMapStateManager(apply, nil, func(time.Duration) { converged.Add(1) })
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
// first target errors -> applied once, then dropped (no retry, no onConverged)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("errored target not applied")
}
select {
case <-applied:
t.Fatal("errored target must not be retried")
case <-time.After(150 * time.Millisecond):
}
require.EqualValues(t, 0, converged.Load(), "onConverged must not fire on error")
// a new target is still processed normally and converges
failNext.Store(false)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("new target after error not applied")
}
require.Eventually(t, func() bool { return converged.Load() == 1 }, 2*time.Second, 10*time.Millisecond)
}
// a new target after convergence triggers a fresh apply; an idle (converged)
// manager does not apply on its own.
func TestMapStateManager_ReappliesOnNewTarget(t *testing.T) {
applied := make(chan struct{}, 8)
apply := func(_ *mgmProto.SyncResponse, _ bool) (bool, error) {
applied <- struct{}{}
return false, nil // converge in one pass
}
m := newMapStateManager(apply, nil, nil)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go m.run(ctx)
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("first target not applied")
}
// converged → must stay idle (no spurious apply)
select {
case <-applied:
t.Fatal("unexpected apply while idle/converged")
case <-time.After(150 * time.Millisecond):
}
require.NoError(t, m.SetTarget(&mgmProto.SyncResponse{}))
select {
case <-applied:
case <-time.After(2 * time.Second):
t.Fatal("new target not applied")
}
}

View File

@@ -120,6 +120,30 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
m.trimLocked()
}
func (m *influxDBMetrics) RecordSyncPhase(_ context.Context, agentInfo AgentInfo, phase string, duration time.Duration) {
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s,phase=%s",
agentInfo.DeploymentType.String(),
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
phase,
)
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_sync_phase",
tags: tags,
fields: map[string]float64{
"duration_seconds": duration.Seconds(),
},
timestamp: time.Now(),
})
m.trimLocked()
}
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
result := "success"
if !success {

View File

@@ -78,6 +78,25 @@ Tags:
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
### Sync Phase Timing
Measurement: `netbird_sync_phase`
Breaks down where time goes inside a single sync, so the total `netbird_sync` duration can be attributed to the sub-step that dominates.
| Field | Description |
|-------|-------------|
| `duration_seconds` | Time spent in one sub-phase of sync processing |
Tags:
- `phase`: the sub-phase — `netbird_config`, `checks`, `persist`, `dns_server`, `routes_classify`, `routes_apply`, `filtering`, `dns_forwarder`, `forward_rules`, `offline_peers`, `removed_peers`, `modified_peers`, `added_peers`, `lazy_exclude`
- `deployment_type`: "cloud" | "selfhosted" | "unknown"
- `version`: NetBird version string
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
**Note:** this is wall-time per phase — it includes both CPU work and time spent waiting on locks. A slow phase points to *where* the time goes, not *why*; pair it with lock-wait metrics to tell contention apart from real work.
### Login Duration
Measurement: `netbird_login`
@@ -191,4 +210,52 @@ docker compose exec influxdb influx query \
# Check ingest server health
curl http://localhost:8087/health
```
```
## Analyzing a Debug Bundle
Metrics collection is always on, so every debug bundle ships a `metrics.txt` in InfluxDB line protocol — a timestamped time series of all recorded events (sync durations, sync phases, connection stages, login). You can replay it into the local stack and graph it, without a running client.
The bundle's `metrics.txt` is a rolling window (capped at 5 days / ~20k samples, see [Buffer Limits](#buffer-limits)). For a connection incident the relevant window is short (connection setup is seconds), so a bundle captured during the issue is enough.
### 1. Start the stack
```bash
# From this directory (client/internal/metrics/infra)
INFLUXDB_ADMIN_TOKEN=admin123 INFLUXDB_ADMIN_PASSWORD=admin123 GRAFANA_ADMIN_PASSWORD=admin123 \
docker compose up -d
```
(`admin123` are throwaway local credentials — fine for offline analysis.)
### 2. Clear any previous data
So you only see this bundle:
```bash
docker exec influxdb influx delete --org netbird --bucket metrics --token admin123 \
--start 1970-01-01T00:00:00Z --stop 2100-01-01T00:00:00Z
```
### 3. Import the bundle's metrics.txt
InfluxDB is not exposed on the host, so import inside the container:
```bash
docker cp /path/to/bundle/metrics.txt influxdb:/tmp/m.txt
docker exec influxdb influx write --org netbird --bucket metrics --precision ns \
--token admin123 --file /tmp/m.txt
```
Re-importing the same file is idempotent (same measurement+tags+timestamp overwrites).
### 4. View the dashboards
Grafana on http://localhost:3001 (login `admin` / `admin123`), datasource pre-provisioned:
- **Where sync time goes:** http://localhost:3001/d/netbird-sync-phases/netbird-sync-phases-where-time-goes
- **General client metrics:** http://localhost:3001/d/netbird-influxdb-metrics
**Set the time range** to cover the bundle's timestamps (e.g. "Last 7 days" or an absolute range matching when the bundle was taken) — with the default short range the panels look empty.
Bundles are distinguishable by the `version` tag; add a tag at import time (e.g. `sed 's/^netbird_\([a-z_]*\),/netbird_\1,bundle=mycase,/' metrics.txt`) if you want to compare several side by side.

View File

@@ -0,0 +1,259 @@
{
"annotations": {
"list": []
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 1,
"links": [],
"refresh": "",
"schemaVersion": 39,
"tags": [
"netbird",
"sync"
],
"templating": {
"list": [
{
"current": {
"text": "All",
"value": "$__all"
},
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"definition": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
"includeAll": true,
"label": "version",
"multi": true,
"name": "version",
"query": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
"refresh": 2,
"type": "query",
"allValue": ".*"
}
]
},
"time": {
"from": "now-2d",
"to": "now"
},
"timepicker": {},
"timezone": "",
"title": "NetBird Sync Phases (where time goes)",
"uid": "netbird-sync-phases",
"version": 1,
"panels": [
{
"id": 1,
"title": "Time per phase over time (stacked, ms)",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 10,
"w": 24,
"x": 0,
"y": 0
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"custom": {
"drawStyle": "bars",
"stacking": {
"mode": "normal",
"group": "A"
},
"fillOpacity": 80,
"lineWidth": 0
}
},
"overrides": []
},
"options": {
"legend": {
"displayMode": "table",
"placement": "right",
"calcs": [
"max",
"mean"
]
},
"tooltip": {
"mode": "multi",
"sort": "desc"
}
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"phase\"])\n |> group(columns: [\"phase\"])"
}
]
},
{
"id": 2,
"title": "p95 per phase (ms)",
"type": "bargauge",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 11,
"w": 12,
"x": 0,
"y": 10
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"color": {
"mode": "continuous-GrYlRd"
}
},
"overrides": []
},
"options": {
"displayMode": "gradient",
"orientation": "horizontal",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showUnfilled": true
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> sort(columns: [\"_value\"], desc: true)"
}
]
},
{
"id": 3,
"title": "Per-phase stats (ms): mean / p95 / max",
"type": "table",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 11,
"w": 12,
"x": 12,
"y": 10
},
"fieldConfig": {
"defaults": {
"unit": "ms"
},
"overrides": []
},
"options": {
"showHeader": true,
"sortBy": [
{
"displayName": "max",
"desc": true
}
]
},
"transformations": [
{
"id": "merge",
"options": {}
}
],
"targets": [
{
"refId": "mean",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> mean()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"mean\"})"
},
{
"refId": "p95",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"p95\"})"
},
{
"refId": "max",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> max()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"max\"})"
}
]
},
{
"id": 4,
"title": "Total sync duration (netbird_sync, ms) \u2014 reference",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 24,
"x": 0,
"y": 21
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"custom": {
"drawStyle": "points",
"pointSize": 5
}
},
"overrides": []
},
"options": {
"legend": {
"displayMode": "table",
"placement": "right",
"calcs": [
"max",
"mean"
]
},
"tooltip": {
"mode": "single"
}
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"version\"])\n |> group(columns: [\"version\"])"
}
]
}
]
}

View File

@@ -19,7 +19,7 @@ const (
defaultListenAddr = ":8087"
defaultInfluxDBURL = "http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns"
maxBodySize = 50 * 1024 * 1024 // 50 MB max request body
maxDurationSeconds = 300.0 // reject any duration field > 5 minutes
maxDurationSeconds = 86400.0 // reject any duration field > 24 hours
peerIDLength = 16 // truncated SHA-256: 8 bytes = 16 hex chars
maxTagValueLength = 64 // reject tag values longer than this
)
@@ -59,6 +59,19 @@ var allowedMeasurements = map[string]measurementSpec{
"peer_id": true,
},
},
"netbird_sync_phase": {
allowedFields: map[string]bool{
"duration_seconds": true,
},
allowedTags: map[string]bool{
"deployment_type": true,
"version": true,
"os": true,
"arch": true,
"peer_id": true,
"phase": true,
},
},
"netbird_login": {
allowedFields: map[string]bool{
"duration_seconds": true,

View File

@@ -53,14 +53,14 @@ func TestValidateLine_NegativeValue(t *testing.T) {
}
func TestValidateLine_DurationTooLarge(t *testing.T) {
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=999 1234567890`
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=100000 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "too large")
}
func TestValidateLine_TotalSecondsTooLarge(t *testing.T) {
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=500 1234567890`
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=100000 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "too large")

View File

@@ -56,6 +56,9 @@ type metricsImplementation interface {
// RecordSyncDuration records how long it took to process a sync message
RecordSyncDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration)
// RecordSyncPhase records how long a single sub-phase of sync processing took
RecordSyncPhase(ctx context.Context, agentInfo AgentInfo, phase string, duration time.Duration)
// RecordLoginDuration records how long the login to management took
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
@@ -127,6 +130,18 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
}
// RecordSyncPhase records the duration of a single sub-phase of sync processing
func (c *ClientMetrics) RecordSyncPhase(ctx context.Context, phase string, duration time.Duration) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
c.impl.RecordSyncPhase(ctx, agentInfo, phase, duration)
}
// RecordLoginDuration records how long the login to management server took
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
if c == nil {

View File

@@ -70,6 +70,9 @@ func (m *mockMetrics) RecordConnectionStages(_ context.Context, _ AgentInfo, _ s
func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.Duration) {
}
func (m *mockMetrics) RecordSyncPhase(_ context.Context, _ AgentInfo, _ string, _ time.Duration) {
}
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
}

View File

@@ -226,12 +226,11 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
return
}
// All query types for an intercepted domain are forwarded to the peer's
// DNS forwarder, which owns the name. Falling through to the system
// resolver would let it answer NXDOMAIN for a name it isn't authoritative
// for, poisoning the whole name (including the A/AAAA records the route
// does serve). The forwarder answers NODATA for types it cannot resolve.
d.mu.RLock()
peerKey := d.currentPeerKey
d.mu.RUnlock()
@@ -293,19 +292,6 @@ func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger
}
}
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed writing DNS continue response: %v", err)
}
}
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
if !exists {

View File

@@ -55,6 +55,14 @@ type GrpcClient struct {
connStateCallback ConnStateNotifier
connStateCallbackLock sync.RWMutex
serverURL string
// syncStreamErr holds the last Sync stream error, or nil while the stream
// is established and healthy. GetServerKey succeeds even when the peer
// cannot sync (e.g. the server returns "settings not found"), so the
// health probe must consult this to avoid reporting a healthy management
// connection while the Sync stream keeps failing.
syncStreamMu sync.RWMutex
syncStreamErr error
}
type ExposeRequest struct {
@@ -364,6 +372,8 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
stream, err := c.connectToSyncStream(ctx, serverPubKey, sysInfo)
if err != nil {
log.Debugf("failed to open Management Service stream: %s", err)
c.notifyDisconnected(err)
c.setSyncStreamDisconnected(err)
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.PermissionDenied {
return backoff.Permanent(err) // unrecoverable error, propagate to the upper layer
}
@@ -372,11 +382,13 @@ func (c *GrpcClient) handleSyncStream(ctx context.Context, serverPubKey wgtypes.
log.Infof("connected to the Management Service stream")
c.notifyConnected()
c.setSyncStreamConnected()
// blocking until error
err = c.receiveUpdatesEvents(stream, serverPubKey, msgHandler)
if err != nil {
c.notifyDisconnected(err)
c.setSyncStreamDisconnected(err)
if ctx.Err() != nil {
log.Debugf("management connection context has been canceled, this usually indicates shutdown")
return nil
@@ -530,6 +542,13 @@ func (c *GrpcClient) IsHealthy() bool {
log.Warnf("health check returned: %s", err)
return false
}
if syncErr := c.syncStreamError(); syncErr != nil {
c.notifyDisconnected(syncErr)
log.Warnf("management transport is up but the Sync stream is unhealthy: %s", syncErr)
return false
}
c.notifyConnected()
return true
}
@@ -771,6 +790,24 @@ func (c *GrpcClient) SyncMeta(sysInfo *system.Info) error {
return err
}
func (c *GrpcClient) setSyncStreamConnected() {
c.syncStreamMu.Lock()
defer c.syncStreamMu.Unlock()
c.syncStreamErr = nil
}
func (c *GrpcClient) setSyncStreamDisconnected(err error) {
c.syncStreamMu.Lock()
defer c.syncStreamMu.Unlock()
c.syncStreamErr = err
}
func (c *GrpcClient) syncStreamError() error {
c.syncStreamMu.RLock()
defer c.syncStreamMu.RUnlock()
return c.syncStreamErr
}
func (c *GrpcClient) notifyDisconnected(err error) {
c.connStateCallbackLock.RLock()
defer c.connStateCallbackLock.RUnlock()

View File

@@ -85,6 +85,7 @@ type GrpcClient struct {
// receive backpressure as a dead stream: reconnecting cannot help, since the
// new stream feeds the same worker, and only triggers a reconnect storm.
receiveHandoffBlocked atomic.Bool
watchdogWg sync.WaitGroup
}
// NewClient creates a new Signal client
@@ -200,10 +201,18 @@ func (c *GrpcClient) Receive(ctx context.Context, msgHandler func(msg *proto.Mes
// Guard the receive direction: the transport can stay healthy while the
// server stops delivering messages. The watchdog reconnects via cancelStream.
c.markReceived()
go c.watchReceiveStream(streamCtx, cancelStream)
c.watchdogWg.Add(1)
go func() {
defer c.watchdogWg.Done()
c.watchReceiveStream(streamCtx, cancelStream)
}()
// start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream)
cancelStream()
c.watchdogWg.Wait()
if err != nil {
// Check the parent context, not streamCtx: a watchdog-triggered
// cancelStream must reconnect, only a parent cancel is shutdown.
@@ -400,7 +409,12 @@ func (c *GrpcClient) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage
// Send sends a message to the remote Peer through the Signal Exchange.
func (c *GrpcClient) Send(msg *proto.Message) error {
return c.send(c.ctx, msg)
}
// send delivers a message deriving per-attempt timeouts from parentCtx, so a
// caller can abort an in-flight send by cancelling that context.
func (c *GrpcClient) send(parentCtx context.Context, msg *proto.Message) error {
if !c.Ready() {
return fmt.Errorf("no connection to signal")
}
@@ -416,7 +430,7 @@ func (c *GrpcClient) Send(msg *proto.Message) error {
if attempt > 1 {
attemptTimeout = time.Duration(attempt) * 5 * time.Second
}
ctx, cancel := context.WithTimeout(c.ctx, attemptTimeout)
ctx, cancel := context.WithTimeout(parentCtx, attemptTimeout)
_, err = c.realClient.Send(ctx, encryptedMessage)
@@ -486,7 +500,7 @@ func (c *GrpcClient) watchReceiveStream(ctx context.Context, cancelStream contex
}
if probeSentAt.IsZero() {
if err := c.sendReceiveProbe(); err != nil {
if err := c.sendReceiveProbe(ctx); err != nil {
log.Debugf("failed to send signal receive probe: %v", err)
}
probeSentAt = time.Now()
@@ -495,11 +509,13 @@ func (c *GrpcClient) watchReceiveStream(ctx context.Context, cancelStream contex
}
}
// sendReceiveProbe sends a self-addressed heartbeat. The Signal server routes it
// back to this client, exercising the exact receive path the watchdog guards.
func (c *GrpcClient) sendReceiveProbe() error {
// sendReceiveProbe sends a self-addressed heartbeat bound to ctx, so cancelStream
// aborts an in-flight probe instead of leaving the watchdog blocked on send timeouts.
// The Signal server routes it back to this client, exercising the exact receive
// path the watchdog guards.
func (c *GrpcClient) sendReceiveProbe(ctx context.Context) error {
self := c.key.PublicKey().String()
return c.Send(&proto.Message{
return c.send(ctx, &proto.Message{
Key: self,
RemoteKey: self,
Body: &proto.Body{Type: proto.Body_HEARTBEAT},
@@ -541,6 +557,9 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient) er
if err := c.decryptionWorker.AddMsg(c.ctx, msg); err != nil {
log.Errorf("failed to add message to decryption worker: %v", err)
}
// Refresh liveness before clearing the flag so the window between here and
// the next Recv does not read a stale timestamp as a dead stream.
c.markReceived()
c.receiveHandoffBlocked.Store(false)
}
}

View File

@@ -2,6 +2,7 @@ package client
import (
"context"
"io"
"net"
"testing"
"time"
@@ -74,7 +75,7 @@ func TestReceiveProbeRoundTrips(t *testing.T) {
t.Fatal("signal stream did not connect within timeout")
}
require.NoError(t, client.sendReceiveProbe())
require.NoError(t, client.sendReceiveProbe(ctx))
select {
case <-received:
@@ -106,3 +107,72 @@ func TestReceiveAliveTreatsHandoffBlockAsLiveness(t *testing.T) {
c.markReceived()
require.True(t, c.receiveAlive(), "a freshly received frame must keep the stream alive")
}
// fakeRecvStream feeds the receive loop frames from a channel and reports EOF
// once the channel is closed. Only Recv is exercised by the loop.
type fakeRecvStream struct {
sigProto.SignalExchange_ConnectStreamClient
frames chan *sigProto.EncryptedMessage
}
func (s *fakeRecvStream) Recv() (*sigProto.EncryptedMessage, error) {
msg, ok := <-s.frames
if !ok {
return nil, io.EOF
}
return msg, nil
}
// TestReceiveLoopRefreshesLivenessAfterBlockedHandoff drives the real receive
// loop into a handoff that blocks past the inactivity threshold, then checks the
// window after the handoff drains but before the next Recv. The loop must have
// refreshed the timestamp on unblocking, otherwise that window reads the stale
// pre-handoff timestamp as a dead stream and the watchdog tears down a healthy
// connection.
func TestReceiveLoopRefreshesLivenessAfterBlockedHandoff(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
c := &GrpcClient{ctx: ctx}
handling := make(chan struct{}, 8)
gate := make(chan struct{})
decrypt := func(*sigProto.EncryptedMessage) (*sigProto.Message, error) { return &sigProto.Message{}, nil }
handler := func(*sigProto.Message) error {
handling <- struct{}{}
<-gate
return nil
}
c.decryptionWorker = NewWorker(decrypt, handler)
workerCtx, workerCancel := context.WithCancel(context.Background())
go c.decryptionWorker.Work(workerCtx)
t.Cleanup(workerCancel)
frames := make(chan *sigProto.EncryptedMessage)
t.Cleanup(func() { close(frames) })
go func() { _ = c.receive(&fakeRecvStream{frames: frames}) }()
// First frame: the worker drains it and parks in the blocking handler.
frames <- &sigProto.EncryptedMessage{}
<-handling
// Second frame fills the worker's single-slot pool.
frames <- &sigProto.EncryptedMessage{}
// Third frame: the pool is full, so the loop parks on the handoff.
frames <- &sigProto.EncryptedMessage{}
require.Eventually(t, c.receiveHandoffBlocked.Load, time.Second, time.Millisecond,
"receive loop should park on the worker handoff")
// Simulate the handoff having blocked past the inactivity threshold.
c.lastReceived.Store(time.Now().Add(-2 * receiveInactivityThreshold).UnixNano())
require.True(t, c.receiveAlive(), "a loop parked on the handoff must stay alive")
// Drain the worker so the handoff returns and the loop resumes reading.
close(gate)
// Once the handoff clears, the loop is parked on the next Recv with no frame
// pending. The stream must still read as alive in that window.
require.Eventually(t, func() bool { return !c.receiveHandoffBlocked.Load() }, time.Second, time.Millisecond,
"handoff should drain once the worker is released")
require.True(t, c.receiveAlive(),
"the loop must refresh liveness when the handoff drains, before the next Recv")
}