Compare commits

...

55 Commits

Author SHA1 Message Date
Viktor Liu
6cb25de9ea Include MTU and SSH auth/JWT cache config in debug bundle 2026-05-05 12:18:56 +02:00
Pascal Fischer
97db824929 [management] fix proxy reconnect (#6063) 2026-05-04 20:43:25 +02:00
Viktor Liu
77a0992dc2 [misc] Disable govet inline analyzer and tidy go.mod (#6066) 2026-05-05 02:59:41 +09:00
JungwooShin
104990dfdd [client] Display QR code for device auth login URL (#5415) 2026-05-04 18:59:29 +02:00
alexsavio
bde632c3b2 [client] Replace WG interface monitor polling with netlink subscription on Linux (#5857) 2026-05-04 18:49:39 +02:00
Lauri Tirkkonen
4268a5cfb7 [client] Use atomic write/rename pattern for ssh config 2026-05-04 18:24:52 +02:00
Zoltan Papp
a547fc74ed [client] Use ctx.Err() instead of gRPC codes.Canceled to detect shutdown (#6019)
Detecting shutdown by inspecting the gRPC status code conflates a local
context cancellation with a server- or proxy-sent codes.Canceled. When
the latter occurs (e.g. an intermediary proxy resets the stream), the
retry loop silently terminates and the client never reconnects.

Switch to ctx.Err() in the signal Receive loop and management Sync/Job
handlers, and stop matching gRPC Canceled/DeadlineExceeded in the flow
client's isContextDone helper. With this change, a server-sent Canceled
is treated as a transient error and the backoff retry loop continues.
2026-05-04 11:59:25 +02:00
Zoltan Papp
a21f6ecb0a [client] release Status.mux before invoking notifier callbacks (#6039)
The Status recorder used to fire notifier callbacks while holding d.mux:
- notifyPeerListChanged / notifyPeerStateChangeListeners ran from inside
  the locked section of every Update*/AddPeerStateRoute/etc.
- notifyAddressChanged ran from UpdateLocalPeerState and CleanLocalPeerState
  while d.mux was held.
- onConnectionChanged was registered with a defer above defer d.mux.Unlock,
  so it executed before the mutex was released in the Mark*Connected/
  Disconnected helpers.
- notifyPeerStateChangeListeners did a blocking channel send under d.mux,
  so a slow subscriber stalled every other d.mux holder.

A listener that re-enters the recorder (e.g. calls GetFullStatus from
within a callback) deadlocks against d.mux, and any callback that takes
longer than expected stalls every other state query for its duration.

Capture the values needed for notification under the lock, release d.mux,
then call the notifier. Build per-peer router-state snapshots inside the
lock and dispatch them via dispatchRouterPeers afterwards. The router-peer
channel send stays blocking, but now happens outside d.mux so a slow
consumer cannot stall any other d.mux holder, and no peer state
transitions are silently dropped.

The notifier itself is unchanged: its internal state was already protected
by its own locks, and the field d.notifier is set once in NewRecorder and
never reassigned, so reading it without d.mux is safe.

Also fix a pre-existing race in Test_notifier_RemoveListener /
Test_notifier_SetListener: setListener spawns a goroutine that writes
listener.peers, but the tests read listener.peers without waiting for it.
2026-05-04 11:59:01 +02:00
Bethuel Mmbaga
6262b0d841 [management] Track pending approval in peer event metadata (#6040) 2026-05-04 12:47:13 +03:00
Viktor Liu
50b58a6828 [client, relay] Advertise relay server IP via signal for foreign-relay fallback dial (#6004) 2026-05-04 11:40:25 +02:00
Viktor Liu
057d651d2e [client, proxy] Add packet capture to debug bundle and CLI (#5891) 2026-05-04 11:28:56 +02:00
Misha Bragin
c4b2da4c92 [management] Add public connection ipv4 and ipv6 posture check (#6038)
This change enables admins to configure posture checks for connecting public IPs of their peers.

It changes the behavior of the check as well and now the evaluation is if the received network is part of the configured network.
2026-04-30 18:36:50 +02:00
Nicolas Frati
dcd1db42ef [management] Enable PAT creation during setup (#6003)
* enable pat creation on setup

* remove logic from handler towards setup service

* fix lint issue

* fix rollback on account id returning empty

* fix coderabbit comments

* fix setup PAT rollback behavior
2026-04-30 17:21:35 +02:00
Pascal Fischer
f29f5a0978 [management] add monitoring for nmap update source (#6036) 2026-04-30 14:52:54 +02:00
Maycon Santos
3fc5a8d4a1 [misc] fix MSI generation add installer tests (#6031)
Add Windows installer build test workflow
2026-04-29 23:44:38 +02:00
Zoltan Papp
57945fc328 [client] Trigger mobile submodule bump PRs on release tags (#6029)
Trigger mobile submodule bump PRs on release tags
2026-04-29 17:19:22 +02:00
Viktor Liu
ed828b7af4 Tolerate EEXIST when adding macOS scoped default routes (#6027) 2026-04-29 16:08:47 +02:00
Viktor Liu
11ac2af2f5 Use BindListener for all userspace bind in lazyconn activity (#6028) 2026-04-29 16:07:33 +02:00
Bethuel Mmbaga
df197d5001 [management] Prevent JWT reuse during peer login (#6002) 2026-04-29 15:04:27 +03:00
shuuri-labs
ad93dcf980 [client] Enable UI autostart for silent and MSI installs (#6026)
* fix(client): enable UI autostart for silent and MSI installs

The MSI installer had no autostart logic and the EXE silent installer
skipped the autostart page, leaving the registry entry unwritten. This
caused the NetBird UI tray to not start at login after RMM deployments.

Add an AUTOSTART property (default: 1) to the MSI that writes the
HKLM Run key, and initialize AutostartEnabled in the NSIS .onInit so
silent installs match the interactive default.

* add real guid for NetBirdAutoStart component
2026-04-29 13:14:46 +02:00
Nicolas Frati
7eba5dafd8 [misc] Add comment automation on release workflow for PRs (#6016)
* feat: add comment automation on release workflow for PRs

* update action permissions
2026-04-29 11:28:55 +02:00
Viktor Liu
28fe26637b [client] Fix Windows installer upgrade detection for pre-0.70.1 installs (#6025) 2026-04-29 11:01:07 +02:00
Viktor Liu
407e9d304b [client] Move macOS sleep detection into the daemon (purego) (#5926) 2026-04-29 08:09:55 +02:00
Viktor Liu
e5474e199f [client] Use WinRT COM for Windows toasts (#6013)
* Use WinRT COM for Windows toasts instead of fyne's PowerShell path

* Quote autostart path and split HKCU registry into per-user component
2026-04-28 20:54:06 +02:00
Bethuel Mmbaga
db44848e2d [management] Drop netmap calculation on peer read (#6006) 2026-04-28 18:25:56 +03:00
EL OUAZIZI Walid
9417ce3b3a fix(getting-started): Infinite healthcheck loop with existing traefik (#5871) 2026-04-28 17:22:51 +02:00
Zoltan Papp
8fc4265995 [relay] evict foreign client cache on disconnect (#6015)
* [relay] evict foreign client cache on disconnect

When a foreign relay's TCP connection drops, the manager's
onServerDisconnected handler only triggered reconnect logic for the
home server; the disconnected foreign entry stayed in the relayClients
cache. Subsequent OpenConn calls reused the closed client until the
60-second cleanup tick evicted it, breaking peer connectivity through
that relay for up to a minute.

Evict the foreign entry from the cache on disconnect so the next
OpenConn dials a fresh client.

Also:
- Make the reconnect backoff cap configurable via WithMaxBackoffInterval
  ManagerOption; the previous hard-coded 60s constant forced
  TestAutoReconnect to sleep ~61s. Test now polls Ready() and finishes
  in ~2s.
- Add NB_HOME_RELAY_SERVERS env var that overrides the relay URL list
  received from management, so a peer can be pinned to a specific home
  relay (used by the netbird-conn-lab Edge 4 reproducer).

* [client] treat empty NB_HOME_RELAY_SERVERS as unset

Returning (urls=[], ok=true) when the env var contained only separators or
whitespace caused callers to wipe the mgmt-provided relay list, leaving the
peer with no relays. Treat a parsed-empty result the same as an unset env.
2026-04-28 15:04:48 +02:00
Zoltan Papp
9c50819f20 Don't mark management disconnected on transient job stream errors (#6005)
The JOB stream is a separate channel from the SYNC stream. Server-side
EOF or other transient errors on the JOB stream do not indicate that
the management connection is unhealthy — the SYNC stream remains the
authoritative state signal.

Previously, a JOB stream EOF would call notifyDisconnected and the
client would emit OnConnecting to the UI. The backoff retry would
reconnect the JOB stream, but handleJobStream never calls notifyConnected
on success, so the UI was stuck on "Connecting" until the next SYNC
event or health check.

Keep notifyDisconnected for codes.PermissionDenied since IsLoginRequired
relies on managementError to detect expired auth.
2026-04-28 15:04:41 +02:00
Bethuel Mmbaga
6f0eff3ba0 [management] Handle single-string JWT group claim from IdPs (#6014) 2026-04-28 14:48:28 +03:00
Bethuel Mmbaga
f8745723fc [management] Add Microsoft AD FS support for embedded Dex identity providers (#6008) 2026-04-28 12:42:19 +03:00
Vlad
154b81645a [management] removed legacy network map code (#5565) 2026-04-27 16:02:54 +02:00
Maycon Santos
34167c8a16 [misc] Update release pipeline version (#5995) 2026-04-27 10:55:38 +02:00
Maycon Santos
d6f08e4840 [misc] Update sign pipeline version (#5981) 2026-04-24 13:13:27 +02:00
Zoltan Papp
f732b01a05 [management] unify peer-update test timeout via constant (#5952)
peerShouldReceiveUpdate waited 500ms for the expected update message,
and every outer wrapper across the management/server test suite paired
it with a 1s goroutine-drain timeout. Both were too tight for slower
CI runners (MySQL, FreeBSD, loaded sqlite), producing intermittent
"Timed out waiting for update message" failures in tests like
TestDNSAccountPeersUpdate, TestPeerAccountPeersUpdate, and
TestNameServerAccountPeersUpdate.

Introduce peerUpdateTimeout (5s) next to the helper and use it both in
the helper and in every outer wrapper so the two timeouts stay in sync.
Only runs down on failure; passing tests return as soon as the channel
delivers, so there is no slowdown on green runs.
2026-04-23 21:19:21 +02:00
alsruf36
c07c726ea7 [proxy] Set session cookie path to root (#5915) 2026-04-23 18:20:54 +02:00
Pascal Fischer
fa0d58d093 [management] exclude peers for expiration job that have already been marked expired (#5970) 2026-04-23 16:01:54 +02:00
Vlad
b6038e8acd [management] refactor: changeable pat rate limiting (#5946) 2026-04-23 15:13:22 +02:00
Zoltan Papp
5da05ecca6 [client] increase gRPC health check timeout to 5s (#5961)
Bump the IsHealthy() context timeout from 1s to 5s for both the
management and signal gRPC clients to reduce false negatives on
slower or congested connections.
2026-04-22 20:54:18 +02:00
Viktor Liu
801de8c68d [client] Add TTL-based refresh to mgmt DNS cache via handler chain (#5945) 2026-04-22 15:10:14 +02:00
Viktor Liu
a822a33240 [self-hosted] Use cscli lapi status for CrowdSec readiness in installer (#5949) 2026-04-22 10:35:22 +02:00
Bethuel Mmbaga
57b23c5b25 [management] Propagate context changes to upstream middleware (#5956) 2026-04-21 23:06:52 +03:00
Zoltan Papp
1165058fad [client] fix port collision in TestUpload (#5950)
* [debug] fix port collision in TestUpload

TestUpload hardcoded :8080, so it failed deterministically when anything
was already on that port and collided across concurrent test runs.
Bind a :0 listener in the test to get a kernel-assigned free port, and
add Server.Serve so tests can hand the listener in without reaching
into unexported state.

* [debug] drop test-only Server.Serve, use SERVER_ADDRESS env

The previous commit added a Server.Serve method on the upload-server,
used only by TestUpload. That left production with an unused function.
Reserve an ephemeral loopback port in the test, release it, and pass
the address through SERVER_ADDRESS (which the server already reads).
A small wait helper ensures the server is accepting connections before
the upload runs, so the close/rebind gap does not cause a false failure.
2026-04-21 19:07:20 +02:00
Zoltan Papp
703353d354 [flow] fix goroutine leak in TestReceive_ProtocolErrorStreamReconnect (#5951)
The Receive goroutine could outlive the test and call t.Logf after
teardown, panicking with "Log in goroutine after ... has completed".
Register a cleanup that waits for the goroutine to exit; ordering is
LIFO so it runs after client.Close, which is what unblocks Receive.
2026-04-21 19:06:47 +02:00
Zoltan Papp
2fb50aef6b [client] allow UDP packet loss in TestICEBind_HandlesConcurrentMixedTraffic (#5953)
The test writes 500 packets per family and asserted exact-count
delivery within a 5s window, even though its own comment says "Some
packet loss is acceptable for UDP". On FreeBSD/QEMU runners the writer
loops cannot always finish all 500 before the 5s deadline closes the
readers (we have seen 411/500 in CI).

The real assertion of this test is the routing check — IPv4 peer only
gets v4- packets, IPv6 peer only gets v6- packets — which remains
strict. Replace the exact-count assertions with a >=80% delivery
threshold so runner speed variance no longer causes false failures.
2026-04-21 19:05:58 +02:00
Vlad
eb3aa96257 [management] check policy for changes before actual db update (#5405) 2026-04-21 18:37:04 +02:00
Viktor Liu
064ec1c832 [client] Trust wg interface in firewalld to bypass owner-flagged chains (#5928) 2026-04-21 17:57:16 +02:00
Viktor Liu
75e408f51c [client] Prefer systemd-resolved stub over file mode regardless of resolv.conf header (#5935) 2026-04-21 17:56:56 +02:00
Zoltan Papp
5a89e6621b [client] Supress ICE signaling (#5820)
* [client] Suppress ICE signaling and periodic offers in force-relay mode

When NB_FORCE_RELAY is enabled, skip WorkerICE creation entirely,
suppress ICE credentials in offer/answer messages, disable the
periodic ICE candidate monitor, and fix isConnectedOnAllWay to
only check relay status so the guard stops sending unnecessary offers.

* [client] Dynamically suppress ICE based on remote peer's offer credentials

Track whether the remote peer includes ICE credentials in its
offers/answers. When remote stops sending ICE credentials, skip
ICE listener dispatch, suppress ICE credentials in responses, and
exclude ICE from the guard connectivity check. When remote resumes
sending ICE credentials, re-enable all ICE behavior.

* [client] Fix nil SessionID panic and force ICE teardown on relay-only transition

Fix nil pointer dereference in signalOfferAnswer when SessionID is nil
(relay-only offers). Close stale ICE agent immediately when remote peer
stops sending ICE credentials to avoid traffic black-hole during the
ICE disconnect timeout.

* [client] Add relay-only fallback check when ICE is unavailable

Ensure the relay connection is supported with the peer when ICE is disabled to prevent connectivity issues.

* [client] Add tri-state connection status to guard for smarter ICE retry (#5828)

* [client] Add tri-state connection status to guard for smarter ICE retry

Refactor isConnectedOnAllWay to return a ConnStatus enum (Connected,
Disconnected, PartiallyConnected) instead of a boolean. When relay is
up but ICE is not (PartiallyConnected), limit ICE offers to 3 retries
with exponential backoff then fall back to hourly attempts, reducing
unnecessary signaling traffic. Fully disconnected peers continue to
retry aggressively. External events (relay/ICE disconnect, signal/relay
reconnect) reset retry state to give ICE a fresh chance.

* [client] Clarify guard ICE retry state and trace log trigger

Split iceRetryState.attempt into shouldRetry (pure predicate) and
enterHourlyMode (explicit state transition) so the caller in
reconnectLoopWithRetry reads top-to-bottom. Restore the original
trace-log behavior in isConnectedOnAllWay so it only logs on full
disconnection, not on the new PartiallyConnected state.

* [client] Extract pure evalConnStatus and add unit tests

Split isConnectedOnAllWay into a thin method that snapshots state and
a pure evalConnStatus helper that takes a connStatusInputs struct, so
the tri-state decision logic can be exercised without constructing
full Worker or Handshaker objects. Add table-driven tests covering
force-relay, ICE-unavailable and fully-available code paths, plus
unit tests for iceRetryState budget/hourly transitions and reset.

* [client] Improve grammar in logs and refactor ICE credential checks
2026-04-21 15:52:08 +02:00
Misha Bragin
06dfa9d4a5 [management] replace mailru/easyjson with netbirdio/easyjson fork (#5938) 2026-04-21 13:59:35 +02:00
Misha Bragin
45d9ee52c0 [self-hosted] add reverse proxy retention fields to combined YAML (#5930) 2026-04-21 10:21:11 +02:00
Zoltan Papp
3098f48b25 [client] fix ios network addresses mac filter (#5906)
* fix(client): skip MAC address filter for network addresses on iOS

iOS does not expose hardware (MAC) addresses due to Apple's privacy
restrictions (since iOS 14), causing networkAddresses() to return an
empty list because all interfaces are filtered out by the HardwareAddr
check. Move networkAddresses() to platform-specific files so iOS can
skip this filter.
2026-04-20 11:49:38 +02:00
Zoltan Papp
7f023ce801 [client] Android debug bundle support (#5888)
Add Android debug bundle support with Troubleshoot UI
2026-04-20 11:26:30 +02:00
Michael Uray
e361126515 [client] Fix WGIface.Close deadlock when DNS filter hook re-enters GetDevice (#5916)
WGIface.Close() took w.mu and held it across w.tun.Close(). The
underlying wireguard-go device waits for its send/receive goroutines to
drain before Close() returns, and some of those goroutines re-enter
WGIface during shutdown. In particular, the userspace packet filter DNS
hook in client/internal/dns.ServiceViaMemory.filterDNSTraffic calls
s.wgInterface.GetDevice() on every packet, which also needs w.mu. With
the Close-side holding the mutex, the read goroutine blocks in
GetDevice and Close waits forever for that goroutine to exit:

  goroutine N (TestDNSPermanent_updateUpstream):
    WGIface.Close -> holds w.mu -> tun.Close -> sync.WaitGroup.Wait
  goroutine M (wireguard read routine):
    FilteredDevice.Read -> filterOutbound -> udpHooksDrop ->
    filterDNSTraffic.func1 -> WGIface.GetDevice -> sync.Mutex.Lock

This surfaces as a 5 minute test timeout on the macOS Client/Unit
CI job (panic: test timed out after 5m0s, running tests:
TestDNSPermanent_updateUpstream).

Release w.mu before calling w.tun.Close(). The other Close steps
(wgProxyFactory.Free, waitUntilRemoved, Destroy) do not mutate any
fields guarded by w.mu beyond what Free() already does, so the lock
is not needed once the tun has started shutting down. A new unit test
in iface_close_test.go uses a fake WGTunDevice to reproduce the
deadlock deterministically without requiring CAP_NET_ADMIN.
2026-04-20 10:36:19 +02:00
Viktor Liu
95213f7157 [client] Use Match host+exec instead of Host+Match in SSH client config (#5903) 2026-04-20 10:24:11 +02:00
Viktor Liu
2e0e3a3601 [client] Replace exclusion routes with scoped default + IP_BOUND_IF on macOS (#5918) 2026-04-20 10:01:01 +02:00
258 changed files with 14542 additions and 7879 deletions

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.1.2" SIGN_PIPE_VER: "v0.1.4"
GORELEASER_VER: "v2.14.3" GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
@@ -114,7 +114,13 @@ jobs:
retention-days: 30 retention-days: 30
release: release:
runs-on: ubuntu-latest-m runs-on: ubuntu-24.04-8-core
outputs:
release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }}
linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }}
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
env: env:
flags: "" flags: ""
steps: steps:
@@ -213,10 +219,13 @@ jobs:
if: always() if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: Tag and push images (amd64 only) - name: Tag and push images (amd64 only)
id: tag_and_push_images
if: | if: |
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
(github.event_name == 'push' && github.ref == 'refs/heads/main') (github.event_name == 'push' && github.ref == 'refs/heads/main')
run: | run: |
set -euo pipefail
resolve_tags() { resolve_tags() {
if [[ "${{ github.event_name }}" == "pull_request" ]]; then if [[ "${{ github.event_name }}" == "pull_request" ]]; then
echo "pr-${{ github.event.pull_request.number }}" echo "pr-${{ github.event.pull_request.number }}"
@@ -225,6 +234,17 @@ jobs:
fi fi
} }
ghcr_package_url() {
local image="$1" package encoded_package
package="${image#ghcr.io/}"
package="${package#*/}"
package="${package%%:*}"
encoded_package="${package//\//%2F}"
echo "https://github.com/orgs/netbirdio/packages/container/package/${encoded_package}"
}
image_refs=()
tag_and_push() { tag_and_push() {
local src="$1" img_name tag dst local src="$1" img_name tag dst
img_name="${src%%:*}" img_name="${src%%:*}"
@@ -233,35 +253,56 @@ jobs:
echo "Tagging ${src} -> ${dst}" echo "Tagging ${src} -> ${dst}"
docker tag "$src" "$dst" docker tag "$src" "$dst"
docker push "$dst" docker push "$dst"
image_refs+=("$dst")
done done
} }
export -f tag_and_push resolve_tags cat > /tmp/goreleaser-artifacts.json <<'JSON'
${{ steps.goreleaser.outputs.artifacts }}
JSON
echo '${{ steps.goreleaser.outputs.artifacts }}' | \ mapfile -t src_images < <(
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \ jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json
grep '^ghcr.io/' | while read -r SRC; do )
tag_and_push "$SRC"
done for src in "${src_images[@]}"; do
tag_and_push "$src"
done
{
echo "images_markdown<<EOF"
if [[ ${#image_refs[@]} -eq 0 ]]; then
echo "_No GHCR images were pushed._"
else
printf '%s\n' "${image_refs[@]}" | sort -u | while read -r image; do
printf -- '- [`%s`](%s)\n' "$image" "$(ghcr_package_url "$image")"
done
fi
echo "EOF"
} >> "$GITHUB_OUTPUT"
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
id: upload_release
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: release name: release
path: dist/ path: dist/
retention-days: 7 retention-days: 7
- name: upload linux packages - name: upload linux packages
id: upload_linux_packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: linux-packages name: linux-packages
path: dist/netbird_linux** path: dist/netbird_linux**
retention-days: 7 retention-days: 7
- name: upload windows packages - name: upload windows packages
id: upload_windows_packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: windows-packages name: windows-packages
path: dist/netbird_windows** path: dist/netbird_windows**
retention-days: 7 retention-days: 7
- name: upload macos packages - name: upload macos packages
id: upload_macos_packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: macos-packages name: macos-packages
@@ -270,6 +311,8 @@ jobs:
release_ui: release_ui:
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs:
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
steps: steps:
- name: Parse semver string - name: Parse semver string
id: semver_parser id: semver_parser
@@ -360,6 +403,7 @@ jobs:
if: always() if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
id: upload_release_ui
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: release-ui name: release-ui
@@ -368,6 +412,8 @@ jobs:
release_ui_darwin: release_ui_darwin:
runs-on: macos-latest runs-on: macos-latest
outputs:
release_ui_darwin_artifact_url: ${{ steps.upload_release_ui_darwin.outputs.artifact-url }}
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
@@ -402,15 +448,258 @@ jobs:
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
id: upload_release_ui_darwin
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: release-ui-darwin name: release-ui-darwin
path: dist/ path: dist/
retention-days: 3 retention-days: 3
trigger_signer: test_windows_installer:
name: "Windows Installer / Build Test"
runs-on: windows-2022
needs: [release, release_ui]
strategy:
fail-fast: false
matrix:
include:
- arch: amd64
wintun_arch: amd64
- arch: arm64
wintun_arch: arm64
defaults:
run:
shell: powershell
env:
PackageWorkdir: netbird_windows_${{ matrix.arch }}
downloadPath: '${{ github.workspace }}\temp'
steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- name: Checkout
uses: actions/checkout@v4
- name: Add 7-Zip to PATH
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Download release artifacts
uses: actions/download-artifact@v4
with:
name: release
path: release
- name: Download UI release artifacts
uses: actions/download-artifact@v4
with:
name: release-ui
path: release-ui
- name: Stage binaries into dist
run: |
$workdir = "dist\${{ env.PackageWorkdir }}"
New-Item -ItemType Directory -Force -Path $workdir | Out-Null
$client = Get-ChildItem -Recurse -Path release -Filter "netbird_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1
$ui = Get-ChildItem -Recurse -Path release-ui -Filter "netbird-ui-windows_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1
if (-not $client) { Write-Host "::error::client tarball not found for ${{ matrix.arch }}"; exit 1 }
if (-not $ui) { Write-Host "::error::ui tarball not found for ${{ matrix.arch }}"; exit 1 }
Write-Host "Client: $($client.FullName)"
Write-Host "UI: $($ui.FullName)"
tar -zvxf $client.FullName -C $workdir
tar -zvxf $ui.FullName -C $workdir
Get-ChildItem $workdir
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
with:
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
- name: Decompress wintun files
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- name: Move wintun.dll into dist
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download Mesa3D (amd64 only)
uses: carlosperate/download-file-action@v2
id: download-mesa3d
if: matrix.arch == 'amd64'
with:
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
file-name: mesa3d.7z
location: ${{ env.downloadPath }}
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
- name: Extract Mesa3D driver (amd64 only)
if: matrix.arch == 'amd64'
run: 7z x -o"${{ env.downloadPath }}" "${{ env.downloadPath }}/mesa3d.7z"
- name: Move opengl32.dll into dist (amd64 only)
if: matrix.arch == 'amd64'
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download EnVar plugin for NSIS
uses: carlosperate/download-file-action@v2
with:
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
file-name: envar_plugin.zip
location: ${{ github.workspace }}
- name: Extract EnVar plugin
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
uses: carlosperate/download-file-action@v2
if: matrix.arch == 'amd64'
with:
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
file-name: ShellExecAsUser_amd64-Unicode.7z
location: ${{ github.workspace }}
- name: Extract ShellExecAsUser plugin (amd64 only)
if: matrix.arch == 'amd64'
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
- name: Build NSIS installer
uses: joncloud/makensis-action@v3.3
with:
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
script-file: client/installer.nsis
arguments: "/V4 /DARCH=${{ matrix.arch }}"
env:
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
- name: Rename NSIS installer
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
- name: Install WiX
run: |
dotnet tool install --global wix --version 6.0.2
wix extension add WixToolset.Util.wixext/6.0.2
- name: Build MSI installer
env:
NETBIRD_VERSION: "${{ steps.semver_parser.outputs.fullversion }}"
run: wix build -arch ${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -ext WixToolset.Util.wixext -o netbird_installer_test_windows_${{ matrix.arch }}.msi .\client\netbird.wxs -d ProcessorArchitecture=${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -d ArchSuffix=${{ matrix.arch }}
- name: Upload installer artifacts
if: always()
uses: actions/upload-artifact@v4
with:
name: windows-installer-test-${{ matrix.arch }}
path: |
netbird_installer_test_windows_${{ matrix.arch }}.exe
netbird_installer_test_windows_${{ matrix.arch }}.msi
retention-days: 3
comment_release_artifacts:
name: Comment release artifacts
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin] needs: [release, release_ui, release_ui_darwin]
if: ${{ always() && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository }}
permissions:
contents: read
issues: write
pull-requests: write
steps:
- name: Create or update PR comment
uses: actions/github-script@v7
env:
RELEASE_RESULT: ${{ needs.release.result }}
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
RELEASE_UI_DARWIN_RESULT: ${{ needs.release_ui_darwin.result }}
RELEASE_ARTIFACT_URL: ${{ needs.release.outputs.release_artifact_url }}
LINUX_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.linux_packages_artifact_url }}
WINDOWS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.windows_packages_artifact_url }}
MACOS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.macos_packages_artifact_url }}
RELEASE_UI_ARTIFACT_URL: ${{ needs.release_ui.outputs.release_ui_artifact_url }}
RELEASE_UI_DARWIN_ARTIFACT_URL: ${{ needs.release_ui_darwin.outputs.release_ui_darwin_artifact_url }}
GHCR_IMAGES_MARKDOWN: ${{ needs.release.outputs.ghcr_images }}
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const marker = '<!-- netbird-release-artifacts -->';
const { owner, repo } = context.repo;
const issue_number = context.payload.pull_request.number;
const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`;
const shortSha = context.payload.pull_request.head.sha.slice(0, 7);
const artifactCell = (url, result) => {
if (url) return `[Download](${url})`;
return result && result !== 'success' ? `_Not available (${result})_` : '_Not available_';
};
const artifacts = [
['All release artifacts', process.env.RELEASE_ARTIFACT_URL, process.env.RELEASE_RESULT],
['Linux packages', process.env.LINUX_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['Windows packages', process.env.WINDOWS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['macOS packages', process.env.MACOS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['UI artifacts', process.env.RELEASE_UI_ARTIFACT_URL, process.env.RELEASE_UI_RESULT],
['UI macOS artifacts', process.env.RELEASE_UI_DARWIN_ARTIFACT_URL, process.env.RELEASE_UI_DARWIN_RESULT],
];
const artifactRows = artifacts
.map(([name, url, result]) => `| ${name} | ${artifactCell(url, result)} |`)
.join('\n');
const ghcrImages = (process.env.GHCR_IMAGES_MARKDOWN || '').trim() || '_No GHCR images were pushed._';
const body = [
marker,
'## Release artifacts',
'',
`Built for PR head \`${shortSha}\` in [workflow run #${process.env.GITHUB_RUN_NUMBER}](${runUrl}).`,
'',
'| Artifact | Link |',
'| --- | --- |',
artifactRows,
'',
'### GHCR images (amd64)',
ghcrImages,
'',
'_This comment is updated by the Release workflow. Artifact links expire according to the workflow retention policy._',
].join('\n');
const comments = await github.paginate(github.rest.issues.listComments, {
owner,
repo,
issue_number,
per_page: 100,
});
const previous = comments.find(comment =>
comment.user?.type === 'Bot' && comment.body?.includes(marker)
);
if (previous) {
await github.rest.issues.updateComment({
owner,
repo,
comment_id: previous.id,
body,
});
core.info(`Updated release artifacts comment ${previous.id}`);
} else {
const { data } = await github.rest.issues.createComment({
owner,
repo,
issue_number,
body,
});
core.info(`Created release artifacts comment ${data.id}`);
}
trigger_signer:
runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin, test_windows_installer]
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')
steps: steps:
- name: Trigger binaries sign pipelines - name: Trigger binaries sign pipelines

View File

@@ -9,6 +9,8 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true cancel-in-progress: true
# Receiving workflows (cloud sync-tag, mobile bump-netbird) expect the short
# tag form (e.g. v0.30.0), not refs/tags/v0.30.0 — github.ref_name, not github.ref.
jobs: jobs:
trigger_sync_tag: trigger_sync_tag:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -20,4 +22,30 @@ jobs:
ref: main ref: main
repo: ${{ secrets.UPSTREAM_REPO }} repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }} token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
trigger_android_bump:
runs-on: ubuntu-latest
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger android-client submodule bump
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
with:
workflow: bump-netbird.yml
ref: main
repo: netbirdio/android-client
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
trigger_ios_bump:
runs-on: ubuntu-latest
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger ios-client submodule bump
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
with:
workflow: bump-netbird.yml
ref: main
repo: netbirdio/ios-client
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }' inputs: '{ "tag": "${{ github.ref_name }}" }'

View File

@@ -58,6 +58,11 @@ linters:
govet: govet:
enable: enable:
- nilness - nilness
disable:
# The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline
# directives but cannot perform the rewrite due to generic type
# parameter inference limitations in the Go inliner.
- inline
enable-all: false enable-all: false
revive: revive:
rules: rules:

View File

@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
$(GOLANGCI_LINT): $(GOLANGCI_LINT):
@echo "Installing golangci-lint..." @echo "Installing golangci-lint..."
@mkdir -p ./bin @mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest @GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push) # Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT) lint: $(GOLANGCI_LINT)

View File

@@ -17,6 +17,7 @@ ENV \
NETBIRD_BIN="/usr/local/bin/netbird" \ NETBIRD_BIN="/usr/local/bin/netbird" \
NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENABLE_CAPTURE="false" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30" NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -23,6 +23,7 @@ ENV \
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \ NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
NB_LOG_FILE="console,/var/lib/netbird/client.log" \ NB_LOG_FILE="console,/var/lib/netbird/client.log" \
NB_DISABLE_DNS="true" \ NB_DISABLE_DNS="true" \
NB_ENABLE_CAPTURE="false" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30" NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -8,6 +8,7 @@ import (
"os" "os"
"slices" "slices"
"sync" "sync"
"time"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@@ -15,6 +16,7 @@ import (
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -26,6 +28,7 @@ import (
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
types "github.com/netbirdio/netbird/upload-server/types"
) )
// ConnectionListener export internal Listener for mobile // ConnectionListener export internal Listener for mobile
@@ -68,7 +71,30 @@ type Client struct {
uiVersion string uiVersion string
networkChangeListener listener.NetworkChangeListener networkChangeListener listener.NetworkChangeListener
stateMu sync.RWMutex
connectClient *internal.ConnectClient connectClient *internal.ConnectClient
config *profilemanager.Config
cacheDir string
}
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
c.config = cfg
c.cacheDir = cacheDir
c.connectClient = cc
}
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.config, c.cacheDir, c.connectClient
}
func (c *Client) getConnectClient() *internal.ConnectClient {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.connectClient
} }
// NewClient instantiate a new Client // NewClient instantiate a new Client
@@ -93,6 +119,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
cfgFile := platformFiles.ConfigurationFilePath() cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath() stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile) log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
@@ -124,8 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
} }
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -135,6 +163,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
cfgFile := platformFiles.ConfigurationFilePath() cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath() stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile) log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
@@ -157,8 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context // todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx) ctx = internal.CtxInitState(ctx)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
} }
// Stop the internal client and free the resources // Stop the internal client and free the resources
@@ -173,11 +203,12 @@ func (c *Client) Stop() {
} }
func (c *Client) RenewTun(fd int) error { func (c *Client) RenewTun(fd int) error {
if c.connectClient == nil { cc := c.getConnectClient()
if cc == nil {
return fmt.Errorf("engine not running") return fmt.Errorf("engine not running")
} }
e := c.connectClient.Engine() e := cc.Engine()
if e == nil { if e == nil {
return fmt.Errorf("engine not initialized") return fmt.Errorf("engine not initialized")
} }
@@ -185,6 +216,73 @@ func (c *Client) RenewTun(fd int) error {
return e.RenewTun(fd) return e.RenewTun(fd)
} }
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
// It works both with and without a running engine.
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
cfg, cacheDir, cc := c.stateSnapshot()
// If the engine hasn't been started, load config from disk
if cfg == nil {
var err error
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: platformFiles.ConfigurationFilePath(),
})
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
cacheDir = platformFiles.CacheDir()
}
deps := debug.GeneratorDependencies{
InternalConfig: cfg,
StatusRecorder: c.recorder,
TempDir: cacheDir,
}
if cc != nil {
resp, err := cc.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
deps.SyncResponse = resp
if e := cc.Engine(); e != nil {
if cm := e.GetClientMetrics(); cm != nil {
deps.ClientMetrics = cm
}
}
}
bundleGenerator := debug.NewBundleGenerator(
deps,
debug.BundleConfig{
Anonymize: anonymize,
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
if err != nil {
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle uploaded with key %s", key)
return key, nil
}
// SetTraceLogLevel configure the logger to trace level // SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() { func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel) log.SetLevel(log.TraceLevel)
@@ -214,12 +312,13 @@ func (c *Client) PeersList() *PeerInfoArray {
} }
func (c *Client) Networks() *NetworkArray { func (c *Client) Networks() *NetworkArray {
if c.connectClient == nil { cc := c.getConnectClient()
if cc == nil {
log.Error("not connected") log.Error("not connected")
return nil return nil
} }
engine := c.connectClient.Engine() engine := cc.Engine()
if engine == nil { if engine == nil {
log.Error("could not get engine") log.Error("could not get engine")
return nil return nil
@@ -300,7 +399,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
} }
func (c *Client) getRouteManager() (routemanager.Manager, error) { func (c *Client) getRouteManager() (routemanager.Manager, error) {
client := c.connectClient client := c.getConnectClient()
if client == nil { if client == nil {
return nil, fmt.Errorf("not connected") return nil, fmt.Errorf("not connected")
} }

View File

@@ -7,4 +7,5 @@ package android
type PlatformFiles interface { type PlatformFiles interface {
ConfigurationFilePath() string ConfigurationFilePath() string
StateFilePath() string StateFilePath() string
CacheDir() string
} }

196
client/cmd/capture.go Normal file
View File

@@ -0,0 +1,196 @@
package cmd
import (
"context"
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"github.com/hashicorp/go-multierror"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util/capture"
)
var captureCmd = &cobra.Command{
Use: "capture",
Short: "Capture packets on the WireGuard interface",
Long: `Captures decrypted packets flowing through the WireGuard interface.
Default output is human-readable text. Use --pcap or --output for pcap binary.
Requires --enable-capture to be set at service install or reconfigure time.
Examples:
netbird debug capture
netbird debug capture host 100.64.0.1 and port 443
netbird debug capture tcp
netbird debug capture icmp
netbird debug capture src host 10.0.0.1 and dst port 80
netbird debug capture -o capture.pcap
netbird debug capture --pcap | tshark -r -
netbird debug capture --pcap | tcpdump -r - -n`,
Args: cobra.ArbitraryArgs,
RunE: runCapture,
}
func init() {
debugCmd.AddCommand(captureCmd)
captureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)")
captureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length")
captureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (useful for HTTP)")
captureCmd.Flags().Uint32("snap-len", 0, "Max bytes per packet (0 = full)")
captureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = until interrupted)")
captureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout")
}
func runCapture(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
cmd.PrintErrf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
req, err := buildCaptureRequest(cmd, args)
if err != nil {
return err
}
ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
stream, err := client.StartCapture(ctx, req)
if err != nil {
return handleCaptureError(err)
}
// First Recv is the empty acceptance message from the server. If the
// device is unavailable (kernel WG, not connected, capture disabled),
// the server returns an error instead.
if _, err := stream.Recv(); err != nil {
return handleCaptureError(err)
}
out, cleanup, err := captureOutput(cmd)
if err != nil {
return err
}
if req.TextOutput {
cmd.PrintErrf("Capturing packets... Press Ctrl+C to stop.\n")
} else {
cmd.PrintErrf("Capturing packets (pcap)... Press Ctrl+C to stop.\n")
}
streamErr := streamCapture(ctx, cmd, stream, out)
cleanupErr := cleanup()
if streamErr != nil {
return streamErr
}
return cleanupErr
}
func buildCaptureRequest(cmd *cobra.Command, args []string) (*proto.StartCaptureRequest, error) {
req := &proto.StartCaptureRequest{}
if len(args) > 0 {
expr := strings.Join(args, " ")
if _, err := capture.ParseFilter(expr); err != nil {
return nil, fmt.Errorf("invalid filter: %w", err)
}
req.FilterExpr = expr
}
if snap, _ := cmd.Flags().GetUint32("snap-len"); snap > 0 {
req.SnapLen = snap
}
if d, _ := cmd.Flags().GetDuration("duration"); d != 0 {
if d < 0 {
return nil, fmt.Errorf("duration must not be negative")
}
req.Duration = durationpb.New(d)
}
req.Verbose, _ = cmd.Flags().GetBool("verbose")
req.Ascii, _ = cmd.Flags().GetBool("ascii")
outPath, _ := cmd.Flags().GetString("output")
forcePcap, _ := cmd.Flags().GetBool("pcap")
req.TextOutput = !forcePcap && outPath == ""
return req, nil
}
func streamCapture(ctx context.Context, cmd *cobra.Command, stream proto.DaemonService_StartCaptureClient, out io.Writer) error {
for {
pkt, err := stream.Recv()
if err != nil {
if ctx.Err() != nil {
cmd.PrintErrf("\nCapture stopped.\n")
return nil //nolint:nilerr // user interrupted
}
if err == io.EOF {
cmd.PrintErrf("\nCapture finished.\n")
return nil
}
return handleCaptureError(err)
}
if _, err := out.Write(pkt.GetData()); err != nil {
return fmt.Errorf("write output: %w", err)
}
}
}
// captureOutput returns the writer for capture data and a cleanup function
// that finalizes the file. Errors from the cleanup must be propagated.
func captureOutput(cmd *cobra.Command) (io.Writer, func() error, error) {
outPath, _ := cmd.Flags().GetString("output")
if outPath == "" {
return os.Stdout, func() error { return nil }, nil
}
f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp")
if err != nil {
return nil, nil, fmt.Errorf("create output file: %w", err)
}
tmpPath := f.Name()
return f, func() error {
var merr *multierror.Error
if err := f.Close(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("close output file: %w", err))
}
fi, statErr := os.Stat(tmpPath)
if statErr != nil || fi.Size() == 0 {
if rmErr := os.Remove(tmpPath); rmErr != nil && !os.IsNotExist(rmErr) {
merr = multierror.Append(merr, fmt.Errorf("remove empty output file: %w", rmErr))
}
return nberrors.FormatErrorOrNil(merr)
}
if err := os.Rename(tmpPath, outPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rename output file: %w", err))
return nberrors.FormatErrorOrNil(merr)
}
cmd.PrintErrf("Wrote %s\n", outPath)
return nberrors.FormatErrorOrNil(merr)
}, nil
}
func handleCaptureError(err error) error {
if s, ok := status.FromError(err); ok {
return fmt.Errorf("%s", s.Message())
}
return err
}

View File

@@ -9,6 +9,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/debug"
@@ -239,11 +240,50 @@ func runForDuration(cmd *cobra.Command, args []string) error {
}() }()
} }
captureStarted := false
if wantCapture, _ := cmd.Flags().GetBool("capture"); wantCapture {
captureTimeout := duration + 30*time.Second
const maxBundleCapture = 10 * time.Minute
if captureTimeout > maxBundleCapture {
captureTimeout = maxBundleCapture
}
_, err := client.StartBundleCapture(cmd.Context(), &proto.StartBundleCaptureRequest{
Timeout: durationpb.New(captureTimeout),
})
if err != nil {
cmd.PrintErrf("Failed to start packet capture: %v\n", status.Convert(err).Message())
} else {
captureStarted = true
cmd.Println("Packet capture started.")
// Safety: always stop on exit, even if the normal stop below runs too.
defer func() {
if captureStarted {
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
}
}
}()
}
}
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr return waitErr
} }
cmd.Println("\nDuration completed") cmd.Println("\nDuration completed")
if captureStarted {
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
} else {
captureStarted = false
cmd.Println("Packet capture stopped.")
}
}
if cpuProfilingStarted { if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil { if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err) cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
@@ -416,4 +456,5 @@ func init() {
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle") forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server") forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
forCmd.Flags().Bool("capture", false, "Capture packets during the debug duration and include in bundle")
} }

View File

@@ -10,6 +10,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/term"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
@@ -23,6 +24,7 @@ import (
func init() { func init() {
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location") loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
} }
@@ -256,7 +258,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
} }
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error { func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser) openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR)
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName}) resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil { if err != nil {
@@ -324,7 +326,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err) return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
} }
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser) openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo) tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
if err != nil { if err != nil {
@@ -334,7 +336,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
return &tokenInfo, nil return &tokenInfo, nil
} }
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) { func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) {
var codeMsg string var codeMsg string
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) { if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
@@ -348,6 +350,12 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
verificationURIComplete + " " + codeMsg) verificationURIComplete + " " + codeMsg)
} }
if showQR {
if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) {
printQRCode(f, verificationURIComplete)
}
}
cmd.Println("") cmd.Println("")
if !noBrowser { if !noBrowser {

25
client/cmd/qr.go Normal file
View File

@@ -0,0 +1,25 @@
package cmd
import (
"io"
"github.com/mdp/qrterminal/v3"
)
// printQRCode prints a QR code for the given URL to the writer.
// Called only when the user explicitly requests QR output via --qr.
func printQRCode(w io.Writer, url string) {
if url == "" {
return
}
qrterminal.GenerateWithConfig(url, qrterminal.Config{
Level: qrterminal.M,
Writer: w,
HalfBlocks: true,
BlackChar: qrterminal.BLACK_BLACK,
WhiteChar: qrterminal.WHITE_WHITE,
BlackWhiteChar: qrterminal.BLACK_WHITE,
WhiteBlackChar: qrterminal.WHITE_BLACK,
QuietZone: qrterminal.QUIET_ZONE,
})
}

26
client/cmd/qr_test.go Normal file
View File

@@ -0,0 +1,26 @@
package cmd
import (
"bytes"
"testing"
)
func TestPrintQRCode_EmptyURL(t *testing.T) {
var buf bytes.Buffer
printQRCode(&buf, "")
if buf.Len() != 0 {
t.Error("expected no output for empty URL")
}
}
func TestPrintQRCode_WritesOutput(t *testing.T) {
var buf bytes.Buffer
printQRCode(&buf, "https://example.com/auth")
if buf.Len() == 0 {
t.Error("expected QR code output for non-empty URL")
}
}

View File

@@ -75,6 +75,7 @@ var (
mtu uint16 mtu uint16
profilesDisabled bool profilesDisabled bool
updateSettingsDisabled bool updateSettingsDisabled bool
captureEnabled bool
networksDisabled bool networksDisabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{

View File

@@ -44,6 +44,7 @@ func init() {
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd) serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles") serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings") serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture")
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks") serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")

View File

@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
} }
} }
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled) serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled, networksDisabled)
if err := serverInstance.Start(); err != nil { if err := serverInstance.Start(); err != nil {
log.Fatalf("failed to start daemon: %v", err) log.Fatalf("failed to start daemon: %v", err)
} }

View File

@@ -59,6 +59,10 @@ func buildServiceArguments() []string {
args = append(args, "--disable-update-settings") args = append(args, "--disable-update-settings")
} }
if captureEnabled {
args = append(args, "--enable-capture")
}
if networksDisabled { if networksDisabled {
args = append(args, "--disable-networks") args = append(args, "--disable-networks")
} }

View File

@@ -28,6 +28,7 @@ type serviceParams struct {
LogFiles []string `json:"log_files,omitempty"` LogFiles []string `json:"log_files,omitempty"`
DisableProfiles bool `json:"disable_profiles,omitempty"` DisableProfiles bool `json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"` DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
EnableCapture bool `json:"enable_capture,omitempty"`
DisableNetworks bool `json:"disable_networks,omitempty"` DisableNetworks bool `json:"disable_networks,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"` ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
} }
@@ -79,6 +80,7 @@ func currentServiceParams() *serviceParams {
LogFiles: logFiles, LogFiles: logFiles,
DisableProfiles: profilesDisabled, DisableProfiles: profilesDisabled,
DisableUpdateSettings: updateSettingsDisabled, DisableUpdateSettings: updateSettingsDisabled,
EnableCapture: captureEnabled,
DisableNetworks: networksDisabled, DisableNetworks: networksDisabled,
} }
@@ -144,6 +146,10 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
updateSettingsDisabled = params.DisableUpdateSettings updateSettingsDisabled = params.DisableUpdateSettings
} }
if !serviceCmd.PersistentFlags().Changed("enable-capture") {
captureEnabled = params.EnableCapture
}
if !serviceCmd.PersistentFlags().Changed("disable-networks") { if !serviceCmd.PersistentFlags().Changed("disable-networks") {
networksDisabled = params.DisableNetworks networksDisabled = params.DisableNetworks
} }

View File

@@ -535,6 +535,7 @@ func fieldToGlobalVar(field string) string {
"LogFiles": "logFiles", "LogFiles": "logFiles",
"DisableProfiles": "profilesDisabled", "DisableProfiles": "profilesDisabled",
"DisableUpdateSettings": "updateSettingsDisabled", "DisableUpdateSettings": "updateSettingsDisabled",
"EnableCapture": "captureEnabled",
"DisableNetworks": "networksDisabled", "DisableNetworks": "networksDisabled",
"ServiceEnvVars": "serviceEnvVars", "ServiceEnvVars": "serviceEnvVars",
} }

View File

@@ -135,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil) mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -160,7 +160,7 @@ func startClientDaemon(
s := grpc.NewServer() s := grpc.NewServer()
server := client.New(ctx, server := client.New(ctx,
"", "", false, false, false) "", "", false, false, false, false)
if err := server.Start(); err != nil { if err := server.Start(); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -39,6 +39,9 @@ const (
noBrowserFlag = "no-browser" noBrowserFlag = "no-browser"
noBrowserDesc = "do not open the browser for SSO login" noBrowserDesc = "do not open the browser for SSO login"
showQRFlag = "qr"
showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)"
profileNameFlag = "profile" profileNameFlag = "profile"
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used." profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
) )
@@ -48,6 +51,7 @@ var (
dnsLabels []string dnsLabels []string
dnsLabelsValidated domain.List dnsLabelsValidated domain.List
noBrowser bool noBrowser bool
showQR bool
profileName string profileName string
configPath string configPath string
@@ -80,6 +84,7 @@ func init() {
) )
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ") upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")

65
client/embed/capture.go Normal file
View File

@@ -0,0 +1,65 @@
package embed
import (
"io"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util/capture"
)
// CaptureOptions configures a packet capture session.
type CaptureOptions struct {
// Output receives pcap-formatted data. Nil disables pcap output.
Output io.Writer
// TextOutput receives human-readable packet summaries. Nil disables text output.
TextOutput io.Writer
// Filter is a BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443").
// Empty captures all packets.
Filter string
// Verbose adds seq/ack, TTL, window, and total length to text output.
Verbose bool
// ASCII dumps transport payload as printable ASCII after each packet line.
ASCII bool
}
// CaptureStats reports capture session counters.
type CaptureStats struct {
Packets int64
Bytes int64
Dropped int64
}
// CaptureSession represents an active packet capture. Call Stop to end the
// capture and flush buffered packets.
type CaptureSession struct {
sess *capture.Session
engine *internal.Engine
}
// Stop ends the capture, flushes remaining packets, and detaches from the device.
// Safe to call multiple times.
func (cs *CaptureSession) Stop() {
if cs.engine != nil {
_ = cs.engine.SetCapture(nil)
cs.engine = nil
}
if cs.sess != nil {
cs.sess.Stop()
}
}
// Stats returns current capture counters.
func (cs *CaptureSession) Stats() CaptureStats {
s := cs.sess.Stats()
return CaptureStats{
Packets: s.Packets,
Bytes: s.Bytes,
Dropped: s.Dropped,
}
}
// Done returns a channel that is closed when the capture's writer goroutine
// has fully exited and all buffered packets have been flushed.
func (cs *CaptureSession) Done() <-chan struct{} {
return cs.sess.Done()
}

View File

@@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util/capture"
) )
var ( var (
@@ -65,7 +66,7 @@ type Options struct {
PrivateKey string PrivateKey string
// ManagementURL overrides the default management server URL // ManagementURL overrides the default management server URL
ManagementURL string ManagementURL string
// PreSharedKey is the pre-shared key for the WireGuard interface // PreSharedKey is the pre-shared key for the tunnel interface
PreSharedKey string PreSharedKey string
// LogOutput is the output destination for logs (defaults to os.Stderr if nil) // LogOutput is the output destination for logs (defaults to os.Stderr if nil)
LogOutput io.Writer LogOutput io.Writer
@@ -81,9 +82,9 @@ type Options struct {
DisableClientRoutes bool DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers // BlockInbound blocks all inbound connections from peers
BlockInbound bool BlockInbound bool
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port. // WireguardPort is the port for the tunnel interface. Use 0 for a random port.
WireguardPort *int WireguardPort *int
// MTU is the MTU for the WireGuard interface. // MTU is the MTU for the tunnel interface.
// Valid values are in the range 576..8192 bytes. // Valid values are in the range 576..8192 bytes.
// If non-nil, this value overrides any value stored in the config file. // If non-nil, this value overrides any value stored in the config file.
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280. // If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
@@ -469,6 +470,52 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
return sshcommon.VerifyHostKey(storedKey, key, peerAddress) return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
} }
// StartCapture begins capturing packets on this client's tunnel device.
// Only one capture can be active at a time; starting a new one stops the previous.
// Call StopCapture (or CaptureSession.Stop) to end it.
func (c *Client) StartCapture(opts CaptureOptions) (*CaptureSession, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
var matcher capture.Matcher
if opts.Filter != "" {
m, err := capture.ParseFilter(opts.Filter)
if err != nil {
return nil, fmt.Errorf("parse filter: %w", err)
}
matcher = m
}
sess, err := capture.NewSession(capture.Options{
Output: opts.Output,
TextOutput: opts.TextOutput,
Matcher: matcher,
Verbose: opts.Verbose,
ASCII: opts.ASCII,
})
if err != nil {
return nil, fmt.Errorf("create capture session: %w", err)
}
if err := engine.SetCapture(sess); err != nil {
sess.Stop()
return nil, fmt.Errorf("set capture: %w", err)
}
return &CaptureSession{sess: sess, engine: engine}, nil
}
// StopCapture stops the active capture session if one is running.
func (c *Client) StopCapture() error {
engine, err := c.getEngine()
if err != nil {
return err
}
return engine.SetCapture(nil)
}
// getEngine safely retrieves the engine from the client with proper locking. // getEngine safely retrieves the engine from the client with proper locking.
// Returns ErrClientNotStarted if the client is not started. // Returns ErrClientNotStarted if the client is not started.
// Returns ErrEngineNotStarted if the engine is not available. // Returns ErrEngineNotStarted if the engine is not available.

View File

@@ -0,0 +1,11 @@
// Package firewalld integrates with the firewalld daemon so NetBird can place
// its wg interface into firewalld's "trusted" zone. This is required because
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
// versions, which returns EPERM to any other process that tries to insert
// rules into them. The workaround mirrors what Tailscale does: let firewalld
// itself add the accept rules to its own chains by trusting the interface.
package firewalld
// TrustedZone is the firewalld zone name used for interfaces whose traffic
// should bypass firewalld filtering.
const TrustedZone = "trusted"

View File

@@ -0,0 +1,260 @@
//go:build linux
package firewalld
import (
"context"
"errors"
"fmt"
"os/exec"
"strings"
"sync"
"time"
"github.com/godbus/dbus/v5"
log "github.com/sirupsen/logrus"
)
const (
dbusDest = "org.fedoraproject.FirewallD1"
dbusPath = "/org/fedoraproject/FirewallD1"
dbusRootIface = "org.fedoraproject.FirewallD1"
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
errZoneAlreadySet = "ZONE_ALREADY_SET"
errAlreadyEnabled = "ALREADY_ENABLED"
errUnknownIface = "UNKNOWN_INTERFACE"
errNotEnabled = "NOT_ENABLED"
// callTimeout bounds each individual DBus or firewall-cmd invocation.
// A fresh context is created for each call so a slow DBus probe can't
// exhaust the deadline before the firewall-cmd fallback gets to run.
callTimeout = 3 * time.Second
)
var (
errDBusUnavailable = errors.New("firewalld dbus unavailable")
// trustLogOnce ensures the "added to trusted zone" message is logged at
// Info level only for the first successful add per process; repeat adds
// from other init paths are quieter.
trustLogOnce sync.Once
parentCtxMu sync.RWMutex
parentCtx context.Context = context.Background()
)
// SetParentContext installs a parent context whose cancellation aborts any
// in-flight TrustInterface call. It does not affect UntrustInterface, which
// always uses a fresh Background-rooted timeout so cleanup can still run
// during engine shutdown when the engine context is already cancelled.
func SetParentContext(ctx context.Context) {
parentCtxMu.Lock()
parentCtx = ctx
parentCtxMu.Unlock()
}
func getParentContext() context.Context {
parentCtxMu.RLock()
defer parentCtxMu.RUnlock()
return parentCtx
}
// TrustInterface places iface into firewalld's trusted zone if firewalld is
// running. It is idempotent and best-effort: errors are returned so callers
// can log, but a non-running firewalld is not an error. Only the first
// successful call per process logs at Info. Respects the parent context set
// via SetParentContext so startup-time cancellation unblocks it.
func TrustInterface(iface string) error {
parent := getParentContext()
if !isRunning(parent) {
return nil
}
if err := addTrusted(parent, iface); err != nil {
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
}
trustLogOnce.Do(func() {
log.Infof("added %s to firewalld trusted zone", iface)
})
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
return nil
}
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
// during shutdown after the engine context has been cancelled.
func UntrustInterface(iface string) error {
if !isRunning(context.Background()) {
return nil
}
if err := removeTrusted(context.Background(), iface); err != nil {
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
}
return nil
}
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
return context.WithTimeout(parent, callTimeout)
}
func isRunning(parent context.Context) bool {
ctx, cancel := newCallContext(parent)
ok, err := isRunningDBus(ctx)
cancel()
if err == nil {
return ok
}
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
ctx, cancel = newCallContext(parent)
defer cancel()
return isRunningCLI(ctx)
}
return false
}
func addTrusted(parent context.Context, iface string) error {
ctx, cancel := newCallContext(parent)
err := addDBus(ctx, iface)
cancel()
if err == nil {
return nil
}
if !errors.Is(err, errDBusUnavailable) {
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
}
ctx, cancel = newCallContext(parent)
defer cancel()
return addCLI(ctx, iface)
}
func removeTrusted(parent context.Context, iface string) error {
ctx, cancel := newCallContext(parent)
err := removeDBus(ctx, iface)
cancel()
if err == nil {
return nil
}
if !errors.Is(err, errDBusUnavailable) {
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
}
ctx, cancel = newCallContext(parent)
defer cancel()
return removeCLI(ctx, iface)
}
func isRunningDBus(ctx context.Context) (bool, error) {
conn, err := dbus.SystemBus()
if err != nil {
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
var zone string
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
}
return true, nil
}
func isRunningCLI(ctx context.Context) bool {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return false
}
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
}
func addDBus(ctx context.Context, iface string) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
if call.Err == nil {
return nil
}
if dbusErrContains(call.Err, errAlreadyEnabled) {
return nil
}
if dbusErrContains(call.Err, errZoneAlreadySet) {
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
if move.Err != nil {
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
}
return nil
}
return fmt.Errorf("firewalld addInterface: %w", call.Err)
}
func removeDBus(ctx context.Context, iface string) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
if call.Err == nil {
return nil
}
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
return nil
}
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
}
func addCLI(ctx context.Context, iface string) error {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return fmt.Errorf("firewall-cmd not available: %w", err)
}
// --change-interface (no --permanent) binds the interface for the
// current runtime only; we do not want membership to persist across
// reboots because netbird re-asserts it on every startup.
out, err := exec.CommandContext(ctx,
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
).CombinedOutput()
if err != nil {
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
}
return nil
}
func removeCLI(ctx context.Context, iface string) error {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return fmt.Errorf("firewall-cmd not available: %w", err)
}
out, err := exec.CommandContext(ctx,
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
).CombinedOutput()
if err != nil {
msg := strings.TrimSpace(string(out))
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
return nil
}
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
}
return nil
}
func dbusErrContains(err error, code string) bool {
if err == nil {
return false
}
var de dbus.Error
if errors.As(err, &de) {
for _, b := range de.Body {
if s, ok := b.(string); ok && strings.Contains(s, code) {
return true
}
}
}
return strings.Contains(err.Error(), code)
}

View File

@@ -0,0 +1,49 @@
//go:build linux
package firewalld
import (
"errors"
"testing"
"github.com/godbus/dbus/v5"
)
func TestDBusErrContains(t *testing.T) {
tests := []struct {
name string
err error
code string
want bool
}{
{"nil error", nil, errZoneAlreadySet, false},
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
{
"dbus.Error body match",
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
errZoneAlreadySet,
true,
},
{
"dbus.Error body miss",
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
errAlreadyEnabled,
false,
},
{
"dbus.Error non-string body falls back to Error()",
dbus.Error{Name: "x", Body: []any{123}},
"x",
true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := dbusErrContains(tc.err, tc.code)
if got != tc.want {
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
}
})
}
}

View File

@@ -0,0 +1,25 @@
//go:build !linux
package firewalld
import "context"
// SetParentContext is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func SetParentContext(context.Context) {
// intentionally empty: firewalld is a Linux-only daemon
}
// TrustInterface is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func TrustInterface(string) error {
// intentionally empty: firewalld is a Linux-only daemon
return nil
}
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func UntrustInterface(string) error {
// intentionally empty: firewalld is a Linux-only daemon
return nil
}

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@@ -86,6 +87,12 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
log.Warnf("raw table not available, notrack rules will be disabled: %v", err) log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
} }
// Trust after all fatal init steps so a later failure doesn't leave the
// interface in firewalld's trusted zone without a corresponding Close.
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
// persist early to ensure cleanup of chains // persist early to ensure cleanup of chains
go func() { go func() {
if err := stateManager.PersistState(context.Background()); err != nil { if err := stateManager.PersistState(context.Background()); err != nil {
@@ -191,6 +198,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err)) merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
} }
// Appending to merr intentionally blocks DeleteState below so ShutdownState
// stays persisted and the crash-recovery path retries firewalld cleanup.
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
// attempt to delete state only if all other operations succeeded // attempt to delete state only if all other operations succeeded
if merr == nil { if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil { if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
@@ -217,6 +230,11 @@ func (m *Manager) AllowNetbird() error {
if err != nil { if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err) return fmt.Errorf("allow netbird interface traffic: %w", err)
} }
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil return nil
} }

View File

@@ -14,6 +14,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
@@ -217,6 +218,10 @@ func (m *Manager) AllowNetbird() error {
return fmt.Errorf("flush allow input netbird rules: %w", err) return fmt.Errorf("flush allow input netbird rules: %w", err)
} }
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil return nil
} }

View File

@@ -19,6 +19,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id" nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
@@ -40,6 +41,8 @@ const (
chainNameForward = "FORWARD" chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward" chainNameMangleForward = "netbird-mangle-forward"
firewalldTableName = "firewalld"
userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif" userDataAcceptForwardRuleOif = "frwacceptoif"
userDataAcceptInputRule = "inputaccept" userDataAcceptInputRule = "inputaccept"
@@ -133,6 +136,10 @@ func (r *router) Reset() error {
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
} }
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeNatPreroutingRules(); err != nil { if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err)) merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
} }
@@ -280,6 +287,10 @@ func (r *router) createContainers() error {
log.Errorf("failed to add accept rules for the forward chain: %s", err) log.Errorf("failed to add accept rules for the forward chain: %s", err)
} }
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to refresh rules: %s", err) log.Errorf("failed to refresh rules: %s", err)
} }
@@ -1319,6 +1330,13 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
return false return false
} }
// Skip firewalld-owned chains. Firewalld creates its chains with the
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
// We delegate acceptance to firewalld by trusting the interface instead.
if chain.Table.Name == firewalldTableName {
return false
}
// Skip all iptables-managed tables in the ip family // Skip all iptables-managed tables in the ip family
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
return false return false

View File

@@ -3,6 +3,9 @@
package uspfilter package uspfilter
import ( import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/firewalld"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -16,6 +19,9 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager) return m.nativeFirewall.Close(stateManager)
} }
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to untrust interface in firewalld: %v", err)
}
return nil return nil
} }
@@ -24,5 +30,8 @@ func (m *Manager) AllowNetbird() error {
if m.nativeFirewall != nil { if m.nativeFirewall != nil {
return m.nativeFirewall.AllowNetbird() return m.nativeFirewall.AllowNetbird()
} }
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil return nil
} }

View File

@@ -9,6 +9,7 @@ import (
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
Name() string
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
Address() wgaddr.Address Address() wgaddr.Address
GetWGDevice() *wgdevice.Device GetWGDevice() *wgdevice.Device

View File

@@ -115,12 +115,13 @@ type Manager struct {
localipmanager *localIPManager localipmanager *localIPManager
udpTracker *conntrack.UDPTracker udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker tcpTracker *conntrack.TCPTracker
forwarder atomic.Pointer[forwarder.Forwarder] forwarder atomic.Pointer[forwarder.Forwarder]
logger *nblog.Logger pendingCapture atomic.Pointer[forwarder.PacketCapture]
flowLogger nftypes.FlowLogger logger *nblog.Logger
flowLogger nftypes.FlowLogger
blockRule firewall.Rule blockRule firewall.Rule
@@ -351,6 +352,19 @@ func (m *Manager) determineRouting() error {
return nil return nil
} }
// SetPacketCapture sets or clears packet capture on the forwarder endpoint.
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
func (m *Manager) SetPacketCapture(pc forwarder.PacketCapture) {
if pc == nil {
m.pendingCapture.Store(nil)
} else {
m.pendingCapture.Store(&pc)
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.SetCapture(pc)
}
}
// initForwarder initializes the forwarder, it disables routing on errors // initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder() error { func (m *Manager) initForwarder() error {
if m.forwarder.Load() != nil { if m.forwarder.Load() != nil {
@@ -372,6 +386,11 @@ func (m *Manager) initForwarder() error {
m.forwarder.Store(forwarder) m.forwarder.Store(forwarder)
// Re-load after store: a concurrent SetPacketCapture may have seen forwarder as nil and only updated pendingCapture.
if pc := m.pendingCapture.Load(); pc != nil {
forwarder.SetCapture(*pc)
}
log.Debug("forwarder initialized") log.Debug("forwarder initialized")
return nil return nil
@@ -614,6 +633,7 @@ func (m *Manager) resetState() {
} }
if fwder := m.forwarder.Load(); fwder != nil { if fwder := m.forwarder.Load(); fwder != nil {
fwder.SetCapture(nil)
fwder.Stop() fwder.Stop()
} }

View File

@@ -31,12 +31,20 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
type IFaceMock struct { type IFaceMock struct {
NameFunc func() string
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
AddressFunc func() wgaddr.Address AddressFunc func() wgaddr.Address
GetWGDeviceFunc func() *wgdevice.Device GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice GetDeviceFunc func() *device.FilteredDevice
} }
func (i *IFaceMock) Name() string {
if i.NameFunc == nil {
return "wgtest"
}
return i.NameFunc()
}
func (i *IFaceMock) GetWGDevice() *wgdevice.Device { func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
if i.GetWGDeviceFunc == nil { if i.GetWGDeviceFunc == nil {
return nil return nil

View File

@@ -12,12 +12,19 @@ import (
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
// PacketCapture captures raw packets for debugging. Implementations must be
// safe for concurrent use and must not block.
type PacketCapture interface {
Offer(data []byte, outbound bool)
}
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device // endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
type endpoint struct { type endpoint struct {
logger *nblog.Logger logger *nblog.Logger
dispatcher stack.NetworkDispatcher dispatcher stack.NetworkDispatcher
device *wgdevice.Device device *wgdevice.Device
mtu atomic.Uint32 mtu atomic.Uint32
capture atomic.Pointer[PacketCapture]
} }
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
@@ -54,13 +61,17 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
continue continue
} }
// Send the packet through WireGuard pktBytes := data.AsSlice()
address := netHeader.DestinationAddress() address := netHeader.DestinationAddress()
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()) if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil {
if err != nil {
e.logger.Error1("CreateOutboundPacket: %v", err) e.logger.Error1("CreateOutboundPacket: %v", err)
continue continue
} }
if pc := e.capture.Load(); pc != nil {
(*pc).Offer(pktBytes, true)
}
written++ written++
} }

View File

@@ -139,6 +139,16 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return f, nil return f, nil
} }
// SetCapture sets or clears the packet capture on the forwarder endpoint.
// This captures outbound packets that bypass the FilteredDevice (netstack forwarding).
func (f *Forwarder) SetCapture(pc PacketCapture) {
if pc == nil {
f.endpoint.capture.Store(nil)
return
}
f.endpoint.capture.Store(&pc)
}
func (f *Forwarder) InjectIncomingPacket(payload []byte) error { func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
if len(payload) < header.IPv4MinimumSize { if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("packet too small: %d bytes", len(payload)) return fmt.Errorf("packet too small: %d bytes", len(payload))

View File

@@ -270,5 +270,9 @@ func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []
return 0 return 0
} }
if pc := f.endpoint.capture.Load(); pc != nil {
(*pc).Offer(fullPacket, true)
}
return len(fullPacket) return len(fullPacket)
} }

View File

@@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
ipv6Count++ ipv6Count++
} }
assert.Equal(t, packetsPerFamily, ipv4Count) // Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
assert.Equal(t, packetsPerFamily, ipv6Count) // routing-correctness checks above are the real assertions; the counts
// are a sanity bound to catch a totally silent path.
minDelivered := packetsPerFamily * 80 / 100
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
} }
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) { func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package device
import ( import (
"net/netip" "net/netip"
"sync" "sync"
"sync/atomic"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
@@ -28,11 +29,20 @@ type PacketFilter interface {
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
} }
// PacketCapture captures raw packets for debugging. Implementations must be
// safe for concurrent use and must not block.
type PacketCapture interface {
// Offer submits a packet for capture. outbound is true for packets
// leaving the host (Read path), false for packets arriving (Write path).
Offer(data []byte, outbound bool)
}
// FilteredDevice to override Read or Write of packets // FilteredDevice to override Read or Write of packets
type FilteredDevice struct { type FilteredDevice struct {
tun.Device tun.Device
filter PacketFilter filter PacketFilter
capture atomic.Pointer[PacketCapture]
mutex sync.RWMutex mutex sync.RWMutex
closeOnce sync.Once closeOnce sync.Once
} }
@@ -63,20 +73,25 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
if n, err = d.Device.Read(bufs, sizes, offset); err != nil { if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
return 0, err return 0, err
} }
d.mutex.RLock() d.mutex.RLock()
filter := d.filter filter := d.filter
d.mutex.RUnlock() d.mutex.RUnlock()
if filter == nil { if filter != nil {
return for i := 0; i < n; i++ {
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
i--
}
}
} }
for i := 0; i < n; i++ { if pc := d.capture.Load(); pc != nil {
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) { for i := 0; i < n; i++ {
bufs = append(bufs[:i], bufs[i+1:]...) (*pc).Offer(bufs[i][offset:offset+sizes[i]], true)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
i--
} }
} }
@@ -85,6 +100,13 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
// Write wraps write method with filtering feature // Write wraps write method with filtering feature
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
// Capture before filtering so dropped packets are still visible in captures.
if pc := d.capture.Load(); pc != nil {
for _, buf := range bufs {
(*pc).Offer(buf[offset:], false)
}
}
d.mutex.RLock() d.mutex.RLock()
filter := d.filter filter := d.filter
d.mutex.RUnlock() d.mutex.RUnlock()
@@ -96,9 +118,10 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs)) filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0 dropped := 0
for _, buf := range bufs { for _, buf := range bufs {
if !filter.FilterInbound(buf[offset:], len(buf)) { if filter.FilterInbound(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf)
dropped++ dropped++
} else {
filteredBufs = append(filteredBufs, buf)
} }
} }
@@ -113,3 +136,14 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) {
d.filter = filter d.filter = filter
d.mutex.Unlock() d.mutex.Unlock()
} }
// SetCapture sets or clears the packet capture sink. Pass nil to disable.
// Uses atomic store so the hot path (Read/Write) is a single pointer load
// with no locking overhead when capture is off.
func (d *FilteredDevice) SetCapture(pc PacketCapture) {
if pc == nil {
d.capture.Store(nil)
return
}
d.capture.Store(&pc)
}

View File

@@ -158,7 +158,7 @@ func TestDeviceWrapperRead(t *testing.T) {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
return return
} }
if n != 0 { if n != 1 {
t.Errorf("expected n=1, got %d", n) t.Errorf("expected n=1, got %d", n)
return return
} }

View File

@@ -217,7 +217,6 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
// Close closes the tunnel interface // Close closes the tunnel interface
func (w *WGIface) Close() error { func (w *WGIface) Close() error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock()
var result *multierror.Error var result *multierror.Error
@@ -225,7 +224,15 @@ func (w *WGIface) Close() error {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err)) result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
} }
if err := w.tun.Close(); err != nil { // Release w.mu before calling w.tun.Close(): the underlying
// wireguard-go device.Close() waits for its send/receive goroutines
// to drain. Some of those goroutines re-enter WGIface methods that
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
// holding the mutex here would deadlock the shutdown path.
tun := w.tun
w.mu.Unlock()
if err := tun.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)) result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
} }

View File

@@ -0,0 +1,113 @@
//go:build !android
package iface
import (
"errors"
"sync"
"testing"
"time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// fakeTunDevice implements WGTunDevice and lets the test control when
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
// until its goroutines drain. Some of those goroutines (e.g. the packet
// filter DNS hook in client/internal/dns) call back into WGIface, so if
// WGIface.Close() held w.mu across tun.Close() the shutdown would
// deadlock.
type fakeTunDevice struct {
closeStarted chan struct{}
unblockClose chan struct{}
}
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
func (f *fakeTunDevice) Close() error {
close(f.closeStarted)
<-f.unblockClose
return nil
}
type fakeProxyFactory struct{}
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
func (fakeProxyFactory) Free() error { return nil }
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
// that surfaces as a macOS test-timeout in
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
// while waiting for the wireguard-go device goroutines to finish, and
// one of those goroutines (the DNS filter hook) calls back into
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
// the lock before tun.Close() returns control.
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
tun := &fakeTunDevice{
closeStarted: make(chan struct{}),
unblockClose: make(chan struct{}),
}
w := &WGIface{
tun: tun,
wgProxyFactory: fakeProxyFactory{},
}
closeDone := make(chan error, 1)
go func() {
closeDone <- w.Close()
}()
select {
case <-tun.closeStarted:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
t.Fatal("tun.Close() was never invoked")
}
// Simulate the WireGuard read goroutine calling back into WGIface
// via the packet filter's DNS hook. If Close() still held w.mu
// during tun.Close(), this would block until the test timeout.
getDeviceDone := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = w.GetDevice()
close(getDeviceDone)
}()
select {
case <-getDeviceDone:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
wg.Wait()
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
}
close(tun.unblockClose)
select {
case <-closeDone:
case <-time.After(2 * time.Second):
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
}
}

View File

@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
} }
if u.address.Network.Contains(a) { if u.address.Network.Contains(a) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address) log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
} }
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
u.addrCache.Store(addr.String(), isRouted) u.addrCache.Store(addr.String(), isRouted)
if isRouted { if isRouted {
// Extra log, as the error only shows up with ICE logging enabled // Extra log, as the error only shows up with ICE logging enabled
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix) log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix) return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
} }
} }

View File

@@ -201,7 +201,18 @@ Pop $0
Function .onInit Function .onInit
StrCpy $INSTDIR "${INSTALL_DIR}" StrCpy $INSTDIR "${INSTALL_DIR}"
; Default autostart to enabled so silent installs (/S) match the interactive default
StrCpy $AutostartEnabled "1"
; Pre-0.70.1 installers ran without SetRegView, so their uninstall keys live
; in the 32-bit view. Fall back to it so upgrades still find them.
SetRegView 64
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString" ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
${If} $R0 == ""
SetRegView 32
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
SetRegView 64
${EndIf}
${If} $R0 != "" ${If} $R0 != ""
# if silent install jump to uninstall step # if silent install jump to uninstall step
IfSilent uninstall IfSilent uninstall
@@ -214,6 +225,10 @@ ${If} $R0 != ""
${EndIf} ${EndIf}
FunctionEnd FunctionEnd
Function un.onInit
SetRegView 64
FunctionEnd
###################################################################### ######################################################################
Section -MainProgram Section -MainProgram
${INSTALL_TYPE} ${INSTALL_TYPE}
@@ -228,6 +243,7 @@ Section -MainProgram
!else !else
File /r "..\\dist\\netbird_windows_amd64\\" File /r "..\\dist\\netbird_windows_amd64\\"
!endif !endif
File "..\\client\\ui\\assets\\netbird.png"
SectionEnd SectionEnd
###################################################################### ######################################################################
@@ -247,9 +263,11 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox ; Create autostart registry entry based on checkbox
DetailPrint "Autostart enabled: $AutostartEnabled" DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1" ${If} $AutostartEnabled == "1"
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe" WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe" DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else} ${Else}
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user" DetailPrint "Autostart not enabled by user"
${EndIf} ${EndIf}
@@ -283,6 +301,8 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry ; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..." DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Handle data deletion based on checkbox ; Handle data deletion based on checkbox
@@ -321,6 +341,7 @@ DetailPrint "Removing registry keys..."
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}" DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}" DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}" DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
DeleteRegKey HKCU "Software\Classes\AppUserModelId\${APP_NAME}"
DetailPrint "Removing application directory from PATH..." DetailPrint "Removing application directory from PATH..."
EnVar::SetHKLM EnVar::SetHKLM

View File

@@ -94,6 +94,7 @@ func (c *ConnectClient) RunOnAndroid(
dnsAddresses []netip.AddrPort, dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener, dnsReadyListener dns.ReadyListener,
stateFilePath string, stateFilePath string,
cacheDir string,
) error { ) error {
// in case of non Android os these variables will be nil // in case of non Android os these variables will be nil
mobileDependency := MobileDependency{ mobileDependency := MobileDependency{
@@ -103,6 +104,7 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses, HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener, DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath, StateFilePath: stateFilePath,
TempDir: cacheDir,
} }
return c.run(mobileDependency, nil, "") return c.run(mobileDependency, nil, "")
} }
@@ -331,6 +333,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.statusRecorder.MarkSignalConnected() c.statusRecorder.MarkSignalConnected()
relayURLs, token := parseRelayInfo(loginResp) relayURLs, token := parseRelayInfo(loginResp)
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
relayURLs = override
}
peerConfig := loginResp.GetPeerConfig() peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath) engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
@@ -338,6 +344,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Error(err) log.Error(err)
return wrapErr(err) return wrapErr(err)
} }
engineConfig.TempDir = mobileDependency.TempDir
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU) relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
c.statusRecorder.SetRelayMgr(relayManager) c.statusRecorder.SetRelayMgr(relayManager)

View File

@@ -16,7 +16,6 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"runtime/pprof" "runtime/pprof"
"slices"
"sort" "sort"
"strings" "strings"
"time" "time"
@@ -31,7 +30,6 @@ import (
"github.com/netbirdio/netbird/client/internal/updater/installer" "github.com/netbirdio/netbird/client/internal/updater/installer"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
) )
const readmeContent = `Netbird debug bundle const readmeContent = `Netbird debug bundle
@@ -63,6 +61,7 @@ allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information. threadcreate.prof: Thread creation profiling information.
cpu.prof: CPU profiling information. cpu.prof: CPU profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation. stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
capture.pcap: Packet capture in pcap format. Only present when capture was running during bundle collection. Omitted from anonymized bundles because it contains raw decrypted packet data.
Anonymization Process Anonymization Process
@@ -234,7 +233,9 @@ type BundleGenerator struct {
statusRecorder *peer.Status statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse syncResponse *mgmProto.SyncResponse
logPath string logPath string
tempDir string
cpuProfile []byte cpuProfile []byte
capturePath string
refreshStatus func() // Optional callback to refresh status before bundle generation refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter clientMetrics MetricsExporter
@@ -256,8 +257,10 @@ type GeneratorDependencies struct {
StatusRecorder *peer.Status StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse SyncResponse *mgmProto.SyncResponse
LogPath string LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
CPUProfile []byte CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation CapturePath string
RefreshStatus func()
ClientMetrics MetricsExporter ClientMetrics MetricsExporter
} }
@@ -275,7 +278,9 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
statusRecorder: deps.StatusRecorder, statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse, syncResponse: deps.SyncResponse,
logPath: deps.LogPath, logPath: deps.LogPath,
tempDir: deps.TempDir,
cpuProfile: deps.CPUProfile, cpuProfile: deps.CPUProfile,
capturePath: deps.CapturePath,
refreshStatus: deps.RefreshStatus, refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics, clientMetrics: deps.ClientMetrics,
@@ -287,7 +292,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
// Generate creates a debug bundle and returns the location. // Generate creates a debug bundle and returns the location.
func (g *BundleGenerator) Generate() (resp string, err error) { func (g *BundleGenerator) Generate() (resp string, err error) {
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip") bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
if err != nil { if err != nil {
return "", fmt.Errorf("create zip file: %w", err) return "", fmt.Errorf("create zip file: %w", err)
} }
@@ -345,6 +350,10 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add CPU profile to debug bundle: %v", err) log.Errorf("failed to add CPU profile to debug bundle: %v", err)
} }
if err := g.addCaptureFile(); err != nil {
log.Errorf("failed to add capture file to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil { if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err) log.Errorf("failed to add stack trace to debug bundle: %v", err)
} }
@@ -373,15 +382,8 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err) log.Errorf("failed to add wg show output: %v", err)
} }
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) { if err := g.addPlatformLog(); err != nil {
if err := g.addLogfile(); err != nil { log.Errorf("failed to add logs to debug bundle: %v", err)
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs as fallback: %v", err)
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs: %v", err)
} }
if err := g.addUpdateLogs(); err != nil { if err := g.addUpdateLogs(); err != nil {
@@ -605,6 +607,12 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.EnableSSHRemotePortForwarding != nil { if g.internalConfig.EnableSSHRemotePortForwarding != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding)) configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
} }
if g.internalConfig.DisableSSHAuth != nil {
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
}
if g.internalConfig.SSHJWTCacheTTL != nil {
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes)) configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes)) configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
@@ -631,6 +639,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
} }
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled)) configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
} }
func (g *BundleGenerator) addProf() (err error) { func (g *BundleGenerator) addProf() (err error) {
@@ -675,6 +684,29 @@ func (g *BundleGenerator) addCPUProfile() error {
return nil return nil
} }
func (g *BundleGenerator) addCaptureFile() error {
if g.capturePath == "" {
return nil
}
if g.anonymize {
log.Info("skipping capture file in anonymized bundle (contains raw packet data)")
return nil
}
f, err := os.Open(g.capturePath)
if err != nil {
return fmt.Errorf("open capture file: %w", err)
}
defer f.Close()
if err := g.addFileToZip(f, "capture.pcap"); err != nil {
return fmt.Errorf("add capture file to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addStackTrace() error { func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true) n := runtime.Stack(buf, true)

View File

@@ -0,0 +1,41 @@
//go:build android
package debug
import (
"fmt"
"io"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (g *BundleGenerator) addPlatformLog() error {
cmd := exec.Command("/system/bin/logcat", "-d")
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("logcat stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("start logcat: %w", err)
}
var logReader io.Reader = stdout
if g.anonymize {
var pw *io.PipeWriter
logReader, pw = io.Pipe()
go anonymizeLog(stdout, pw, g.anonymizer)
}
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
return fmt.Errorf("add logcat to zip: %w", err)
}
if err := cmd.Wait(); err != nil {
return fmt.Errorf("wait logcat: %w", err)
}
log.Debug("added logcat output to debug bundle")
return nil
}

View File

@@ -0,0 +1,25 @@
//go:build !android
package debug
import (
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
func (g *BundleGenerator) addPlatformLog() error {
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
return err
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
return err
}
return nil
}

View File

@@ -5,16 +5,21 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"net" "net"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs" "github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
) )
@@ -471,8 +476,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false, anonymize: false,
input: map[string]any{ input: map[string]any{
jsonKeyServiceEnv: map[string]any{ jsonKeyServiceEnv: map[string]any{
"HOME": "/root", "HOME": "/root",
"PATH": "/usr/bin", "PATH": "/usr/bin",
"NB_LOG_LEVEL": "debug", "NB_LOG_LEVEL": "debug",
}, },
}, },
@@ -489,9 +494,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false, anonymize: false,
input: map[string]any{ input: map[string]any{
jsonKeyServiceEnv: map[string]any{ jsonKeyServiceEnv: map[string]any{
"NB_SETUP_KEY": "abc123", "NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz", "NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info", "NB_LOG_LEVEL": "info",
}, },
}, },
check: func(t *testing.T, params map[string]any) { check: func(t *testing.T, params map[string]any) {
@@ -766,3 +771,127 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
assert.Contains(t, anonNftables, "chain input {") assert.Contains(t, anonNftables, "chain input {")
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
} }
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
// profilemanager.Config is either rendered in the debug bundle or explicitly
// excluded. When a new field is added to Config, this test fails until the
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
// the excluded set with a justification.
func TestAddConfig_AllFieldsCovered(t *testing.T) {
excluded := map[string]string{
"PrivateKey": "sensitive: WireGuard private key",
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
}
mURL, _ := url.Parse("https://api.example.com:443")
aURL, _ := url.Parse("https://admin.example.com:443")
bTrue := true
iVal := 42
cfg := &profilemanager.Config{
PrivateKey: "priv",
PreSharedKey: "psk",
ManagementURL: mURL,
AdminURL: aURL,
WgIface: "wt0",
WgPort: 51820,
NetworkMonitor: &bTrue,
IFaceBlackList: []string{"eth0"},
DisableIPv6Discovery: true,
RosenpassEnabled: true,
RosenpassPermissive: true,
ServerSSHAllowed: &bTrue,
EnableSSHRoot: &bTrue,
EnableSSHSFTP: &bTrue,
EnableSSHLocalPortForwarding: &bTrue,
EnableSSHRemotePortForwarding: &bTrue,
DisableSSHAuth: &bTrue,
SSHJWTCacheTTL: &iVal,
DisableClientRoutes: true,
DisableServerRoutes: true,
DisableDNS: true,
DisableFirewall: true,
BlockLANAccess: true,
BlockInbound: true,
DisableNotifications: &bTrue,
DNSLabels: domain.List{},
SSHKey: "sshkey",
NATExternalIPs: []string{"1.2.3.4"},
CustomDNSAddress: "1.1.1.1:53",
DisableAutoConnect: true,
DNSRouteInterval: 5 * time.Second,
ClientCertPath: "/tmp/cert",
ClientCertKeyPath: "/tmp/key",
LazyConnectionEnabled: true,
MTU: 1280,
}
for _, anonymize := range []bool{false, true} {
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
g := &BundleGenerator{
anonymizer: newAnonymizerForTest(),
internalConfig: cfg,
anonymize: anonymize,
}
var sb strings.Builder
g.addCommonConfigFields(&sb)
rendered := sb.String() + renderAddConfigSpecific(g)
val := reflect.ValueOf(cfg).Elem()
typ := val.Type()
var missing []string
for i := 0; i < typ.NumField(); i++ {
name := typ.Field(i).Name
if _, ok := excluded[name]; ok {
continue
}
if !strings.Contains(rendered, name+":") {
missing = append(missing, name)
}
}
if len(missing) > 0 {
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
"Either render the field in addCommonConfigFields/addConfig, "+
"or add it to the excluded map with a justification.", missing)
}
})
}
}
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
// production shape without needing to write an actual zip.
func renderAddConfigSpecific(g *BundleGenerator) string {
var sb strings.Builder
if g.anonymize {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
}
} else {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
}
}
return sb.String()
}
func newAnonymizerForTest() *anonymize.Anonymizer {
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
}

View File

@@ -3,10 +3,12 @@ package debug
import ( import (
"context" "context"
"errors" "errors"
"net"
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
"time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@@ -19,8 +21,10 @@ func TestUpload(t *testing.T) {
t.Skip("Skipping upload test on docker ci") t.Skip("Skipping upload test on docker ci")
} }
testDir := t.TempDir() testDir := t.TempDir()
testURL := "http://localhost:8080" addr := reserveLoopbackPort(t)
testURL := "http://" + addr
t.Setenv("SERVER_URL", testURL) t.Setenv("SERVER_URL", testURL)
t.Setenv("SERVER_ADDRESS", addr)
t.Setenv("STORE_DIR", testDir) t.Setenv("STORE_DIR", testDir)
srv := server.NewServer() srv := server.NewServer()
go func() { go func() {
@@ -33,6 +37,7 @@ func TestUpload(t *testing.T) {
t.Errorf("Failed to stop server: %v", err) t.Errorf("Failed to stop server: %v", err)
} }
}) })
waitForServer(t, addr)
file := filepath.Join(t.TempDir(), "tmpfile") file := filepath.Join(t.TempDir(), "tmpfile")
fileContent := []byte("test file content") fileContent := []byte("test file content")
@@ -47,3 +52,30 @@ func TestUpload(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, fileContent, createdFileContent) require.Equal(t, fileContent, createdFileContent)
} }
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
// address, then releases it so the server under test can rebind. The close/
// rebind window is racy in theory; on loopback with a kernel-assigned port
// it's essentially never contended in practice.
func reserveLoopbackPort(t *testing.T) string {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := l.Addr().String()
require.NoError(t, l.Close())
return addr
}
func waitForServer(t *testing.T, addr string) {
t.Helper()
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
if err == nil {
_ = c.Close()
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("server did not start listening on %s in time", addr)
}

View File

@@ -13,6 +13,7 @@ import (
const ( const (
defaultResolvConfPath = "/etc/resolv.conf" defaultResolvConfPath = "/etc/resolv.conf"
nsswitchConfPath = "/etc/nsswitch.conf"
) )
type resolvConf struct { type resolvConf struct {

View File

@@ -1,7 +1,10 @@
package dns package dns
import ( import (
"context"
"fmt" "fmt"
"math"
"net"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@@ -192,6 +195,12 @@ func (c *HandlerChain) logHandlers() {
} }
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
c.dispatch(w, r, math.MaxInt)
}
// dispatch routes a DNS request through the chain, skipping handlers with
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return
} }
@@ -216,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// Try handlers in priority order // Try handlers in priority order
for _, entry := range handlers { for _, entry := range handlers {
if entry.Priority > maxPriority {
continue
}
if !c.isHandlerMatch(qname, entry) { if !c.isHandlerMatch(qname, entry) {
continue continue
} }
@@ -273,6 +285,55 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
cw.response.Len(), meta, time.Since(startTime)) cw.response.Len(), meta, time.Since(startTime))
} }
// ResolveInternal runs an in-process DNS query against the chain, skipping any
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
// (bounded by the invoked handler's internal timeout).
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
if len(r.Question) == 0 {
return nil, fmt.Errorf("empty question")
}
base := &internalResponseWriter{}
done := make(chan struct{})
go func() {
c.dispatch(base, r, maxPriority)
close(done)
}()
select {
case <-done:
case <-ctx.Done():
// Prefer a completed response if dispatch finished concurrently with cancellation.
select {
case <-done:
default:
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
}
}
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
strings.ToLower(r.Question[0].Name), maxPriority)
}
return base.response, nil
}
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
// priority ≤ maxPriority.
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
c.mu.RLock()
defer c.mu.RUnlock()
for _, h := range c.handlers {
if h.Pattern == "." && h.Priority <= maxPriority {
return true
}
}
return false
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch { switch {
case entry.Pattern == ".": case entry.Pattern == ".":
@@ -291,3 +352,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
} }
} }
} }
// internalResponseWriter captures a dns.Msg for in-process chain queries.
type internalResponseWriter struct {
response *dns.Msg
}
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
// still surface their answer to ResolveInternal.
func (w *internalResponseWriter) Write(p []byte) (int, error) {
msg := new(dns.Msg)
if err := msg.Unpack(p); err != nil {
return 0, err
}
w.response = msg
return len(p), nil
}
func (w *internalResponseWriter) Close() error { return nil }
func (w *internalResponseWriter) TsigStatus() error { return nil }
// TsigTimersOnly is part of dns.ResponseWriter.
func (w *internalResponseWriter) TsigTimersOnly(bool) {
// no-op: in-process queries carry no TSIG state.
}
// Hijack is part of dns.ResponseWriter.
func (w *internalResponseWriter) Hijack() {
// no-op: in-process queries have no underlying connection to hand off.
}

View File

@@ -1,11 +1,15 @@
package dns_test package dns_test
import ( import (
"context"
"net"
"testing" "testing"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
@@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
}) })
} }
} }
// answeringHandler writes a fixed A record to ack the query. Used to verify
// which handler ResolveInternal dispatches to.
type answeringHandler struct {
name string
ip string
}
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
_ = w.WriteMsg(resp)
}
func (h *answeringHandler) String() string { return h.name }
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, 1, len(resp.Answer))
a, ok := resp.Answer[0].(*dns.A)
assert.True(t, ok)
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
}
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.Error(t, err, "no handler at or below maxPriority should error")
}
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
type rawWriteHandler struct {
ip string
}
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
packed, err := resp.Pack()
if err != nil {
return
}
_, _ = w.Write(packed)
}
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
chain := nbdns.NewHandlerChain()
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok)
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
}
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
chain := nbdns.NewHandlerChain()
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
assert.Error(t, err)
}
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
type hangingHandler struct {
block chan struct{}
}
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
<-h.block
resp := &dns.Msg{}
resp.SetReply(r)
_ = w.WriteMsg(resp)
}
func (h *hangingHandler) String() string { return "hangingHandler" }
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &hangingHandler{block: make(chan struct{})}
defer close(h.block)
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
elapsed := time.Since(start)
assert.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
}
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
chain.AddHandler(".", h, nbdns.PriorityDefault)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
chain.RemoveHandler(".", nbdns.PriorityDefault)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
// Primary nsgroup case: root handler lands at PriorityUpstream.
chain.AddHandler(".", h, nbdns.PriorityUpstream)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
chain.RemoveHandler(".", nbdns.PriorityUpstream)
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
chain.AddHandler(".", h, nbdns.PriorityFallback)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
chain.RemoveHandler(".", nbdns.PriorityFallback)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
}

View File

@@ -46,12 +46,12 @@ type restoreHostManager interface {
} }
func newHostManager(wgInterface string) (hostManager, error) { func newHostManager(wgInterface string) (hostManager, error) {
osManager, err := getOSDNSManagerType() osManager, reason, err := getOSDNSManagerType()
if err != nil { if err != nil {
return nil, fmt.Errorf("get os dns manager type: %w", err) return nil, fmt.Errorf("get os dns manager type: %w", err)
} }
log.Infof("System DNS manager discovered: %s", osManager) log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
mgr, err := newHostManagerFromType(wgInterface, osManager) mgr, err := newHostManagerFromType(wgInterface, osManager)
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value // need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
if err != nil { if err != nil {
@@ -74,17 +74,49 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
} }
} }
func getOSDNSManagerType() (osManagerType, error) { func getOSDNSManagerType() (osManagerType, string, error) {
resolved := isSystemdResolvedRunning()
nss := isLibnssResolveUsed()
stub := checkStub()
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
// that go through nss-resolve, and in foreign mode they can loop back
// through resolved as an upstream.
if resolved && (nss || stub) {
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
}
mgr, reason, rejected, err := scanResolvConfHeader()
if err != nil {
return 0, "", err
}
if reason != "" {
return mgr, reason, nil
}
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
if len(rejected) > 0 {
fallback += "; rejected: " + strings.Join(rejected, ", ")
}
return fileManager, fallback, nil
}
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
// matching manager. If reason is empty the caller should pick file mode and
// use rejected for diagnostics.
func scanResolvConfHeader() (osManagerType, string, []string, error) {
file, err := os.Open(defaultResolvConfPath) file, err := os.Open(defaultResolvConfPath)
if err != nil { if err != nil {
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err) return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
} }
defer func() { defer func() {
if err := file.Close(); err != nil { if cerr := file.Close(); cerr != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, err) log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
} }
}() }()
var rejected []string
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(file)
for scanner.Scan() { for scanner.Scan() {
text := scanner.Text() text := scanner.Text()
@@ -92,41 +124,48 @@ func getOSDNSManagerType() (osManagerType, error) {
continue continue
} }
if text[0] != '#' { if text[0] != '#' {
return fileManager, nil break
} }
if strings.Contains(text, fileGeneratedResolvConfContentHeader) { if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
return netbirdManager, nil return mgr, reason, nil, nil
} } else if rej != "" {
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { rejected = append(rejected, rej)
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
if checkStub() {
return systemdManager, nil
} else {
return fileManager, nil
}
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, nil
}
return resolvConfManager, nil
} }
} }
if err := scanner.Err(); err != nil && err != io.EOF { if err := scanner.Err(); err != nil && err != io.EOF {
return 0, fmt.Errorf("scan: %w", err) return 0, "", nil, fmt.Errorf("scan: %w", err)
} }
return 0, "", rejected, nil
return fileManager, nil
} }
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager. // matchResolvConfHeader inspects a single comment line. Returns either a
// definitive (manager, reason) or a non-empty rejected diagnostic.
func matchResolvConfHeader(text string) (osManagerType, string, string) {
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, "netbird-managed resolv.conf header detected", ""
}
if strings.Contains(text, "NetworkManager") {
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, "NetworkManager header + supported version on dbus", ""
}
return 0, "", "NetworkManager header (no dbus or unsupported version)"
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
}
return resolvConfManager, "resolvconf header detected", ""
}
return 0, "", ""
}
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
// into file mode while resolved is active.
func checkStub() bool { func checkStub() bool {
rConf, err := parseDefaultResolvConf() rConf, err := parseDefaultResolvConf()
if err != nil { if err != nil {
log.Warnf("failed to parse resolv conf: %s", err) log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
return true return true
} }
@@ -139,3 +178,36 @@ func checkStub() bool {
return false return false
} }
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
// delegated to systemd-resolved regardless of /etc/resolv.conf.
func isLibnssResolveUsed() bool {
bs, err := os.ReadFile(nsswitchConfPath)
if err != nil {
log.Debugf("read %s: %v", nsswitchConfPath, err)
return false
}
return parseNsswitchResolveAhead(bs)
}
func parseNsswitchResolveAhead(data []byte) bool {
for _, line := range strings.Split(string(data), "\n") {
if i := strings.IndexByte(line, '#'); i >= 0 {
line = line[:i]
}
fields := strings.Fields(line)
if len(fields) < 2 || fields[0] != "hosts:" {
continue
}
for _, module := range fields[1:] {
switch module {
case "dns":
return false
case "resolve":
return true
}
}
}
return false
}

View File

@@ -0,0 +1,76 @@
//go:build (linux && !android) || freebsd
package dns
import "testing"
func TestParseNsswitchResolveAhead(t *testing.T) {
tests := []struct {
name string
in string
want bool
}{
{
name: "resolve before dns with action token",
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
want: true,
},
{
name: "dns before resolve",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
want: false,
},
{
name: "debian default with only dns",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
want: false,
},
{
name: "neither resolve nor dns",
in: "hosts: files myhostname\n",
want: false,
},
{
name: "no hosts line",
in: "passwd: files systemd\ngroup: files systemd\n",
want: false,
},
{
name: "empty",
in: "",
want: false,
},
{
name: "comments and blank lines ignored",
in: "# comment\n\n# another\nhosts: resolve dns\n",
want: true,
},
{
name: "trailing inline comment",
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
want: true,
},
{
name: "hosts token must be the first field",
in: " hosts: resolve dns\n",
want: true,
},
{
name: "other db line mentioning resolve is ignored",
in: "networks: resolve\nhosts: dns\n",
want: false,
},
{
name: "only resolve, no dns",
in: "hosts: files resolve\n",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -2,40 +2,83 @@ package mgmt
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"net/url" "net/url"
"os"
"slices"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
) )
const dnsTimeout = 5 * time.Second const (
dnsTimeout = 5 * time.Second
defaultTTL = 300 * time.Second
refreshBackoff = 30 * time.Second
// Resolver caches critical NetBird infrastructure domains // envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
)
// ChainResolver lets the cache refresh stale entries through the DNS handler
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
// system resolver.
type ChainResolver interface {
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
HasRootHandlerAtOrBelow(maxPriority int) bool
}
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
// records and cachedAt are set at construction and treated as immutable;
// lastFailedRefresh and consecFailures are mutable and must be accessed under
// Resolver.mutex.
type cachedRecord struct {
records []dns.RR
cachedAt time.Time
lastFailedRefresh time.Time
consecFailures int
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
type Resolver struct { type Resolver struct {
records map[dns.Question][]dns.RR records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex mutex sync.RWMutex
}
type ipsResponse struct { chain ChainResolver
ips []netip.Addr chainMaxPriority int
err error refreshGroup singleflight.Group
// refreshing tracks questions whose refresh is running via the OS
// fallback path. A ServeDNS hit for a question in this map indicates
// the OS resolver routed the recursive query back to us (loop). Only
// the OS path arms this so chain-path refreshes don't produce false
// positives. The atomic bool is CAS-flipped once per refresh to
// throttle the warning log.
refreshing map[dns.Question]*atomic.Bool
cacheTTL time.Duration
} }
// NewResolver creates a new management domains cache resolver. // NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver { func NewResolver() *Resolver {
return &Resolver{ return &Resolver{
records: make(map[dns.Question][]dns.RR), records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
} }
} }
@@ -44,7 +87,19 @@ func (m *Resolver) String() string {
return "MgmtCacheResolver" return "MgmtCacheResolver"
} }
// ServeDNS implements dns.Handler interface. // SetChainResolver wires the handler chain used to refresh stale cache entries.
// maxPriority caps which handlers may answer refresh queries (typically
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
// mgmt/route/local handlers are skipped).
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
m.mutex.Lock()
m.chain = chain
m.chainMaxPriority = maxPriority
m.mutex.Unlock()
}
// ServeDNS serves cached A/AAAA records. Stale entries are returned
// immediately and refreshed asynchronously (stale-while-revalidate).
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 { if len(r.Question) == 0 {
m.continueToNext(w, r) m.continueToNext(w, r)
@@ -60,7 +115,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
m.mutex.RLock() m.mutex.RLock()
records, found := m.records[question] cached, found := m.records[question]
inflight := m.refreshing[question]
var shouldRefresh bool
if found {
stale := time.Since(cached.cachedAt) > m.cacheTTL
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
shouldRefresh = stale && !inBackoff
}
m.mutex.RUnlock() m.mutex.RUnlock()
if !found { if !found {
@@ -68,12 +130,23 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
if inflight != nil && inflight.CompareAndSwap(false, true) {
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
question.Name)
}
// Skip scheduling a refresh goroutine if one is already inflight for
// this question; singleflight would dedup anyway but skipping avoids
// a parked goroutine per stale hit under bursty load.
if shouldRefresh && inflight == nil {
m.scheduleRefresh(question, cached)
}
resp := &dns.Msg{} resp := &dns.Msg{}
resp.SetReply(r) resp.SetReply(r)
resp.Authoritative = false resp.Authoritative = false
resp.RecursionAvailable = true resp.RecursionAvailable = true
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
resp.Answer = append(resp.Answer, records...)
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
@@ -98,101 +171,260 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
} }
} }
// AddDomain manually adds a domain to cache by resolving it. // AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
ctx, cancel := context.WithTimeout(ctx, dnsTimeout) ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel() defer cancel()
ips, err := lookupIPWithExtraTimeout(ctx, d) aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
if err != nil {
return err if errA != nil && errAAAA != nil {
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
} }
var aRecords, aaaaRecords []dns.RR if len(aRecords) == 0 && len(aaaaRecords) == 0 {
for _, ip := range ips { if err := errors.Join(errA, errAAAA); err != nil {
if ip.Is4() { return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
rr := &dns.A{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
} }
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
} }
now := time.Now()
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock()
if len(aRecords) > 0 { m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
aQuestion := dns.Question{ m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
}
if len(aaaaRecords) > 0 { log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords)) d.SafeString(), len(aRecords), len(aaaaRecords))
return nil return nil
} }
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) { // applyFamilyRecords writes records, evicts on NODATA, leaves the cache
log.Infof("looking up IP for mgmt domain=%s", d.SafeString()) // untouched on error. Caller holds m.mutex.
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString()) func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
resultChan := make(chan *ipsResponse, 1) q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
switch {
case len(records) > 0:
m.records[q] = &cachedRecord{records: records, cachedAt: now}
case err == nil:
delete(m.records, q)
}
}
go func() { // scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) // unique in-flight key; bursty stale hits share its channel. expected is the
resultChan <- &ipsResponse{ // cachedRecord pointer observed by the caller; the refresh only mutates the
err: err, // cache if that pointer is still the one stored, so a stale in-flight refresh
ips: ips, // can't clobber a newer entry written by AddDomain or a competing refresh.
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
key := question.Name + "|" + dns.TypeToString[question.Qtype]
_ = m.refreshGroup.DoChan(key, func() (any, error) {
return nil, m.refreshQuestion(question, expected)
})
}
// refreshQuestion replaces the cached records on success, or marks the entry
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
// a resolver loop by spotting a query for this same question arriving on us.
// expected pins the cache entry observed at schedule time; mutations only apply
// if m.records[question] still points at it.
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
if err != nil {
m.markRefreshFailed(question, expected)
return fmt.Errorf("parse domain: %w", err)
}
records, err := m.lookupRecords(ctx, d, question)
if err != nil {
fails := m.markRefreshFailed(question, expected)
logf := log.Warnf
if fails == 0 || fails > 1 {
logf = log.Debugf
} }
}() logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
var resp *ipsResponse return err
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
} }
if resp.err != nil { // NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err) if len(records) == 0 {
m.mutex.Lock()
if m.records[question] == expected {
delete(m.records, question)
m.mutex.Unlock()
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.mutex.Unlock()
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
} }
return resp.ips, nil
now := time.Now()
m.mutex.Lock()
if m.records[question] != expected {
m.mutex.Unlock()
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.records[question] = &cachedRecord{records: records, cachedAt: now}
m.mutex.Unlock()
log.Infof("refreshed mgmt cache domain=%s type=%s",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
func (m *Resolver) markRefreshing(question dns.Question) {
m.mutex.Lock()
m.refreshing[question] = &atomic.Bool{}
m.mutex.Unlock()
}
func (m *Resolver) clearRefreshing(question dns.Question) {
m.mutex.Lock()
delete(m.refreshing, question)
m.mutex.Unlock()
}
// markRefreshFailed arms the backoff and returns the new consecutive-failure
// count so callers can downgrade subsequent failure logs to debug.
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
m.mutex.Lock()
defer m.mutex.Unlock()
c, ok := m.records[question]
if !ok || c != expected {
return 0
}
c.lastFailedRefresh = time.Now()
c.consecFailures++
return c.consecFailures
}
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
// callers tell records, NODATA (nil err, no records), and failure apart.
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
return
}
// TODO: drop once every supported OS registers a fallback resolver. Safe
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
// not the system resolver, so net.DefaultResolver will not loop back.
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
return
}
// lookupRecords resolves a single record type via chain or OS. The OS branch
// arms the loop detector for the duration of its call so that ServeDNS can
// spot the OS resolver routing the recursive query back to us.
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
}
// TODO: drop once every supported OS registers a fallback resolver.
m.markRefreshing(q)
defer m.clearRefreshing(q)
return m.osLookup(ctx, d, q.Name, q.Qtype)
}
// lookupViaChain resolves via the handler chain and rewrites each RR to use
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
msg := &dns.Msg{}
msg.SetQuestion(dnsName, qtype)
msg.RecursionDesired = true
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
if err != nil {
return nil, fmt.Errorf("chain resolve: %w", err)
}
if resp == nil {
return nil, fmt.Errorf("chain resolve returned nil response")
}
if resp.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
}
ttl := uint32(m.cacheTTL.Seconds())
owners := cnameOwners(dnsName, resp.Answer)
var filtered []dns.RR
for _, rr := range resp.Answer {
h := rr.Header()
if h.Class != dns.ClassINET || h.Rrtype != qtype {
continue
}
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
continue
}
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
filtered = append(filtered, cp)
}
}
return filtered, nil
}
// osLookup resolves a single family via net.DefaultResolver using resutil,
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
// returns (nil, nil).
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
}
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
if result.Rcode == dns.RcodeSuccess {
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
}
if result.Err != nil {
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
}
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
}
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
// so downstream resolvers don't cache an answer for longer than we will.
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
remaining := m.cacheTTL - time.Since(cachedAt)
if remaining <= 0 {
return 0
}
return uint32((remaining + time.Second - 1) / time.Second)
} }
// PopulateFromConfig extracts and caches domains from the client configuration. // PopulateFromConfig extracts and caches domains from the client configuration.
@@ -224,19 +456,12 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
aQuestion := dns.Question{ qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
Name: dnsName, qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
Qtype: dns.TypeA, delete(m.records, qA)
Qclass: dns.ClassINET, delete(m.records, qAAAA)
} delete(m.refreshing, qA)
delete(m.records, aQuestion) delete(m.refreshing, qAAAA)
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
log.Debugf("removed domain=%s from cache", d.SafeString()) log.Debugf("removed domain=%s from cache", d.SafeString())
return nil return nil
@@ -394,3 +619,73 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains return domains
} }
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
// A/AAAA records return nil.
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
switch r := rr.(type) {
case *dns.A:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.A = slices.Clone(r.A)
return &cp
case *dns.AAAA:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.AAAA = slices.Clone(r.AAAA)
return &cp
}
return nil
}
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
// stamping ttl so the response shares no memory with the cached slice.
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
out := make([]dns.RR, 0, len(records))
for _, rr := range records {
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
out = append(out, cp)
}
}
return out
}
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
// in answer, iterating until fixed point so out-of-order chains resolve.
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
owners := map[string]bool{dnsName: true}
for {
added := false
for _, rr := range answer {
cname, ok := rr.(*dns.CNAME)
if !ok {
continue
}
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
if !owners[name] {
continue
}
target := strings.ToLower(dns.Fqdn(cname.Target))
if !owners[target] {
owners[target] = true
added = true
}
}
if !added {
return owners
}
}
}
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
func resolveCacheTTL() time.Duration {
if v := os.Getenv(envMgmtCacheTTL); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
return d
}
}
return defaultTTL
}

View File

@@ -0,0 +1,408 @@
package mgmt
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/shared/management/domain"
)
type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
err error
hasRoot bool
onLookup func()
}
func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
hasRoot: true,
}
}
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
f.mu.Lock()
defer f.mu.Unlock()
return f.hasRoot
}
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
f.mu.Lock()
q := msg.Question[0]
key := q.Name + "|" + dns.TypeToString[q.Qtype]
f.calls[key]++
answers := f.answers[key]
err := f.err
onLookup := f.onLookup
f.mu.Unlock()
if onLookup != nil {
onLookup()
}
if err != nil {
return nil, err
}
resp := &dns.Msg{}
resp.SetReply(msg)
resp.Answer = answers
return resp, nil
}
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
f.mu.Lock()
defer f.mu.Unlock()
key := name + "|" + dns.TypeToString[qtype]
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
switch qtype {
case dns.TypeA:
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
case dns.TypeAAAA:
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
}
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()
return f.calls[name+"|"+dns.TypeToString[qtype]]
}
// waitFor polls the predicate until it returns true or the deadline passes.
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
t.Helper()
deadline := time.Now().Add(d)
for time.Now().Before(deadline) {
if fn() {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("condition not met within %s", d)
}
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
t.Helper()
msg := new(dns.Msg)
msg.SetQuestion(name, dns.TypeA)
w := &test.MockResponseWriter{}
r.ServeDNS(w, msg)
return w.GetLastResponse()
}
func firstA(t *testing.T, resp *dns.Msg) string {
t.Helper()
require.NotNil(t, resp)
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok, "expected A record")
return a.A.String()
}
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
// Same cached entry age, different cacheTTL values: the shorter TTL must
// trigger a background refresh, the longer one must not. Proves that the
// per-Resolver cacheTTL field actually drives the stale decision.
cachedAt := time.Now().Add(-100 * time.Millisecond)
newRec := func() *cachedRecord {
return &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: cachedAt,
}
}
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = 10 * time.Millisecond
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount(q.Name, dns.TypeA) >= 1
})
})
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = time.Hour
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
})
}
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(), // fresh
}
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
}
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
}
// First query: serves stale immediately.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
})
// Next query should now return the refreshed IP.
waitFor(t, time.Second, func() bool {
resp := queryA(t, r, "mgmt.example.com.")
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
})
}
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
var inflight atomic.Int32
var maxInflight atomic.Int32
chain.onLookup = func() {
cur := inflight.Add(1)
defer inflight.Add(-1)
for {
prev := maxInflight.Load()
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
break
}
}
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
}
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
queryA(t, r, "mgmt.example.com.")
}()
}
wg.Wait()
waitFor(t, 2*time.Second, func() bool {
return inflight.Load() == 0
})
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
}
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("boom")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
// First stale hit triggers a refresh attempt that fails.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
})
waitFor(t, time.Second, func() bool {
r.mutex.RLock()
defer r.mutex.RUnlock()
c, ok := r.records[q]
return ok && !c.lastFailedRefresh.IsZero()
})
// Subsequent stale hits within backoff window should not schedule more refreshes.
for i := 0; i < 10; i++ {
queryA(t, r, "mgmt.example.com.")
}
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
}
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.hasRoot = false
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
// With hasRoot=false the chain must not be consulted. Use a short
// deadline so the OS fallback returns quickly without waiting on a
// real network call in CI.
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
"chain must not be used when no root handler is registered at the bound priority")
}
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
// ServeDNS being invoked for a question while a refresh for that question
// is inflight indicates a resolver loop (OS resolver sent the recursive
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
// Simulate an inflight refresh.
r.markRefreshing(q)
defer r.clearRefreshing(q)
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
require.NotNil(t, inflight)
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
}
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
r.markRefreshing(q)
defer r.clearRefreshing(q)
// Multiple ServeDNS calls during the same refresh must not re-set the flag
// (CompareAndSwap from false -> true returns true only on the first call).
for range 5 {
queryA(t, r, "mgmt.example.com.")
}
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
assert.True(t, inflight.Load())
}
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
queryA(t, r, "mgmt.example.com.")
r.mutex.RLock()
_, ok := r.refreshing[q]
r.mutex.RUnlock()
assert.False(t, ok, "no refresh inflight means no loop tracking")
}
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
r.SetChainResolver(chain, 50)
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.2", firstA(t, resp))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
}

View File

@@ -6,6 +6,7 @@ import (
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -23,6 +24,60 @@ func TestResolver_NewResolver(t *testing.T) {
assert.False(t, resolver.MatchSubdomains()) assert.False(t, resolver.MatchSubdomains())
} }
func TestResolveCacheTTL(t *testing.T) {
tests := []struct {
name string
value string
want time.Duration
}{
{"unset falls back to default", "", defaultTTL},
{"valid duration", "45s", 45 * time.Second},
{"valid minutes", "2m", 2 * time.Minute},
{"malformed falls back to default", "not-a-duration", defaultTTL},
{"zero falls back to default", "0s", defaultTTL},
{"negative falls back to default", "-5s", defaultTTL},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envMgmtCacheTTL, tc.value)
got := resolveCacheTTL()
assert.Equal(t, tc.want, got, "parsed TTL should match")
})
}
}
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
t.Setenv(envMgmtCacheTTL, "7s")
r := NewResolver()
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
}
func TestResolver_ResponseTTL(t *testing.T) {
now := time.Now()
tests := []struct {
name string
cacheTTL time.Duration
cachedAt time.Time
wantMin uint32
wantMax uint32
}{
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := &Resolver{cacheTTL: tc.cacheTTL}
got := r.responseTTL(tc.cachedAt)
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
})
}
}
func TestResolver_ExtractDomainFromURL(t *testing.T) { func TestResolver_ExtractDomainFromURL(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -212,6 +212,7 @@ func newDefaultServer(
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver() mgmtCacheResolver := mgmt.NewResolver()
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,

View File

@@ -26,7 +26,9 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager" firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
@@ -67,6 +69,7 @@ import (
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
sProto "github.com/netbirdio/netbird/shared/signal/proto" sProto "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/capture"
) )
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -140,6 +143,7 @@ type EngineConfig struct {
ProfileConfig *profilemanager.Config ProfileConfig *profilemanager.Config
LogPath string LogPath string
TempDir string
} }
// EngineServices holds the external service dependencies required by the Engine. // EngineServices holds the external service dependencies required by the Engine.
@@ -216,6 +220,8 @@ type Engine struct {
portForwardManager *portforward.Manager portForwardManager *portforward.Manager
srWatcher *guard.SRWatcher srWatcher *guard.SRWatcher
afpacketCapture *capture.AFPacketCapture
// Sync response persistence (protected by syncRespMux) // Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex syncRespMux sync.RWMutex
persistSyncResponse bool persistSyncResponse bool
@@ -569,7 +575,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.connMgr.Start(e.ctx) e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start() e.srWatcher.Start(peer.IsForceRelayed())
e.receiveSignalEvents() e.receiveSignalEvents()
e.receiveManagementEvents() e.receiveManagementEvents()
@@ -603,6 +609,8 @@ func (e *Engine) createFirewall() error {
return nil return nil
} }
firewalld.SetParentContext(e.ctx)
var err error var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil { if err != nil {
@@ -940,7 +948,12 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
return fmt.Errorf("update relay token: %w", err) return fmt.Errorf("update relay token: %w", err)
} }
e.relayManager.UpdateServerURLs(update.Urls) urls := update.Urls
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
urls = override
}
e.relayManager.UpdateServerURLs(urls)
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled. // Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
// We can ignore all errors because the guard will manage the reconnection retries. // We can ignore all errors because the guard will manage the reconnection retries.
@@ -1095,6 +1108,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
StatusRecorder: e.statusRecorder, StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse, SyncResponse: syncResponse,
LogPath: e.config.LogPath, LogPath: e.config.LogPath,
TempDir: e.config.TempDir,
ClientMetrics: e.clientMetrics, ClientMetrics: e.clientMetrics,
RefreshStatus: func() { RefreshStatus: func() {
e.RunHealthProbes(true) e.RunHealthProbes(true)
@@ -1693,6 +1707,11 @@ func (e *Engine) parseNATExternalIPMappings() []string {
} }
func (e *Engine) close() { func (e *Engine) close() {
if e.afpacketCapture != nil {
e.afpacketCapture.Stop()
e.afpacketCapture = nil
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil { if e.wgInterface != nil {
@@ -2158,6 +2177,62 @@ func (e *Engine) Address() (netip.Addr, error) {
return e.wgInterface.Address().IP, nil return e.wgInterface.Address().IP, nil
} }
// SetCapture sets or clears packet capture on the WireGuard device.
// On userspace WireGuard, it taps the FilteredDevice directly.
// On kernel WireGuard (Linux), it falls back to AF_PACKET raw socket capture.
// Pass nil to disable capture.
func (e *Engine) SetCapture(pc device.PacketCapture) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
intf := e.wgInterface
if intf == nil {
return errors.New("wireguard interface not initialized")
}
if e.afpacketCapture != nil {
e.afpacketCapture.Stop()
e.afpacketCapture = nil
}
dev := intf.GetDevice()
if dev != nil {
dev.SetCapture(pc)
e.setForwarderCapture(pc)
return nil
}
// Kernel mode: no FilteredDevice. Use AF_PACKET on Linux.
if pc == nil {
return nil
}
sess, ok := pc.(*capture.Session)
if !ok {
return errors.New("filtered device not available and AF_PACKET requires *capture.Session")
}
afc := capture.NewAFPacketCapture(intf.Name(), sess)
if err := afc.Start(); err != nil {
return fmt.Errorf("start AF_PACKET capture on %s: %w", intf.Name(), err)
}
e.afpacketCapture = afc
return nil
}
// setForwarderCapture propagates capture to the USP filter's forwarder endpoint.
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
func (e *Engine) setForwarderCapture(pc device.PacketCapture) {
if e.firewall == nil {
return
}
type forwarderCapturer interface {
SetPacketCapture(pc forwarder.PacketCapture)
}
if fc, ok := e.firewall.(forwarderCapturer); ok {
fc.SetPacketCapture(pc)
}
}
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) { func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
if e.firewall == nil { if e.firewall == nil {
log.Warn("firewall is disabled, not updating forwarding rules") log.Warn("firewall is disabled, not updating forwarding rules")
@@ -2379,6 +2454,8 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
} }
} }
relayIP := decodeRelayIP(msg.GetBody().GetRelayServerIP())
offerAnswer := peer.OfferAnswer{ offerAnswer := peer.OfferAnswer{
IceCredentials: peer.IceCredentials{ IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag, UFrag: remoteCred.UFrag,
@@ -2389,7 +2466,23 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
RosenpassPubKey: rosenpassPubKey, RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr, RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(), RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
RelaySrvIP: relayIP,
SessionID: sessionID, SessionID: sessionID,
} }
return &offerAnswer, nil return &offerAnswer, nil
} }
// decodeRelayIP decodes the proto relayServerIP bytes (4 or 16) into a
// netip.Addr. Returns the zero value for empty input and logs a warning
// for malformed payloads.
func decodeRelayIP(b []byte) netip.Addr {
if len(b) == 0 {
return netip.Addr{}
}
ip, ok := netip.AddrFromSlice(b)
if !ok {
log.Warnf("invalid relayServerIP in signal message (%d bytes), ignoring", len(b))
return netip.Addr{}
}
return ip.Unmap()
}

View File

@@ -1671,7 +1671,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil) mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -3,7 +3,6 @@ package activity
import ( import (
"net" "net"
"net/netip" "net/netip"
"runtime"
"testing" "testing"
"time" "time"
@@ -18,10 +17,6 @@ import (
peerid "github.com/netbirdio/netbird/client/internal/peer/id" peerid "github.com/netbirdio/netbird/client/internal/peer/id"
) )
func isBindListenerPlatform() bool {
return runtime.GOOS == "windows" || runtime.GOOS == "js"
}
// mockEndpointManager implements device.EndpointManager for testing // mockEndpointManager implements device.EndpointManager for testing
type mockEndpointManager struct { type mockEndpointManager struct {
endpoints map[netip.Addr]net.Conn endpoints map[netip.Addr]net.Conn
@@ -181,10 +176,6 @@ func TestBindListener_Close(t *testing.T) {
} }
func TestManager_BindMode(t *testing.T) { func TestManager_BindMode(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager() mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
@@ -226,10 +217,6 @@ func TestManager_BindMode(t *testing.T) {
} }
func TestManager_BindMode_MultiplePeers(t *testing.T) { func TestManager_BindMode_MultiplePeers(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager() mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}

View File

@@ -4,14 +4,12 @@ import (
"errors" "errors"
"net" "net"
"net/netip" "net/netip"
"runtime"
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id" peerid "github.com/netbirdio/netbird/client/internal/peer/id"
@@ -75,16 +73,6 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
return NewUDPListener(m.wgIface, peerCfg) return NewUDPListener(m.wgIface, peerCfg)
} }
// BindListener is used on Windows, JS, and netstack platforms:
// - JS: Cannot listen to UDP sockets
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
// gateway points to, preventing them from reaching the loopback interface.
// - Netstack: Allows multiple instances on the same host without port conflicts.
// BindListener bypasses these issues by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
return NewUDPListener(m.wgIface, peerCfg)
}
provider, ok := m.wgIface.(bindProvider) provider, ok := m.wgIface.(bindProvider)
if !ok { if !ok {
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider") return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")

View File

@@ -6,7 +6,6 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/activity" "github.com/netbirdio/netbird/client/internal/lazyconn/activity"
@@ -91,8 +90,8 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
m.routesMu.Lock() m.routesMu.Lock()
defer m.routesMu.Unlock() defer m.routesMu.Unlock()
maps.Clear(m.peerToHAGroups) clear(m.peerToHAGroups)
maps.Clear(m.haGroupToPeers) clear(m.haGroupToPeers)
for haUniqueID, routes := range haMap { for haUniqueID, routes := range haMap {
var peers []string var peers []string

View File

@@ -22,4 +22,8 @@ type MobileDependency struct {
DnsManager dns.IosDnsManager DnsManager dns.IosDnsManager
FileDescriptor int32 FileDescriptor int32
StateFilePath string StateFilePath string
// TempDir is a writable directory for temporary files (e.g., debug bundle zip).
// On Android, this should be set to the app's cache directory.
TempDir string
} }

View File

@@ -3,8 +3,6 @@ package store
import ( import (
"sync" "sync"
"golang.org/x/exp/maps"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/netflow/types"
@@ -30,7 +28,7 @@ func (m *Memory) StoreEvent(event *types.Event) {
func (m *Memory) Close() { func (m *Memory) Close() {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
maps.Clear(m.events) clear(m.events)
} }
func (m *Memory) GetEvents() []*types.Event { func (m *Memory) GetEvents() []*types.Event {

View File

@@ -185,17 +185,20 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager) conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() forceRelay := IsForceRelayed()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) if !forceRelay {
if err != nil { relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
return err workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
return err
}
conn.workerICE = workerICE
} }
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages) conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
if !isForceRelayed() { if !forceRelay {
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer) conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
} }
@@ -251,7 +254,9 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.wgWatcherCancel() conn.wgWatcherCancel()
} }
conn.workerRelay.CloseConn() conn.workerRelay.CloseConn()
conn.workerICE.Close() if conn.workerICE != nil {
conn.workerICE.Close()
}
if conn.wgProxyRelay != nil { if conn.wgProxyRelay != nil {
err := conn.wgProxyRelay.CloseConn() err := conn.wgProxyRelay.CloseConn()
@@ -294,7 +299,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) { func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
conn.dumpState.RemoteCandidate() conn.dumpState.RemoteCandidate()
conn.workerICE.OnRemoteCandidate(candidate, haRoutes) if conn.workerICE != nil {
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
}
} }
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
@@ -712,33 +719,35 @@ func (conn *Conn) evalStatus() ConnStatus {
return StatusConnecting return StatusConnecting
} }
func (conn *Conn) isConnectedOnAllWay() (connected bool) { // isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
// would be better to protect this with a mutex, but it could cause deadlock with Close function //
// The result is a tri-state:
// - ConnStatusConnected: all available transports are up
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
// - ConnStatusDisconnected: no working transport
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
defer func() { defer func() {
if !connected { if status == guard.ConnStatusDisconnected {
conn.logTraceConnState() conn.logTraceConnState()
} }
}() }()
// For JS platform: only relay connection is supported iceWorkerCreated := conn.workerICE != nil
if runtime.GOOS == "js" {
return conn.statusRelay.Get() == worker.StatusConnected var iceInProgress bool
if iceWorkerCreated {
iceInProgress = conn.workerICE.InProgress()
} }
// For non-JS platforms: check ICE connection status return evalConnStatus(connStatusInputs{
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() { forceRelay: IsForceRelayed(),
return false peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(),
} relayConnected: conn.statusRelay.Get() == worker.StatusConnected,
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
// If relay is supported with peer, it must also be connected iceWorkerCreated: iceWorkerCreated,
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected,
if conn.statusRelay.Get() == worker.StatusDisconnected { iceInProgress: iceInProgress,
return false })
}
}
return true
} }
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) { func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
@@ -926,3 +935,43 @@ func isController(config ConnConfig) bool {
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil return remoteRosenpassPubKey != nil
} }
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
// "Relay up and needed" — the peer uses relay and the transport is connected.
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
if in.forceRelay {
return boolToConnStatus(relayUsedAndUp)
}
// Remote peer doesn't support ICE, or we haven't created the worker yet:
// relay is the only possible transport.
if !in.remoteSupportsICE || !in.iceWorkerCreated {
return boolToConnStatus(relayUsedAndUp)
}
// ICE counts as "up" when the status is anything other than Disconnected, OR
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
iceUp := in.iceStatusConnecting || in.iceInProgress
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
relayOK := !in.peerUsesRelay || in.relayConnected
switch {
case iceUp && relayOK:
return guard.ConnStatusConnected
case relayUsedAndUp:
// Relay is up but ICE is down — partially connected.
return guard.ConnStatusPartiallyConnected
default:
return guard.ConnStatusDisconnected
}
}
func boolToConnStatus(connected bool) guard.ConnStatus {
if connected {
return guard.ConnStatusConnected
}
return guard.ConnStatusDisconnected
}

View File

@@ -13,6 +13,20 @@ const (
StatusConnected StatusConnected
) )
// connStatusInputs is the primitive-valued snapshot of the state that drives the
// tri-state connection classification. Extracted so the decision logic can be unit-tested
// without constructing full Worker/Handshaker objects.
type connStatusInputs struct {
forceRelay bool // NB_FORCE_RELAY or JS/WASM
peerUsesRelay bool // remote peer advertises relay support AND local has relay
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
remoteSupportsICE bool // remote peer sent ICE credentials
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
iceStatusConnecting bool // statusICE is anything other than Disconnected
iceInProgress bool // a negotiation is currently in flight
}
// ConnStatus describe the status of a peer's connection // ConnStatus describe the status of a peer's connection
type ConnStatus int32 type ConnStatus int32

View File

@@ -0,0 +1,201 @@
package peer
import (
"testing"
"github.com/netbirdio/netbird/client/internal/peer/guard"
)
func TestEvalConnStatus_ForceRelay(t *testing.T) {
tests := []struct {
name string
in connStatusInputs
want guard.ConnStatus
}{
{
name: "force relay, peer uses relay, relay up",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: true,
relayConnected: true,
},
want: guard.ConnStatusConnected,
},
{
name: "force relay, peer uses relay, relay down",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: true,
relayConnected: false,
},
want: guard.ConnStatusDisconnected,
},
{
name: "force relay, peer does NOT use relay - disconnected forever",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: false,
relayConnected: true,
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := evalConnStatus(tc.in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
}
})
}
}
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
tests := []struct {
name string
in connStatusInputs
want guard.ConnStatus
}{
{
name: "remote does not support ICE, peer uses relay, relay up",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: true,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusConnected,
},
{
name: "remote does not support ICE, peer uses relay, relay down",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: false,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusDisconnected,
},
{
name: "ICE worker not yet created, relay up",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: true,
remoteSupportsICE: true,
iceWorkerCreated: false,
},
want: guard.ConnStatusConnected,
},
{
name: "remote does not support ICE, peer does not use relay",
in: connStatusInputs{
peerUsesRelay: false,
relayConnected: false,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := evalConnStatus(tc.in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
}
})
}
}
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
base := connStatusInputs{
remoteSupportsICE: true,
iceWorkerCreated: true,
}
tests := []struct {
name string
mutator func(*connStatusInputs)
want guard.ConnStatus
}{
{
name: "ICE connected, relay connected, peer uses relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = true
in.iceStatusConnecting = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE connected, peer does NOT use relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = false
in.iceStatusConnecting = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE InProgress only, peer does NOT use relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.iceStatusConnecting = false
in.iceInProgress = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE down, relay up, peer uses relay -> partial",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = true
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusPartiallyConnected,
},
{
name: "ICE down, peer does NOT use relay -> disconnected",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = false
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusDisconnected,
},
{
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = false
in.iceStatusConnecting = true
},
// relayOK = false (peer uses relay but it's down), iceUp = true
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
// falls into default: Disconnected.
want: guard.ConnStatusDisconnected,
},
{
name: "ICE down, relay up but peer does not use relay -> disconnected",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = true // not actually used since peer doesn't rely on it
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
in := base
tc.mutator(&in)
if got := evalConnStatus(in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
}
})
}
}

View File

@@ -7,12 +7,38 @@ import (
) )
const ( const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY" EnvKeyNBForceRelay = "NB_FORCE_RELAY"
EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS"
) )
func isForceRelayed() bool { func IsForceRelayed() bool {
if runtime.GOOS == "js" { if runtime.GOOS == "js" {
return true return true
} }
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
} }
// OverrideRelayURLs returns the relay server URL list set in
// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether
// the override is active. When the env var is unset, the boolean is false
// and the caller should keep the list received from the management server.
// Intended for lab/debug scenarios where a peer must pin to a specific home
// relay regardless of what management offers.
func OverrideRelayURLs() ([]string, bool) {
raw := os.Getenv(EnvKeyNBHomeRelayServers)
if raw == "" {
return nil, false
}
parts := strings.Split(raw, ",")
urls := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
urls = append(urls, p)
}
}
if len(urls) == 0 {
return nil, false
}
return urls, true
}

View File

@@ -8,7 +8,19 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
type isConnectedFunc func() bool // ConnStatus represents the connection state as seen by the guard.
type ConnStatus int
const (
// ConnStatusDisconnected means neither ICE nor Relay is connected.
ConnStatusDisconnected ConnStatus = iota
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
ConnStatusPartiallyConnected
// ConnStatusConnected means all required connections are established.
ConnStatusConnected
)
type connStatusFunc func() ConnStatus
// Guard is responsible for the reconnection logic. // Guard is responsible for the reconnection logic.
// It will trigger to send an offer to the peer then has connection issues. // It will trigger to send an offer to the peer then has connection issues.
@@ -20,14 +32,14 @@ type isConnectedFunc func() bool
// - ICE candidate changes // - ICE candidate changes
type Guard struct { type Guard struct {
log *log.Entry log *log.Entry
isConnectedOnAllWay isConnectedFunc isConnectedOnAllWay connStatusFunc
timeout time.Duration timeout time.Duration
srWatcher *SRWatcher srWatcher *SRWatcher
relayedConnDisconnected chan struct{} relayedConnDisconnected chan struct{}
iCEConnDisconnected chan struct{} iCEConnDisconnected chan struct{}
} }
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{ return &Guard{
log: log, log: log,
isConnectedOnAllWay: isConnectedFn, isConnectedOnAllWay: isConnectedFn,
@@ -57,8 +69,17 @@ func (g *Guard) SetICEConnDisconnected() {
} }
} }
// reconnectLoopWithRetry periodically check the connection status. // reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported //
// Behavior depends on the connection state reported by isConnectedOnAllWay:
// - Connected: no action, the peer is fully reachable.
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
//
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
srReconnectedChan := g.srWatcher.NewListener() srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan) defer g.srWatcher.RemoveListener(srReconnectedChan)
@@ -68,36 +89,47 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
tickerChannel := ticker.C tickerChannel := ticker.C
iceState := &iceRetryState{log: g.log}
defer iceState.reset()
for { for {
select { select {
case t := <-tickerChannel: case <-tickerChannel:
if t.IsZero() { switch g.isConnectedOnAllWay() {
g.log.Infof("retry timed out, stop periodic offer sending") case ConnStatusConnected:
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop // all good, nothing to do
tickerChannel = make(<-chan time.Time) case ConnStatusDisconnected:
continue callback()
case ConnStatusPartiallyConnected:
if iceState.shouldRetry() {
callback()
} else {
iceState.enterHourlyMode()
ticker.Stop()
tickerChannel = iceState.hourlyC()
}
} }
if !g.isConnectedOnAllWay() {
callback()
}
case <-g.relayedConnDisconnected: case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, reset reconnection ticker") g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop() ticker.Stop()
ticker = g.prepareExponentTicker(ctx) ticker = g.newReconnectTicker(ctx)
tickerChannel = ticker.C tickerChannel = ticker.C
iceState.reset()
case <-g.iCEConnDisconnected: case <-g.iCEConnDisconnected:
g.log.Debugf("ICE connection changed, reset reconnection ticker") g.log.Debugf("ICE connection changed, reset reconnection ticker")
ticker.Stop() ticker.Stop()
ticker = g.prepareExponentTicker(ctx) ticker = g.newReconnectTicker(ctx)
tickerChannel = ticker.C tickerChannel = ticker.C
iceState.reset()
case <-srReconnectedChan: case <-srReconnectedChan:
g.log.Debugf("has network changes, reset reconnection ticker") g.log.Debugf("has network changes, reset reconnection ticker")
ticker.Stop() ticker.Stop()
ticker = g.prepareExponentTicker(ctx) ticker = g.newReconnectTicker(ctx)
tickerChannel = ticker.C tickerChannel = ticker.C
iceState.reset()
case <-ctx.Done(): case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop") g.log.Debugf("context is done, stop reconnect loop")
@@ -120,7 +152,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
return backoff.NewTicker(bo) return backoff.NewTicker(bo)
} }
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{ bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond, InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1, RandomizationFactor: 0.1,

View File

@@ -0,0 +1,61 @@
package guard
import (
"time"
log "github.com/sirupsen/logrus"
)
const (
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
maxICERetries = 3
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
iceRetryInterval = 1 * time.Hour
)
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
// After maxICERetries attempts it switches to a periodic hourly retry.
type iceRetryState struct {
log *log.Entry
retries int
hourly *time.Ticker
}
func (s *iceRetryState) reset() {
s.retries = 0
if s.hourly != nil {
s.hourly.Stop()
s.hourly = nil
}
}
// shouldRetry reports whether the caller should send another ICE offer on this tick.
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
// to the hourly ticker via enterHourlyMode + hourlyC.
func (s *iceRetryState) shouldRetry() bool {
if s.hourly != nil {
s.log.Debugf("hourly ICE retry attempt")
return true
}
s.retries++
if s.retries <= maxICERetries {
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
return true
}
return false
}
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
func (s *iceRetryState) enterHourlyMode() {
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
s.hourly = time.NewTicker(iceRetryInterval)
}
func (s *iceRetryState) hourlyC() <-chan time.Time {
if s.hourly == nil {
return nil
}
return s.hourly.C
}

View File

@@ -0,0 +1,103 @@
package guard
import (
"testing"
log "github.com/sirupsen/logrus"
)
func newTestRetryState() *iceRetryState {
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
}
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
s := newTestRetryState()
for i := 1; i <= maxICERetries; i++ {
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
}
}
}
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries; i++ {
_ = s.shouldRetry()
}
if s.shouldRetry() {
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
}
}
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
s := newTestRetryState()
if s.hourlyC() != nil {
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
}
}
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries+1; i++ {
_ = s.shouldRetry()
}
s.enterHourlyMode()
defer s.reset()
if s.hourlyC() == nil {
t.Fatalf("hourlyC returned nil after enterHourlyMode")
}
}
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
s := newTestRetryState()
s.enterHourlyMode()
defer s.reset()
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false in hourly mode, want true")
}
// Subsequent calls also return true — we keep retrying on each hourly tick.
if !s.shouldRetry() {
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
}
}
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries+1; i++ {
_ = s.shouldRetry()
}
s.enterHourlyMode()
s.reset()
if s.hourlyC() != nil {
t.Fatalf("hourlyC returned non-nil channel after reset")
}
if s.retries != 0 {
t.Fatalf("retries = %d after reset, want 0", s.retries)
}
for i := 1; i <= maxICERetries; i++ {
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
}
}
}
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
s := newTestRetryState()
s.reset()
s.reset() // second call must not panic or re-stop a nil ticker
if s.hourlyC() != nil {
t.Fatalf("hourlyC non-nil after double reset")
}
}

View File

@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
return srw return srw
} }
func (w *SRWatcher) Start() { func (w *SRWatcher) Start(disableICEMonitor bool) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
@@ -50,8 +50,10 @@ func (w *SRWatcher) Start() {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel w.cancelIceMonitor = cancel
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) if !disableICEMonitor {
go iceMonitor.Start(ctx, w.onICEChanged) iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged)
}
w.signalClient.SetOnReconnectedListener(w.onReconnected) w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected)

View File

@@ -3,7 +3,9 @@ package peer
import ( import (
"context" "context"
"errors" "errors"
"net/netip"
"sync" "sync"
"sync/atomic"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -39,10 +41,18 @@ type OfferAnswer struct {
// relay server address // relay server address
RelaySrvAddress string RelaySrvAddress string
// RelaySrvIP is the IP the remote peer is connected to on its
// relay server. Used as a dial target if DNS for RelaySrvAddress
// fails. Zero value if the peer did not advertise an IP.
RelaySrvIP netip.Addr
// SessionID is the unique identifier of the session, used to discard old messages // SessionID is the unique identifier of the session, used to discard old messages
SessionID *ICESessionID SessionID *ICESessionID
} }
func (o *OfferAnswer) hasICECredentials() bool {
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
}
type Handshaker struct { type Handshaker struct {
mu sync.Mutex mu sync.Mutex
log *log.Entry log *log.Entry
@@ -59,6 +69,10 @@ type Handshaker struct {
relayListener *AsyncOfferListener relayListener *AsyncOfferListener
iceListener func(remoteOfferAnswer *OfferAnswer) iceListener func(remoteOfferAnswer *OfferAnswer)
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
remoteICESupported atomic.Bool
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer remoteOffersCh chan OfferAnswer
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
@@ -66,7 +80,7 @@ type Handshaker struct {
} }
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker { func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
return &Handshaker{ h := &Handshaker{
log: log, log: log,
config: config, config: config,
signaler: signaler, signaler: signaler,
@@ -76,6 +90,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
remoteOffersCh: make(chan OfferAnswer), remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer), remoteAnswerCh: make(chan OfferAnswer),
} }
// assume remote supports ICE until we learn otherwise from received offers
h.remoteICESupported.Store(ice != nil)
return h
}
func (h *Handshaker) RemoteICESupported() bool {
return h.remoteICESupported.Load()
} }
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) { func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
@@ -90,18 +111,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
for { for {
select { select {
case remoteOfferAnswer := <-h.remoteOffersCh: case remoteOfferAnswer := <-h.remoteOffersCh:
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
// Record signaling received for reconnection attempts // Record signaling received for reconnection attempts
if h.metricsStages != nil { if h.metricsStages != nil {
h.metricsStages.RecordSignalingReceived() h.metricsStages.RecordSignalingReceived()
} }
h.updateRemoteICEState(&remoteOfferAnswer)
if h.relayListener != nil { if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer) h.relayListener.Notify(&remoteOfferAnswer)
} }
if h.iceListener != nil { if h.iceListener != nil && h.RemoteICESupported() {
h.iceListener(&remoteOfferAnswer) h.iceListener(&remoteOfferAnswer)
} }
@@ -110,18 +133,20 @@ func (h *Handshaker) Listen(ctx context.Context) {
continue continue
} }
case remoteOfferAnswer := <-h.remoteAnswerCh: case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
// Record signaling received for reconnection attempts // Record signaling received for reconnection attempts
if h.metricsStages != nil { if h.metricsStages != nil {
h.metricsStages.RecordSignalingReceived() h.metricsStages.RecordSignalingReceived()
} }
h.updateRemoteICEState(&remoteOfferAnswer)
if h.relayListener != nil { if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer) h.relayListener.Notify(&remoteOfferAnswer)
} }
if h.iceListener != nil { if h.iceListener != nil && h.RemoteICESupported() {
h.iceListener(&remoteOfferAnswer) h.iceListener(&remoteOfferAnswer)
} }
case <-ctx.Done(): case <-ctx.Done():
@@ -183,20 +208,39 @@ func (h *Handshaker) sendAnswer() error {
} }
func (h *Handshaker) buildOfferAnswer() OfferAnswer { func (h *Handshaker) buildOfferAnswer() OfferAnswer {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer := OfferAnswer{ answer := OfferAnswer{
IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort, WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(), Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey, RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr, RosenpassAddr: h.config.RosenpassConfig.Addr,
SessionID: &sid,
} }
if addr, err := h.relay.RelayInstanceAddress(); err == nil { if h.ice != nil && h.RemoteICESupported() {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer.IceCredentials = IceCredentials{uFrag, pwd}
answer.SessionID = &sid
}
if addr, ip, err := h.relay.RelayInstanceAddress(); err == nil {
answer.RelaySrvAddress = addr answer.RelaySrvAddress = addr
answer.RelaySrvIP = ip
} }
return answer return answer
} }
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
hasICE := offer.hasICECredentials()
prev := h.remoteICESupported.Swap(hasICE)
if prev != hasICE {
if hasICE {
h.log.Infof("remote peer started sending ICE credentials")
} else {
h.log.Infof("remote peer stopped sending ICE credentials")
if h.ice != nil {
h.ice.Close()
}
}
}
}

View File

@@ -8,6 +8,7 @@ import (
type mocListener struct { type mocListener struct {
lastState int lastState int
wg sync.WaitGroup wg sync.WaitGroup
peersWg sync.WaitGroup
peers int peers int
} }
@@ -33,6 +34,7 @@ func (l *mocListener) OnAddressChanged(host, addr string) {
} }
func (l *mocListener) OnPeersListChanged(size int) { func (l *mocListener) OnPeersListChanged(size int) {
l.peers = size l.peers = size
l.peersWg.Done()
} }
func (l *mocListener) setWaiter() { func (l *mocListener) setWaiter() {
@@ -43,6 +45,14 @@ func (l *mocListener) wait() {
l.wg.Wait() l.wg.Wait()
} }
func (l *mocListener) setPeersWaiter() {
l.peersWg.Add(1)
}
func (l *mocListener) waitPeers() {
l.peersWg.Wait()
}
func Test_notifier_serverState(t *testing.T) { func Test_notifier_serverState(t *testing.T) {
type scenario struct { type scenario struct {
@@ -72,11 +82,13 @@ func Test_notifier_serverState(t *testing.T) {
func Test_notifier_SetListener(t *testing.T) { func Test_notifier_SetListener(t *testing.T) {
listener := &mocListener{} listener := &mocListener{}
listener.setWaiter() listener.setWaiter()
listener.setPeersWaiter()
n := newNotifier() n := newNotifier()
n.lastNotification = stateConnecting n.lastNotification = stateConnecting
n.setListener(listener) n.setListener(listener)
listener.wait() listener.wait()
listener.waitPeers()
if listener.lastState != n.lastNotification { if listener.lastState != n.lastNotification {
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification) t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
} }
@@ -85,9 +97,14 @@ func Test_notifier_SetListener(t *testing.T) {
func Test_notifier_RemoveListener(t *testing.T) { func Test_notifier_RemoveListener(t *testing.T) {
listener := &mocListener{} listener := &mocListener{}
listener.setWaiter() listener.setWaiter()
listener.setPeersWaiter()
n := newNotifier() n := newNotifier()
n.lastNotification = stateConnecting n.lastNotification = stateConnecting
n.setListener(listener) n.setListener(listener)
// setListener replays cached state on a goroutine; wait for both the state
// and peers callbacks to finish so we don't race on listener.peers.
listener.wait()
listener.waitPeers()
n.removeListener() n.removeListener()
n.peerListChanged(1) n.peerListChanged(1)

View File

@@ -46,23 +46,27 @@ func (s *Signaler) Ready() bool {
// SignalOfferAnswer signals either an offer or an answer to remote peer // SignalOfferAnswer signals either an offer or an answer to remote peer
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error { func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
sessionIDBytes, err := offerAnswer.SessionID.Bytes() var sessionIDBytes []byte
if err != nil { if offerAnswer.SessionID != nil {
log.Warnf("failed to get session ID bytes: %v", err) var err error
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
}
} }
msg, err := signal.MarshalCredential( msg, err := signal.MarshalCredential(s.wgPrivateKey, remoteKey, signal.CredentialPayload{
s.wgPrivateKey, Type: bodyType,
offerAnswer.WgListenPort, WgListenPort: offerAnswer.WgListenPort,
remoteKey, Credential: &signal.Credential{
&signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag, UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd, Pwd: offerAnswer.IceCredentials.Pwd,
}, },
bodyType, RosenpassPubKey: offerAnswer.RosenpassPubKey,
offerAnswer.RosenpassPubKey, RosenpassAddr: offerAnswer.RosenpassAddr,
offerAnswer.RosenpassAddr, RelaySrvAddress: offerAnswer.RelaySrvAddress,
offerAnswer.RelaySrvAddress, RelaySrvIP: offerAnswer.RelaySrvIP,
sessionIDBytes) SessionID: sessionIDBytes,
})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -320,10 +320,10 @@ func (d *Status) RemovePeer(peerPubKey string) error {
// UpdatePeerState updates peer status // UpdatePeerState updates peer status
func (d *Status) UpdatePeerState(receivedState State) error { func (d *Status) UpdatePeerState(receivedState State) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey] peerState, ok := d.peers[receivedState.PubKey]
if !ok { if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
@@ -343,23 +343,29 @@ func (d *Status) UpdatePeerState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) { notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
d.notifyPeerListChanged()
}
// when we close the connection we will not notify the router manager // when we close the connection we will not notify the router manager
if receivedState.ConnStatus == StatusIdle { notifyRouter := receivedState.ConnStatus == StatusIdle
d.notifyPeerStateChangeListeners(receivedState.PubKey) routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
} }
return nil return nil
} }
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error { func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer] peerState, ok := d.peers[peer]
if !ok { if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
@@ -371,17 +377,20 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R
d.routeIDLookup.AddRemoteRouteID(resourceId, pref) d.routeIDLookup.AddRemoteRouteID(resourceId, pref)
} }
numPeers := d.numOfPeers()
d.mux.Unlock()
// todo: consider to make sense of this notification or not // todo: consider to make sense of this notification or not
d.notifyPeerListChanged() d.notifier.peerListChanged(numPeers)
return nil return nil
} }
func (d *Status) RemovePeerStateRoute(peer string, route string) error { func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer] peerState, ok := d.peers[peer]
if !ok { if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
@@ -393,8 +402,11 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.routeIDLookup.RemoveRemoteRouteID(pref) d.routeIDLookup.RemoveRemoteRouteID(pref)
} }
numPeers := d.numOfPeers()
d.mux.Unlock()
// todo: consider to make sense of this notification or not // todo: consider to make sense of this notification or not
d.notifyPeerListChanged() d.notifier.peerListChanged(numPeers)
return nil return nil
} }
@@ -410,10 +422,10 @@ func (d *Status) CheckRoutes(ip netip.Addr) ([]byte, bool) {
func (d *Status) UpdatePeerICEState(receivedState State) error { func (d *Status) UpdatePeerICEState(receivedState State) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey] peerState, ok := d.peers[receivedState.PubKey]
if !ok { if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
@@ -431,22 +443,28 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) { notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
d.notifyPeerListChanged() notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
} routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { d.mux.Unlock()
d.notifyPeerStateChangeListeners(receivedState.PubKey)
if notifyList {
d.notifier.peerListChanged(numPeers)
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
} }
return nil return nil
} }
func (d *Status) UpdatePeerRelayedState(receivedState State) error { func (d *Status) UpdatePeerRelayedState(receivedState State) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey] peerState, ok := d.peers[receivedState.PubKey]
if !ok { if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
@@ -461,22 +479,28 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) { notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
d.notifyPeerListChanged() notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
} routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { d.mux.Unlock()
d.notifyPeerStateChangeListeners(receivedState.PubKey)
if notifyList {
d.notifier.peerListChanged(numPeers)
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
} }
return nil return nil
} }
func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error { func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey] peerState, ok := d.peers[receivedState.PubKey]
if !ok { if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
@@ -490,22 +514,28 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) { notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
d.notifyPeerListChanged() notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
} routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { d.mux.Unlock()
d.notifyPeerStateChangeListeners(receivedState.PubKey)
if notifyList {
d.notifier.peerListChanged(numPeers)
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
} }
return nil return nil
} }
func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey] peerState, ok := d.peers[receivedState.PubKey]
if !ok { if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
@@ -522,12 +552,18 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
if hasConnStatusChanged(oldState, receivedState.ConnStatus) { notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
d.notifyPeerListChanged() notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
} routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) { d.mux.Unlock()
d.notifyPeerStateChangeListeners(receivedState.PubKey)
if notifyList {
d.notifier.peerListChanged(numPeers)
}
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
} }
return nil return nil
} }
@@ -594,17 +630,33 @@ func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) erro
// FinishPeerListModifications this event invoke the notification // FinishPeerListModifications this event invoke the notification
func (d *Status) FinishPeerListModifications() { func (d *Status) FinishPeerListModifications() {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
if !d.peerListChangedForNotification { if !d.peerListChangedForNotification {
d.mux.Unlock()
return return
} }
d.peerListChangedForNotification = false d.peerListChangedForNotification = false
d.notifyPeerListChanged() numPeers := d.numOfPeers()
// snapshot per-peer router state to deliver after the lock is released
type routerDispatch struct {
peerID string
snapshot map[string]RouterState
}
dispatches := make([]routerDispatch, 0, len(d.peers))
for key := range d.peers { for key := range d.peers {
d.notifyPeerStateChangeListeners(key) snapshot := d.snapshotRouterPeersLocked(key, true)
if snapshot != nil {
dispatches = append(dispatches, routerDispatch{peerID: key, snapshot: snapshot})
}
}
d.mux.Unlock()
d.notifier.peerListChanged(numPeers)
for _, rd := range dispatches {
d.dispatchRouterPeers(rd.peerID, rd.snapshot)
} }
} }
@@ -655,10 +707,12 @@ func (d *Status) GetLocalPeerState() LocalPeerState {
// UpdateLocalPeerState updates local peer status // UpdateLocalPeerState updates local peer status
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) { func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
d.localPeer = localPeerState d.localPeer = localPeerState
d.notifyAddressChanged() fqdn := d.localPeer.FQDN
ip := d.localPeer.IP
d.mux.Unlock()
d.notifier.localAddressChanged(fqdn, ip)
} }
// AddLocalPeerStateRoute adds a route to the local peer state // AddLocalPeerStateRoute adds a route to the local peer state
@@ -721,30 +775,36 @@ func (d *Status) CleanLocalPeerStateRoutes() {
// CleanLocalPeerState cleans local peer status // CleanLocalPeerState cleans local peer status
func (d *Status) CleanLocalPeerState() { func (d *Status) CleanLocalPeerState() {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
d.localPeer = LocalPeerState{} d.localPeer = LocalPeerState{}
d.notifyAddressChanged() fqdn := d.localPeer.FQDN
ip := d.localPeer.IP
d.mux.Unlock()
d.notifier.localAddressChanged(fqdn, ip)
} }
// MarkManagementDisconnected sets ManagementState to disconnected // MarkManagementDisconnected sets ManagementState to disconnected
func (d *Status) MarkManagementDisconnected(err error) { func (d *Status) MarkManagementDisconnected(err error) {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.managementState = false d.managementState = false
d.managementError = err d.managementError = err
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
} }
// MarkManagementConnected sets ManagementState to connected // MarkManagementConnected sets ManagementState to connected
func (d *Status) MarkManagementConnected() { func (d *Status) MarkManagementConnected() {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.managementState = true d.managementState = true
d.managementError = nil d.managementError = nil
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
} }
// UpdateSignalAddress update the address of the signal server // UpdateSignalAddress update the address of the signal server
@@ -778,21 +838,25 @@ func (d *Status) UpdateLazyConnection(enabled bool) {
// MarkSignalDisconnected sets SignalState to disconnected // MarkSignalDisconnected sets SignalState to disconnected
func (d *Status) MarkSignalDisconnected(err error) { func (d *Status) MarkSignalDisconnected(err error) {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.signalState = false d.signalState = false
d.signalError = err d.signalError = err
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
} }
// MarkSignalConnected sets SignalState to connected // MarkSignalConnected sets SignalState to connected
func (d *Status) MarkSignalConnected() { func (d *Status) MarkSignalConnected() {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.signalState = true d.signalState = true
d.signalError = nil d.signalError = nil
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
} }
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) { func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
@@ -919,7 +983,7 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
// if the server connection is not established then we will use the general address // if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address // in case of connection we will use the instance specific address
instanceAddr, err := d.relayMgr.RelayInstanceAddress() instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
if err != nil { if err != nil {
// TODO add their status // TODO add their status
for _, r := range d.relayMgr.ServerURLs() { for _, r := range d.relayMgr.ServerURLs() {
@@ -1012,18 +1076,17 @@ func (d *Status) RemoveConnectionListener() {
d.notifier.removeListener() d.notifier.removeListener()
} }
func (d *Status) onConnectionChanged() { // snapshotRouterPeersLocked builds the RouterState map for a peer's subscribers.
d.notifier.updateServerStates(d.managementState, d.signalState) // Caller MUST hold d.mux. Returns nil when there are no subscribers for peerID
} // or when notify is false. The snapshot is consumed later by dispatchRouterPeers
// outside the lock so the channel send cannot stall any d.mux holder.
// notifyPeerStateChangeListeners notifies route manager about the change in peer state func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[string]RouterState {
func (d *Status) notifyPeerStateChangeListeners(peerID string) { if !notify {
subs, ok := d.changeNotify[peerID] return nil
if !ok { }
return if _, ok := d.changeNotify[peerID]; !ok {
return nil
} }
// collect the relevant data for router peers
routerPeers := make(map[string]RouterState, len(d.changeNotify)) routerPeers := make(map[string]RouterState, len(d.changeNotify))
for pid := range d.changeNotify { for pid := range d.changeNotify {
s, ok := d.peers[pid] s, ok := d.peers[pid]
@@ -1031,13 +1094,35 @@ func (d *Status) notifyPeerStateChangeListeners(peerID string) {
log.Warnf("router peer not found in peers list: %s", pid) log.Warnf("router peer not found in peers list: %s", pid)
continue continue
} }
routerPeers[pid] = RouterState{ routerPeers[pid] = RouterState{
Status: s.ConnStatus, Status: s.ConnStatus,
Relayed: s.Relayed, Relayed: s.Relayed,
Latency: s.Latency, Latency: s.Latency,
} }
} }
return routerPeers
}
// dispatchRouterPeers delivers a previously snapshotted router-state map to
// the peer's subscribers. Caller MUST NOT hold d.mux. The method takes a
// fresh, short read of d.changeNotify under the lock to grab subscriber
// channels, then sends outside the lock so a slow consumer cannot block other
// d.mux holders. The send itself stays blocking (only short-circuited by the
// subscriber's context) so peer state transitions are not silently dropped.
func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]RouterState) {
if routerPeers == nil {
return
}
d.mux.Lock()
subsMap, ok := d.changeNotify[peerID]
subs := make([]*StatusChangeSubscription, 0, len(subsMap))
if ok {
for _, sub := range subsMap {
subs = append(subs, sub)
}
}
d.mux.Unlock()
for _, sub := range subs { for _, sub := range subs {
select { select {
@@ -1047,14 +1132,6 @@ func (d *Status) notifyPeerStateChangeListeners(peerID string) {
} }
} }
func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(d.numOfPeers())
}
func (d *Status) notifyAddressChanged() {
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
}
func (d *Status) numOfPeers() int { func (d *Status) numOfPeers() int {
return len(d.peers) + len(d.offlinePeers) return len(d.peers) + len(d.offlinePeers)
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"net" "net"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -53,15 +54,19 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.relaySupportedOnRemotePeer.Store(true) w.relaySupportedOnRemotePeer.Store(true)
// the relayManager will return with error in case if the connection has lost with relay server // the relayManager will return with error in case if the connection has lost with relay server
currentRelayAddress, err := w.relayManager.RelayInstanceAddress() currentRelayAddress, _, err := w.relayManager.RelayInstanceAddress()
if err != nil { if err != nil {
w.log.Errorf("failed to handle new offer: %s", err) w.log.Errorf("failed to handle new offer: %s", err)
return return
} }
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
var serverIP netip.Addr
if srv == remoteOfferAnswer.RelaySrvAddress {
serverIP = remoteOfferAnswer.RelaySrvIP
}
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key) relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key, serverIP)
if err != nil { if err != nil {
if errors.Is(err, relayClient.ErrConnAlreadyExists) { if errors.Is(err, relayClient.ErrConnAlreadyExists) {
w.log.Debugf("handled offer by reusing existing relay connection") w.log.Debugf("handled offer by reusing existing relay connection")
@@ -90,7 +95,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
}) })
} }
func (w *WorkerRelay) RelayInstanceAddress() (string, error) { func (w *WorkerRelay) RelayInstanceAddress() (string, netip.Addr, error) {
return w.relayManager.RelayInstanceAddress() return w.relayManager.RelayInstanceAddress()
} }

View File

@@ -0,0 +1,10 @@
//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin
package systemops
// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They
// always fall through to the ref-counter exclusion-route path; these stubs
// exist only so systemops_unix.go compiles.
func (r *SysOps) setupAdvancedRouting() error { return nil }
func (r *SysOps) cleanupAdvancedRouting() error { return nil }
func (r *SysOps) flushPlatformExtras() error { return nil }

View File

@@ -0,0 +1,253 @@
//go:build darwin && !ios
package systemops
import (
"errors"
"fmt"
"net/netip"
"os"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/client/net"
)
// scopedRouteBudget bounds retries for the scoped default route. Installing or
// deleting it matters enough that we're willing to spend longer waiting for the
// kernel reply than for per-prefix exclusion routes.
const scopedRouteBudget = 5 * time.Second
// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family
// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can
// resolve gateway'd destinations while the VPN's split default owns the
// unscoped table.
//
// Timing note: this runs during routeManager.Init, which happens before the
// VPN interface is created and before any peer routes propagate. The initial
// mgmt / signal / relay TCP dials always fire before this runs, so those
// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route
// lookup, which at that point correctly picks the physical default. Those
// already-established TCP flows keep their originally-selected interface for
// their lifetime on Darwin because the kernel caches the egress route
// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default
// afterwards does not migrate them since the original en0 default stays in
// the table. Any subsequent reconnect via nbnet.NewDialer picks up the
// populated bound-iface cache and gets IP_BOUND_IF set cleanly.
func (r *SysOps) setupAdvancedRouting() error {
// Drop any previously-cached egress interface before reinstalling. On a
// refresh, a family that no longer resolves would otherwise keep the stale
// binding, causing new sockets to scope to an interface without a matching
// scoped default.
nbnet.ClearBoundInterfaces()
if err := r.flushScopedDefaults(); err != nil {
log.Warnf("flush residual scoped defaults: %v", err)
}
var merr *multierror.Error
installed := 0
for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
ok, err := r.installScopedDefaultFor(unspec)
if err != nil {
merr = multierror.Append(merr, err)
continue
}
if ok {
installed++
}
}
if installed == 0 && merr != nil {
return nberrors.FormatErrorOrNil(merr)
}
if merr != nil {
log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr))
}
return nil
}
// installScopedDefaultFor resolves the physical default nexthop for the given
// address family, installs a scoped default via it, and caches the iface for
// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds.
func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
nexthop, err := GetNextHop(unspec)
if err != nil {
if errors.Is(err, vars.ErrRouteNotFound) {
return false, nil
}
return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err)
}
if nexthop.Intf == nil {
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
}
reused := false
if err := r.addScopedDefault(unspec, nexthop); err != nil {
if !errors.Is(err, unix.EEXIST) {
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
}
// macOS installs its own RTF_IFSCOPE defaults for primary service
// selection on multi-NIC setups, so a route on this ifindex can
// already exist before we try. Binding to it via IP[V6]_BOUND_IF
// still produces the scoped lookup we need.
reused = true
}
af := unix.AF_INET
if unspec.Is6() {
af = unix.AF_INET6
}
nbnet.SetBoundInterface(af, nexthop.Intf)
via := "point-to-point"
if nexthop.IP.IsValid() {
via = nexthop.IP.String()
}
verb := "installed"
if reused {
verb = "reused existing"
}
log.Infof("%s scoped default route via %s on %s for %s", verb, via, nexthop.Intf.Name, afOf(unspec))
return true, nil
}
func (r *SysOps) cleanupAdvancedRouting() error {
nbnet.ClearBoundInterfaces()
return r.flushScopedDefaults()
}
// flushPlatformExtras runs darwin-specific residual cleanup hooked into the
// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get
// removed on the next boot regardless of whether a profile is brought up.
func (r *SysOps) flushPlatformExtras() error {
return r.flushScopedDefaults()
}
// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag.
// Safe to call at startup to clear residual entries from a prior session.
func (r *SysOps) flushScopedDefaults() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
removed := 0
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
if rtMsg.Flags&unix.RTF_IFSCOPE == 0 {
continue
}
info, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("skip scoped flush: %v", err)
continue
}
if !info.Dst.IsValid() || info.Dst.Bits() != 0 {
continue
}
if err := r.deleteScopedRoute(rtMsg); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w",
info.Dst, rtMsg.Index, err))
continue
}
removed++
log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index)
}
if removed > 0 {
log.Infof("flushed %d residual scoped default route(s)", removed)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error {
return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop)
}
func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error {
// Preserve identifying flags from the stored route (including RTF_GATEWAY
// only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE.
keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag
del := &route.RouteMessage{
Type: unix.RTM_DELETE,
Flags: rtMsg.Flags & keep,
Version: unix.RTM_VERSION,
Seq: r.getSeq(),
Index: rtMsg.Index,
Addrs: rtMsg.Addrs,
}
return r.writeRouteMessage(del, scopedRouteBudget)
}
func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error {
flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag
msg := &route.RouteMessage{
Type: action,
Flags: flags,
Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(),
Index: nexthop.Intf.Index,
}
const numAddrs = unix.RTAX_NETMASK + 1
addrs := make([]route.Addr, numAddrs)
dst, err := addrToRouteAddr(unspec)
if err != nil {
return fmt.Errorf("build destination: %w", err)
}
mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0))
if err != nil {
return fmt.Errorf("build netmask: %w", err)
}
addrs[unix.RTAX_DST] = dst
addrs[unix.RTAX_NETMASK] = mask
if nexthop.IP.IsValid() {
msg.Flags |= unix.RTF_GATEWAY
gw, err := addrToRouteAddr(nexthop.IP.Unmap())
if err != nil {
return fmt.Errorf("build gateway: %w", err)
}
addrs[unix.RTAX_GATEWAY] = gw
} else {
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
Index: nexthop.Intf.Index,
Name: nexthop.Intf.Name,
}
}
msg.Addrs = addrs
return r.writeRouteMessage(msg, scopedRouteBudget)
}
func afOf(a netip.Addr) string {
if a.Is4() {
return "IPv4"
}
return "IPv6"
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/net/hooks" "github.com/netbirdio/netbird/client/net/hooks"
) )
@@ -31,8 +32,6 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error { func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
stateManager.RegisterState(&ShutdownState{}) stateManager.RegisterState(&ShutdownState{})
@@ -397,12 +396,16 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
} }
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix. // IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux,
// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) { func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
localRoutes, err := hasSeparateRouting() if nbnet.AdvancedRouting() {
return false, netip.Prefix{}
}
localRoutes, err := GetRoutesFromTable()
if err != nil { if err != nil {
if !errors.Is(err, ErrRoutingIsSeparate) { log.Errorf("Failed to get routes: %v", err)
log.Errorf("Failed to get routes: %v", err)
}
return false, netip.Prefix{} return false, netip.Prefix{}
} }

View File

@@ -22,10 +22,6 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil return []netip.Prefix{}, nil
} }
func hasSeparateRouting() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
// GetDetailedRoutesFromTable returns empty routes for WASM. // GetDetailedRoutesFromTable returns empty routes for WASM.
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) { func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
return []DetailedRoute{}, nil return []DetailedRoute{}, nil

View File

@@ -894,13 +894,6 @@ func getAddressFamily(prefix netip.Prefix) int {
return netlink.FAMILY_V6 return netlink.FAMILY_V6
} }
func hasSeparateRouting() ([]netip.Prefix, error) {
if !nbnet.AdvancedRouting() {
return GetRoutesFromTable()
}
return nil, ErrRoutingIsSeparate
}
func isOpErr(err error) bool { func isOpErr(err error) bool {
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported // EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) { if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {

View File

@@ -48,10 +48,6 @@ func EnableIPForwarding() error {
return nil return nil
} }
func hasSeparateRouting() ([]netip.Prefix, error) {
return GetRoutesFromTable()
}
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms) // GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
func GetIPRules() ([]IPRule, error) { func GetIPRules() ([]IPRule, error) {
log.Infof("IP rules collection is not supported on %s", runtime.GOOS) log.Infof("IP rules collection is not supported on %s", runtime.GOOS)

View File

@@ -25,6 +25,9 @@ import (
const ( const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG" envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
// routeBudget bounds retries for per-prefix exclusion route programming.
routeBudget = 1 * time.Second
) )
var routeProtoFlag int var routeProtoFlag int
@@ -41,26 +44,42 @@ func init() {
} }
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.setupAdvancedRouting()
}
log.Infof("Using legacy routing setup with ref counters")
return r.setupRefCounter(initAddresses, stateManager) return r.setupRefCounter(initAddresses, stateManager)
} }
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.cleanupAdvancedRouting()
}
return r.cleanupRefCounter(stateManager) return r.cleanupRefCounter(stateManager)
} }
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. // FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a
// crashed prior session can't leave crud in the table.
func (r *SysOps) FlushMarkedRoutes() error { func (r *SysOps) FlushMarkedRoutes() error {
var merr *multierror.Error
if err := r.flushPlatformExtras(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err))
}
rib, err := retryFetchRIB() rib, err := retryFetchRIB()
if err != nil { if err != nil {
return fmt.Errorf("fetch routing table: %w", err) return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err)))
} }
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil { if err != nil {
return fmt.Errorf("parse routing table: %w", err) return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err)))
} }
var merr *multierror.Error
flushedCount := 0 flushedCount := 0
for _, msg := range msgs { for _, msg := range msgs {
@@ -117,12 +136,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return fmt.Errorf("invalid prefix: %s", prefix) return fmt.Errorf("invalid prefix: %s", prefix)
} }
expBackOff := backoff.NewExponentialBackOff() msg, err := r.buildRouteMessage(action, prefix, nexthop)
expBackOff.InitialInterval = 50 * time.Millisecond if err != nil {
expBackOff.MaxInterval = 500 * time.Millisecond return fmt.Errorf("build route message: %w", err)
expBackOff.MaxElapsedTime = 1 * time.Second }
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil { if err := r.writeRouteMessage(msg, routeBudget); err != nil {
a := "add" a := "add"
if action == unix.RTM_DELETE { if action == unix.RTM_DELETE {
a = "remove" a = "remove"
@@ -132,50 +151,91 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return nil return nil
} }
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error { // writeRouteMessage sends a route message over AF_ROUTE and waits for the
operation := func() error { // kernel's matching reply, retrying transient failures until budget elapses.
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) // Callers do not need to manage sockets or seq numbers themselves.
if err != nil { func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error {
return fmt.Errorf("open routing socket: %w", err) expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = budget
return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff)
}
func routeMessageRoundtrip(msg *route.RouteMessage) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("close routing socket: %v", err)
} }
defer func() { }()
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("failed to close routing socket: %v", err) tv := unix.Timeval{Sec: 1}
if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err))
}
// AF_ROUTE is a broadcast channel: every route socket on the host sees
// every RTM_* event. With concurrent route programming the default
// per-socket queue overflows and our own reply gets dropped.
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil {
log.Debugf("set SO_RCVBUF on route socket: %v", err)
}
bytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal: %w", err))
}
if _, err = unix.Write(fd, bytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
return readRouteResponse(fd, msg.Type, msg.Seq)
}
// readRouteResponse reads from the AF_ROUTE socket until it sees a reply
// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a
// broadcast channel: interface up/down, third-party route changes and neighbor
// discovery events can all land between our write and read, so we must filter.
func readRouteResponse(fd, wantType, wantSeq int) error {
pid := int32(os.Getpid())
resp := make([]byte, 2048)
deadline := time.Now().Add(time.Second)
for {
if time.Now().After(deadline) {
// Transient: under concurrent pressure the kernel can drop our reply
// from the socket buffer. Let backoff.Retry re-send with a fresh seq.
return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq)
}
n, err := unix.Read(fd, resp)
if err != nil {
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) {
// SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline.
continue
} }
}() return backoff.Permanent(fmt.Errorf("read: %w", err))
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
} }
if n < int(unsafe.Sizeof(unix.RtMsghdr{})) {
msgBytes, err := msg.Marshal() continue
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
} }
hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0]))
if _, err = unix.Write(fd, msgBytes); err != nil { // Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid)
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) { // uniquely identifies our own reply among broadcast traffic.
return fmt.Errorf("write: %w", err) if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid {
} continue
return backoff.Permanent(fmt.Errorf("write: %w", err))
} }
if hdr.Errno != 0 {
respBuf := make([]byte, 2048) return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno)))
n, err := unix.Read(fd, respBuf)
if err != nil {
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
} }
if n > 0 {
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
return backoff.Permanent(err)
}
}
return nil return nil
} }
return operation
} }
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
@@ -183,6 +243,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
Type: action, Type: action,
Flags: unix.RTF_UP | routeProtoFlag, Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION, Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(), Seq: r.getSeq(),
} }
@@ -221,19 +282,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
return msg, nil return msg, nil
} }
func (r *SysOps) parseRouteResponse(buf []byte) error {
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
return nil
}
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
if rtMsg.Errno != 0 {
return fmt.Errorf("parse: %d", rtMsg.Errno)
}
return nil
}
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr). // addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) { func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
if addr.Is4() { if addr.Is4() {

View File

@@ -7,7 +7,6 @@ import (
"sync" "sync"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -44,8 +43,8 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
if rs.selectedRoutes == nil { if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
maps.Clear(rs.deselectedRoutes) clear(rs.deselectedRoutes)
maps.Clear(rs.selectedRoutes) clear(rs.selectedRoutes)
for _, r := range allRoutes { for _, r := range allRoutes {
rs.deselectedRoutes[r] = struct{}{} rs.deselectedRoutes[r] = struct{}{}
} }
@@ -78,8 +77,8 @@ func (rs *RouteSelector) SelectAllRoutes() {
if rs.selectedRoutes == nil { if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
maps.Clear(rs.deselectedRoutes) clear(rs.deselectedRoutes)
maps.Clear(rs.selectedRoutes) clear(rs.selectedRoutes)
} }
// DeselectRoutes removes specific routes from the selection. // DeselectRoutes removes specific routes from the selection.
@@ -116,8 +115,8 @@ func (rs *RouteSelector) DeselectAllRoutes() {
if rs.selectedRoutes == nil { if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
maps.Clear(rs.deselectedRoutes) clear(rs.deselectedRoutes)
maps.Clear(rs.selectedRoutes) clear(rs.selectedRoutes)
} }
// IsSelected checks if a specific route is selected. // IsSelected checks if a specific route is selected.

View File

@@ -2,217 +2,358 @@
package sleep package sleep
/*
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
#include <IOKit/pwr_mgt/IOPMLib.h>
#include <IOKit/IOMessage.h>
#include <CoreFoundation/CoreFoundation.h>
extern void sleepCallbackBridge();
extern void poweredOnCallbackBridge();
extern void suspendedCallbackBridge();
extern void resumedCallbackBridge();
// C global variables for IOKit state
static IONotificationPortRef g_notifyPortRef = NULL;
static io_object_t g_notifierObject = 0;
static io_object_t g_generalInterestNotifier = 0;
static io_connect_t g_rootPort = 0;
static CFRunLoopRef g_runLoop = NULL;
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
switch (messageType) {
case kIOMessageSystemWillSleep:
sleepCallbackBridge();
IOAllowPowerChange(g_rootPort, (long)messageArgument);
break;
case kIOMessageSystemHasPoweredOn:
poweredOnCallbackBridge();
break;
case kIOMessageServiceIsSuspended:
suspendedCallbackBridge();
break;
case kIOMessageServiceIsResumed:
resumedCallbackBridge();
break;
default:
break;
}
}
static void registerNotifications() {
g_rootPort = IORegisterForSystemPower(
NULL,
&g_notifyPortRef,
(IOServiceInterestCallback)sleepCallback,
&g_notifierObject
);
if (g_rootPort == 0) {
return;
}
CFRunLoopAddSource(CFRunLoopGetCurrent(),
IONotificationPortGetRunLoopSource(g_notifyPortRef),
kCFRunLoopCommonModes);
g_runLoop = CFRunLoopGetCurrent();
CFRunLoopRun();
}
static void unregisterNotifications() {
CFRunLoopRemoveSource(g_runLoop,
IONotificationPortGetRunLoopSource(g_notifyPortRef),
kCFRunLoopCommonModes);
IODeregisterForSystemPower(&g_notifierObject);
IOServiceClose(g_rootPort);
IONotificationPortDestroy(g_notifyPortRef);
CFRunLoopStop(g_runLoop);
g_notifyPortRef = NULL;
g_notifierObject = 0;
g_rootPort = 0;
g_runLoop = NULL;
}
*/
import "C"
import ( import (
"context"
"fmt" "fmt"
"runtime" "runtime"
"sync" "sync"
"time" "time"
"unsafe"
"github.com/ebitengine/purego"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
var ( // IOKit message types from IOKit/IOMessage.h.
serviceRegistry = make(map[*Detector]struct{}) const (
serviceRegistryMu sync.Mutex kIOMessageCanSystemSleep uintptr = 0xe0000270
kIOMessageSystemWillSleep uintptr = 0xe0000280
kIOMessageSystemHasPoweredOn uintptr = 0xe0000300
) )
//export sleepCallbackBridge var (
func sleepCallbackBridge() { ioKit iokitFuncs
log.Info("sleepCallbackBridge event triggered") cf cfFuncs
cfCommonModes uintptr
serviceRegistryMu.Lock() libInitOnce sync.Once
defer serviceRegistryMu.Unlock() libInitErr error
for svc := range serviceRegistry { // callbackThunk is the single C-callable trampoline registered with IOKit.
svc.triggerCallback(EventTypeSleep) callbackThunk uintptr
}
serviceRegistry = make(map[*Detector]struct{})
serviceRegistryMu sync.Mutex
session *runLoopSession
// lifecycleMu serializes Register/Deregister so a new registration can't
// start a second runloop while a previous teardown is still pending.
lifecycleMu sync.Mutex
)
// iokitFuncs holds IOKit symbols resolved once at init.
type iokitFuncs struct {
IORegisterForSystemPower func(refcon uintptr, portRef *uintptr, callback uintptr, notifier *uintptr) uintptr
IODeregisterForSystemPower func(notifier *uintptr) int32
IOAllowPowerChange func(kernelPort uintptr, notificationID uintptr) int32
IOServiceClose func(connect uintptr) int32
IONotificationPortGetRunLoopSource func(port uintptr) uintptr
IONotificationPortDestroy func(port uintptr)
} }
//export resumedCallbackBridge // cfFuncs holds CoreFoundation symbols resolved once at init.
func resumedCallbackBridge() { type cfFuncs struct {
log.Info("resumedCallbackBridge event triggered") CFRunLoopGetCurrent func() uintptr
CFRunLoopRun func()
CFRunLoopStop func(rl uintptr)
CFRunLoopAddSource func(rl, source, mode uintptr)
CFRunLoopRemoveSource func(rl, source, mode uintptr)
} }
//export suspendedCallbackBridge // runLoopSession bundles the handles owned by one CFRunLoop lifetime. A nil
func suspendedCallbackBridge() { // session means no runloop is active and the next Register must start one.
log.Info("suspendedCallbackBridge event triggered") type runLoopSession struct {
rl uintptr
port uintptr
notifier uintptr
rp uintptr
} }
//export poweredOnCallbackBridge // detectorSnapshot pins a detector's callback and done channel so dispatch
func poweredOnCallbackBridge() { // runs with values valid at snapshot time, even if a concurrent
log.Info("poweredOnCallbackBridge event triggered") // Deregister/Register rewrites the detector's fields.
serviceRegistryMu.Lock() type detectorSnapshot struct {
defer serviceRegistryMu.Unlock() detector *Detector
callback func(event EventType)
for svc := range serviceRegistry { done <-chan struct{}
svc.triggerCallback(EventTypeWakeUp)
}
} }
// Detector delivers sleep and wake events to a registered callback.
type Detector struct { type Detector struct {
callback func(event EventType) callback func(event EventType)
ctx context.Context done chan struct{}
cancel context.CancelFunc
}
func NewDetector() (*Detector, error) {
return &Detector{}, nil
} }
// Register installs callback for power events. The first registration starts
// the CFRunLoop on a dedicated OS-locked thread and blocks until IOKit
// registration succeeds or fails; subsequent registrations just add to the
// dispatch set.
func (d *Detector) Register(callback func(event EventType)) error { func (d *Detector) Register(callback func(event EventType)) error {
serviceRegistryMu.Lock() lifecycleMu.Lock()
defer serviceRegistryMu.Unlock() defer lifecycleMu.Unlock()
serviceRegistryMu.Lock()
if _, exists := serviceRegistry[d]; exists { if _, exists := serviceRegistry[d]; exists {
serviceRegistryMu.Unlock()
return fmt.Errorf("detector service already registered") return fmt.Errorf("detector service already registered")
} }
d.callback = callback d.callback = callback
d.done = make(chan struct{})
serviceRegistry[d] = struct{}{}
needSetup := session == nil
serviceRegistryMu.Unlock()
d.ctx, d.cancel = context.WithCancel(context.Background()) if !needSetup {
if len(serviceRegistry) > 0 {
serviceRegistry[d] = struct{}{}
return nil return nil
} }
serviceRegistry[d] = struct{}{} errCh := make(chan error, 1)
go runRunLoop(errCh)
// CFRunLoop must run on a single fixed OS thread if err := <-errCh; err != nil {
go func() { serviceRegistryMu.Lock()
runtime.LockOSThread() delete(serviceRegistry, d)
defer runtime.UnlockOSThread() close(d.done)
d.done = nil
C.registerNotifications() serviceRegistryMu.Unlock()
}() return err
}
log.Info("sleep detection service started on macOS") log.Info("sleep detection service started on macOS")
return nil return nil
} }
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down // Deregister removes the detector. When the last detector leaves, IOKit
// and the runloop is stopped and cleaned up. // notifications are torn down and the runloop is stopped.
func (d *Detector) Deregister() error { func (d *Detector) Deregister() error {
lifecycleMu.Lock()
defer lifecycleMu.Unlock()
serviceRegistryMu.Lock() serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock() if _, exists := serviceRegistry[d]; !exists {
_, exists := serviceRegistry[d] serviceRegistryMu.Unlock()
if !exists {
return nil return nil
} }
close(d.done)
// cancel and remove this detector
d.cancel()
delete(serviceRegistry, d) delete(serviceRegistry, d)
// If other Detectors still exist, leave IOKit running
if len(serviceRegistry) > 0 { if len(serviceRegistry) > 0 {
serviceRegistryMu.Unlock()
return nil return nil
} }
sess := session
serviceRegistryMu.Unlock()
log.Info("sleep detection service stopping (deregister)") log.Info("sleep detection service stopping (deregister)")
// Deregister IOKit notifications, stop runloop, and free resources if sess == nil {
C.unregisterNotifications() return nil
}
if sess.rl != 0 && sess.port != 0 {
source := ioKit.IONotificationPortGetRunLoopSource(sess.port)
cf.CFRunLoopRemoveSource(sess.rl, source, cfCommonModes)
}
if sess.notifier != 0 {
n := sess.notifier
ioKit.IODeregisterForSystemPower(&n)
}
// Clear session only after IODeregisterForSystemPower returns so any
// in-flight powerCallback can still look up session.rp to ack sleep.
serviceRegistryMu.Lock()
session = nil
serviceRegistryMu.Unlock()
if sess.rp != 0 {
ioKit.IOServiceClose(sess.rp)
}
if sess.port != 0 {
ioKit.IONotificationPortDestroy(sess.port)
}
if sess.rl != 0 {
cf.CFRunLoopStop(sess.rl)
}
return nil return nil
} }
func (d *Detector) triggerCallback(event EventType) { func (d *Detector) triggerCallback(event EventType, cb func(event EventType), done <-chan struct{}) {
doneChan := make(chan struct{}) if cb == nil || done == nil {
return
}
select {
case <-done:
return
default:
}
doneChan := make(chan struct{})
timeout := time.NewTimer(500 * time.Millisecond) timeout := time.NewTimer(500 * time.Millisecond)
defer timeout.Stop() defer timeout.Stop()
cb := d.callback go func() {
go func(callback func(event EventType)) { defer close(doneChan)
defer func() {
if r := recover(); r != nil {
log.Errorf("panic in sleep callback: %v", r)
}
}()
log.Info("sleep detection event fired") log.Info("sleep detection event fired")
callback(event) cb(event)
close(doneChan) }()
}(cb)
select { select {
case <-doneChan: case <-doneChan:
case <-d.ctx.Done(): case <-done:
case <-timeout.C: case <-timeout.C:
log.Warnf("sleep callback timed out") log.Warn("sleep callback timed out")
} }
} }
// NewDetector initializes IOKit/CoreFoundation bindings and returns a Detector.
func NewDetector() (*Detector, error) {
if err := initLibs(); err != nil {
return nil, err
}
return &Detector{}, nil
}
func initLibs() error {
libInitOnce.Do(func() {
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
libInitErr = fmt.Errorf("dlopen IOKit: %w", err)
return
}
cfLib, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
libInitErr = fmt.Errorf("dlopen CoreFoundation: %w", err)
return
}
purego.RegisterLibFunc(&ioKit.IORegisterForSystemPower, iokit, "IORegisterForSystemPower")
purego.RegisterLibFunc(&ioKit.IODeregisterForSystemPower, iokit, "IODeregisterForSystemPower")
purego.RegisterLibFunc(&ioKit.IOAllowPowerChange, iokit, "IOAllowPowerChange")
purego.RegisterLibFunc(&ioKit.IOServiceClose, iokit, "IOServiceClose")
purego.RegisterLibFunc(&ioKit.IONotificationPortGetRunLoopSource, iokit, "IONotificationPortGetRunLoopSource")
purego.RegisterLibFunc(&ioKit.IONotificationPortDestroy, iokit, "IONotificationPortDestroy")
purego.RegisterLibFunc(&cf.CFRunLoopGetCurrent, cfLib, "CFRunLoopGetCurrent")
purego.RegisterLibFunc(&cf.CFRunLoopRun, cfLib, "CFRunLoopRun")
purego.RegisterLibFunc(&cf.CFRunLoopStop, cfLib, "CFRunLoopStop")
purego.RegisterLibFunc(&cf.CFRunLoopAddSource, cfLib, "CFRunLoopAddSource")
purego.RegisterLibFunc(&cf.CFRunLoopRemoveSource, cfLib, "CFRunLoopRemoveSource")
modeAddr, err := purego.Dlsym(cfLib, "kCFRunLoopCommonModes")
if err != nil {
libInitErr = fmt.Errorf("dlsym kCFRunLoopCommonModes: %w", err)
return
}
// Launder the uintptr-to-pointer conversion through a Go variable so
// go vet's unsafeptr analyzer doesn't flag a system-library global.
cfCommonModes = **(**uintptr)(unsafe.Pointer(&modeAddr))
// NewCallback slots are a finite, non-reclaimable resource, so register
// a single thunk that dispatches to the current Detector set.
callbackThunk = purego.NewCallback(powerCallback)
})
return libInitErr
}
// powerCallback is the IOServiceInterestCallback trampoline, invoked on the
// runloop thread. A Go panic crossing the purego boundary has undefined
// behavior, so contain it here.
func powerCallback(refcon, service, messageType, messageArgument uintptr) uintptr {
defer func() {
if r := recover(); r != nil {
log.Errorf("panic in sleep powerCallback: %v", r)
}
}()
switch messageType {
case kIOMessageCanSystemSleep:
// Not acknowledging forces a 30s IOKit timeout before idle sleep.
allowPowerChange(messageArgument)
case kIOMessageSystemWillSleep:
dispatchEvent(EventTypeSleep)
allowPowerChange(messageArgument)
case kIOMessageSystemHasPoweredOn:
dispatchEvent(EventTypeWakeUp)
}
return 0
}
func allowPowerChange(messageArgument uintptr) {
serviceRegistryMu.Lock()
var port uintptr
if session != nil {
port = session.rp
}
serviceRegistryMu.Unlock()
if port != 0 {
ioKit.IOAllowPowerChange(port, messageArgument)
}
}
func dispatchEvent(event EventType) {
serviceRegistryMu.Lock()
snaps := make([]detectorSnapshot, 0, len(serviceRegistry))
for d := range serviceRegistry {
snaps = append(snaps, detectorSnapshot{
detector: d,
callback: d.callback,
done: d.done,
})
}
serviceRegistryMu.Unlock()
for _, s := range snaps {
s.detector.triggerCallback(event, s.callback, s.done)
}
}
// runRunLoop owns the OS-locked thread that CFRunLoop is pinned to. Setup
// result is reported on errCh so Register can surface failures synchronously.
func runRunLoop(errCh chan<- error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
sess, err := setupSession()
if err == nil {
serviceRegistryMu.Lock()
session = sess
serviceRegistryMu.Unlock()
}
errCh <- err
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
log.Errorf("panic in sleep runloop: %v", r)
}
}()
cf.CFRunLoopRun()
}
// setupSession performs the IOKit registration on the current thread. Panics
// are converted to errors so runRunLoop never leaves errCh unsent.
func setupSession() (s *runLoopSession, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during runloop setup: %v", r)
}
}()
var portRef, notifier uintptr
rp := ioKit.IORegisterForSystemPower(0, &portRef, callbackThunk, &notifier)
if rp == 0 {
return nil, fmt.Errorf("IORegisterForSystemPower returned zero")
}
rl := cf.CFRunLoopGetCurrent()
source := ioKit.IONotificationPortGetRunLoopSource(portRef)
cf.CFRunLoopAddSource(rl, source, cfCommonModes)
return &runLoopSession{rl: rl, port: portRef, notifier: notifier, rp: rp}, nil
}

View File

@@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"net" "net"
"runtime" "runtime"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -28,6 +27,10 @@ func NewWGIfaceMonitor() *WGIfaceMonitor {
// Start begins monitoring the WireGuard interface. // Start begins monitoring the WireGuard interface.
// It relies on the provided context cancellation to stop. // It relies on the provided context cancellation to stop.
//
// On Linux the watcher is event-driven (RTNLGRP_LINK netlink subscription)
// to avoid the allocation churn of repeatedly dumping the kernel link
// table; on other platforms it falls back to a low-frequency poll.
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) { func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
defer close(m.done) defer close(m.done)
@@ -56,31 +59,7 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex) log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
ticker := time.NewTicker(2 * time.Second) return watchInterface(ctx, ifaceName, expectedIndex)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
case <-ticker.C:
currentIndex, err := getInterfaceIndex(ifaceName)
if err != nil {
// Interface was deleted
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
}
// Check if interface index changed (interface was recreated)
if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
}
}
} }
// getInterfaceIndex returns the index of a network interface by name. // getInterfaceIndex returns the index of a network interface by name.

View File

@@ -0,0 +1,134 @@
//go:build linux
package internal
import (
"context"
"fmt"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
)
// watchInterface uses an RTNLGRP_LINK netlink subscription to detect
// deletion or recreation of the WireGuard interface.
//
// The previous implementation polled net.InterfaceByName every 2 s, which
// on Linux issues syscall.NetlinkRIB(RTM_GETLINK, ...) and dumps the
// entire kernel link table on every call. On hosts with many veth
// interfaces (containers, bridges) the resulting allocation churn was on
// the order of ~1 GB/day from this single ticker, which on small ARM
// hosts manifested as a slow RSS climb (see netbirdio/netbird#3678).
//
// The event-driven version below allocates only when the kernel actually
// publishes a link event for the tracked interface — typically zero
// allocations between events.
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
done := make(chan struct{})
defer close(done)
// Buffer the channel to absorb event bursts (e.g. when many veth
// pairs are created/destroyed at once by container runtimes).
linkChan := make(chan netlink.LinkUpdate, 32)
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
// Return shouldRestart=true so the engine recovers monitoring
// via triggerClientRestart instead of silently losing it for
// the rest of the process lifetime.
return true, fmt.Errorf("subscribe to link updates: %w", err)
}
// Race window: the interface could have been deleted (or recreated)
// between the initial getInterfaceIndex() in Start and LinkSubscribe
// completing its handshake with the kernel. Re-check explicitly so we
// do not block forever waiting for an event that already fired.
if currentIndex, err := getInterfaceIndex(ifaceName); err != nil {
log.Infof("Interface monitor: %s deleted before subscription completed", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
} else if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d) before subscription completed",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
case update, ok := <-linkChan:
if !ok {
// The vishvananda/netlink subscription goroutine closes
// the channel on receive errors. Signal the engine to
// restart so monitoring is re-established instead of
// silently ending.
log.Warnf("Interface monitor: link subscription channel closed unexpectedly for %s", ifaceName)
return true, fmt.Errorf("link subscription channel closed unexpectedly")
}
if restart, err := inspectLinkEvent(update, ifaceName, expectedIndex); restart {
return true, err
}
}
}
}
// inspectLinkEvent classifies a single netlink link update against the
// tracked WireGuard interface. It returns (true, err) when the engine
// should restart monitoring; (false, nil) means the event is unrelated
// and the caller should keep waiting.
//
// The error component, when non-nil, describes the kernel-side reason
// (deletion or rename); the recreation case returns (true, nil) since
// no error condition is reported.
func inspectLinkEvent(update netlink.LinkUpdate, ifaceName string, expectedIndex int) (bool, error) {
eventIndex := int(update.Index)
eventName := ""
if attrs := update.Attrs(); attrs != nil {
eventName = attrs.Name
}
switch update.Header.Type {
case syscall.RTM_DELLINK:
return inspectDelLink(eventIndex, ifaceName, expectedIndex)
case syscall.RTM_NEWLINK:
return inspectNewLink(eventIndex, eventName, ifaceName, expectedIndex)
}
return false, nil
}
// inspectDelLink reports a restart when an RTM_DELLINK arrives for the
// tracked interface index.
func inspectDelLink(eventIndex int, ifaceName string, expectedIndex int) (bool, error) {
if eventIndex != expectedIndex {
return false, nil
}
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted", ifaceName)
}
// inspectNewLink reports a restart when an RTM_NEWLINK either:
//
// 1. Introduces a link with our name at a different index (recreation
// after a delete), or
//
// 2. Reports a link still at our index but with a different name
// (in-place rename). The previous polling implementation caught
// this implicitly because net.InterfaceByName(ifaceName) would
// start failing; the event-driven version has to test it.
//
// Same name + same index is just a flag/state change on the existing
// interface and is ignored.
func inspectNewLink(eventIndex int, eventName, ifaceName string, expectedIndex int) (bool, error) {
if eventName == ifaceName && eventIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, eventIndex)
return true, nil
}
if eventIndex == expectedIndex && eventName != "" && eventName != ifaceName {
log.Infof("Interface monitor: %s renamed to %s (index %d), restarting engine",
ifaceName, eventName, expectedIndex)
return true, fmt.Errorf("interface %s renamed to %s", ifaceName, eventName)
}
return false, nil
}

View File

@@ -0,0 +1,56 @@
//go:build !linux
package internal
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
)
// watchInterface polls net.InterfaceByName at a fixed interval to detect
// deletion or recreation of the WireGuard interface.
//
// This is the fallback used on non-Linux desktop and server platforms
// (darwin, windows, freebsd). It is also compiled on android and ios so
// the package builds on every supported GOOS, but it is never reached
// at runtime there because Start() in wg_iface_monitor.go exits early
// on mobile platforms.
//
// The Linux build (see wg_iface_monitor_linux.go) uses an event-driven
// RTNLGRP_LINK netlink subscription instead, because on Linux
// net.InterfaceByName issues syscall.NetlinkRIB(RTM_GETLINK, ...) which
// dumps the entire kernel link table on every call and produces
// significant allocation churn (netbirdio/netbird#3678).
//
// Windows is also reported in #3678 as affected by RSS climb. A future
// follow-up could implement an event-driven watcher there using
// NotifyIpInterfaceChange from iphlpapi.
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
case <-ticker.C:
currentIndex, err := getInterfaceIndex(ifaceName)
if err != nil {
// Interface was deleted
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
}
// Check if interface index changed (interface was recreated)
if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
}
}
}

View File

@@ -0,0 +1,5 @@
package net
func (d *Dialer) init() {
d.Dialer.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows //go:build !linux && !windows && !darwin
package net package net

View File

@@ -1,24 +0,0 @@
//go:build android
package net
// Init initializes the network environment for Android
func Init() {
// No initialization needed on Android
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on Android since we cannot handle routes dynamically.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on Android
func SetVPNInterfaceName(name string) {
// No-op on Android - not needed for Android VPN service
}
// GetVPNInterfaceName returns empty string on Android
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -1,4 +1,4 @@
//go:build windows //go:build (darwin && !ios) || windows
package net package net
@@ -24,17 +24,22 @@ func Init() {
} }
func checkAdvancedRoutingSupport() bool { func checkAdvancedRoutingSupport() bool {
var err error legacyRouting := false
var legacyRouting bool
if val := os.Getenv(envUseLegacyRouting); val != "" { if val := os.Getenv(envUseLegacyRouting); val != "" {
legacyRouting, err = strconv.ParseBool(val) parsed, err := strconv.ParseBool(val)
if err != nil { if err != nil {
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err) log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err)
} else {
legacyRouting = parsed
} }
} }
if legacyRouting || netstack.IsEnabled() { if legacyRouting {
log.Info("advanced routing has been requested to be disabled") log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting)
return false
}
if netstack.IsEnabled() {
log.Info("advanced routing disabled: netstack mode is enabled")
return false return false
} }

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !android //go:build !linux && !windows && !darwin
package net package net

25
client/net/env_mobile.go Normal file
View File

@@ -0,0 +1,25 @@
//go:build ios || android
package net
// Init initializes the network environment for mobile platforms.
func Init() {
// no-op on mobile: routing scope is owned by the VPN extension.
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension
// owns the routing scope.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on mobile.
func SetVPNInterfaceName(string) {
// no-op on mobile: the VPN extension manages the interface.
}
// GetVPNInterfaceName returns an empty string on mobile.
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -0,0 +1,5 @@
package net
func (l *ListenerConfig) init() {
l.ListenConfig.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows //go:build !linux && !windows && !darwin
package net package net

160
client/net/net_darwin.go Normal file
View File

@@ -0,0 +1,160 @@
package net
import (
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack
// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6"
// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns
// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is
// dispatched by socket domain (AF_INET only) not by inp_vflag.
// boundIface holds the physical interface chosen at routing setup time. Sockets
// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF
// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup
// hits the RTF_IFSCOPE default installed by the routemanager, rather than
// following the VPN's split default.
var (
boundIfaceMu sync.RWMutex
boundIface4 *net.Interface
boundIface6 *net.Interface
)
// SetBoundInterface records the egress interface for an address family. Called
// by the routemanager after a scoped default route has been installed.
// af must be unix.AF_INET or unix.AF_INET6; other values are ignored.
// nil iface is rejected — use ClearBoundInterfaces to clear all slots.
func SetBoundInterface(af int, iface *net.Interface) {
if iface == nil {
log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af)
return
}
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
switch af {
case unix.AF_INET:
boundIface4 = iface
case unix.AF_INET6:
boundIface6 = iface
default:
log.Warnf("SetBoundInterface: unsupported address family %d", af)
}
}
// ClearBoundInterfaces resets the cached egress interfaces. Called by the
// routemanager during cleanup.
func ClearBoundInterfaces() {
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
boundIface4 = nil
boundIface6 = nil
}
// boundInterfaceFor returns the cached egress interface for a socket's address
// family, falling back to the other family if the preferred slot is empty.
// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so
// either setsockopt scopes the socket; preferring same-family still matters
// when v4 and v6 defaults egress different NICs.
func boundInterfaceFor(network, address string) *net.Interface {
if iface := zoneInterface(address); iface != nil {
return iface
}
boundIfaceMu.RLock()
defer boundIfaceMu.RUnlock()
primary, secondary := boundIface4, boundIface6
if isV6Network(network) {
primary, secondary = boundIface6, boundIface4
}
if primary != nil {
return primary
}
return secondary
}
func isV6Network(network string) bool {
return strings.HasSuffix(network, "6")
}
// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0).
func zoneInterface(address string) *net.Interface {
if address == "" {
return nil
}
addr, err := netip.ParseAddrPort(address)
if err != nil {
a, err := netip.ParseAddr(address)
if err != nil {
return nil
}
addr = netip.AddrPortFrom(a, 0)
}
zone := addr.Addr().Zone()
if zone == "" {
return nil
}
if iface, err := net.InterfaceByName(zone); err == nil {
return iface
}
if idx, err := strconv.Atoi(zone); err == nil {
if iface, err := net.InterfaceByIndex(idx); err == nil {
return iface
}
}
return nil
}
func setIPv4BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
func setIPv6BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
// applyBoundIfToSocket binds the socket to the cached physical egress interface
// so scoped route lookup avoids the VPN utun and egresses the underlay directly.
func applyBoundIfToSocket(network, address string, c syscall.RawConn) error {
if !AdvancedRouting() {
return nil
}
iface := boundInterfaceFor(network, address)
if iface == nil {
log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address)
return nil
}
isV6 := isV6Network(network)
var controlErr error
if err := c.Control(func(fd uintptr) {
if isV6 {
controlErr = setIPv6BoundIf(fd, iface)
} else {
controlErr = setIPv4BoundIf(fd, iface)
}
if controlErr == nil {
log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address)
}
}); err != nil {
return fmt.Errorf("control: %w", err)
}
return controlErr
}

Some files were not shown because too many files have changed in this diff Show More