Compare commits

..

22 Commits

Author SHA1 Message Date
dependabot[bot]
9acc04817f Bump github.com/gopacket/gopacket from 1.4.0 to 1.6.1
Bumps [github.com/gopacket/gopacket](https://github.com/gopacket/gopacket) from 1.4.0 to 1.6.1.
- [Release notes](https://github.com/gopacket/gopacket/releases)
- [Commits](https://github.com/gopacket/gopacket/compare/v1.4.0...v1.6.1)

---
updated-dependencies:
- dependency-name: github.com/gopacket/gopacket
  dependency-version: 1.6.1
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2026-06-30 02:49:14 +00:00
Zoltan Papp
04c3d19032 [client] Skip firewall ruleset rebuild when config is unchanged (#6508)
* [client] Skip firewall ruleset rebuild when config is unchanged

ApplyFiltering rebuilt every peer and route ACL and flushed the firewall
on every sync, with no guard for an unchanged configuration. Management
re-sends the same network map far more often than it actually changes
(account-wide updates, peer meta churn), so on busy accounts this is the
dominant client-side cost of redundant syncs — especially with a large
route set and a userspace firewall.

Hash the inputs ApplyFiltering consumes (peer rules, route rules, the
empty flag and the dns-route feature flag) and skip the rebuild + flush
when the hash matches the last successfully applied update. Mirrors the
guard the DNS server already uses (previousConfigHash). The hash is only
recorded after apply and flush both succeed, so a failed update is not
skipped on the next (possibly identical) sync and gets a chance to
reconcile the firewall state.

* [client] Include config hash in ACL skip debug log

* [client] Include RoutesFirewallRulesIsEmpty in firewall config hash

* [client] Add benchmarks for firewall config hash computation
2026-06-29 19:51:50 +02:00
Zoltan Papp
3f1fb3b52d [ingest] raise duration validation limit to 24 hours (#6598)
Peer connection timing fields (signaling_to_connection_seconds) can
legitimately exceed 5 minutes during long reconnections; the previous
300 s cap caused valid data points to be rejected.
2026-06-29 19:51:25 +02:00
Viktor Liu
b434cda062 [client] Refresh signal receive liveness when worker handoff drains (#6594) 2026-06-29 12:16:47 +02:00
Zoltan Papp
0b594c639a [client] report management unhealthy while Sync stream is failing (#6575)
* fix(mgm): report management unhealthy while Sync stream is failing

The health probe (IsHealthy) only checked the gRPC transport and a
GetServerKey call. GetServerKey succeeds even when the peer cannot sync
(e.g. the server returns "settings not found"), so the probe kept marking
management Connected while the Sync stream failed in a tight retry loop —
pinning the status to "Connected" forever despite no sync ever succeeding.

Track the last Sync stream error and have IsHealthy consult it, so a
healthy transport is no longer enough to report the connection healthy.

* fix(mgm): record disconnected state when sync stream setup fails

The connectToSyncStream failure path in handleSyncStream returned early
without updating syncStreamErr, so the client could still report healthy
even when stream setup failed. Mirror the receiveUpdatesEvents error path
by calling notifyDisconnected and setSyncStreamDisconnected.
2026-06-29 11:28:58 +02:00
Zoltan Papp
deff8af59f [client] Wait for signal receive watchdog to stop before reconnect (#6574)
* [client] Wait for signal receive watchdog to stop before reconnect

The per-stream watchReceiveStream goroutine was started fire-and-forget
and never joined. On reconnect a lingering watchdog could still flip
shared client state (receiveStalled, the disconnect notifier) on the
freshly established stream, since cancelStream only cancels its own
stream context.

Track the watchdog with a WaitGroup and wait for it to exit (after
cancelling its stream) before the operation returns, so each reconnect
starts with no stale watchdog.

* [client] Bind signal receive probe to the stream context

The watchdog probe reused the generic Send, which derives its per-attempt
timeouts from the long-lived client context, so cancelStream could not
interrupt an in-flight probe. After joining the watchdog on reconnect,
watchdogWg.Wait() could then block for the full send-attempt chain.

Split Send into a context-aware send and pass the stream context down
through sendReceiveProbe, so cancelStream aborts any in-flight probe and
the watchdog exits promptly.
2026-06-29 11:24:25 +02:00
Riccardo Manfrin
5711f0e38c [client] add per-phase timing metrics for sync processing (#6533)
* Adds metrics sync phases time split to monitor costs

* Address review fixes

* Increment README.md with description on usage with debug bundles
2026-06-29 11:02:02 +02:00
Maycon Santos
1409a1325a [misc] Update careers page link (#6538) 2026-06-29 09:19:01 +02:00
Viktor Liu
4400372f37 [client] Forward non-address DNS record types through route forwarders (#6455) 2026-06-28 18:50:17 +02:00
Zoltan Papp
2d7b309004 [client] Categorize privileged tests behind a build tag and run them in Docker (#6425)
* [client] categorize root/system-mutating tests behind a privileged build tag

Tests that need root or mutate host state (nftables/iptables/DNS, TUN/WireGuard
interfaces, routes, eBPF, SSH/service install) are now gated behind a
//go:build privileged tag. The default `go test ./client/...` runs as a non-root
user with no sudo and leaves host networking untouched; mixed files were split so
pure-logic tests stay in the default suite.

A self-hosting ory/dockertest/v4 harness (client/testutil/privileged) runs the
privileged suite inside a --privileged --cap-add=NET_ADMIN container via
`make test-privileged`; a DOCKER_CI=true guard skips the spawn when already inside
the container. Added `make test-unit` for the host-safe run.

* [client] add PRIV_RUN/PRIV_PKGS filters to the privileged test harness

The dockertest harness now reads two optional env vars when building the
in-container `go test` command: PRIV_RUN adds a -run test-name filter and
PRIV_PKGS overrides the package list. Both empty reproduce the full privileged
suite, so CI and `make test-privileged` behave as before. Lets a developer run a
single privileged test in the container, e.g.:

  PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged

* [client] fix unused-helper lint after the privileged test split

Splitting privileged tests into *_privileged_test.go left their shared helpers in
the untagged files, so in the default (no-tag) build they had no callers and
golangci-lint flagged them as unused.

Moved the privileged-only helpers into the privileged files next to their callers
(generateDummyHandler; createEngine/startSignal/startManagement/getConnectedPeers/
getPeers + kaep/kasp; (*mockDaemon).setJWTToken). Annotated the shared routing-test
fixtures that must stay untagged for cross-platform compilation with //nolint:unused
(systemops_bsd expected* vars, ensureIPv6DefaultRoute on bsd/windows,
loopbackIfaceWindows), matching the existing linux variant.

* [client] fix privileged test CI failures and run the harness on macOS

The host-safe unit run dropped sudo but two privileged test groups were
never tagged, and the Docker privileged job silently never ran the suite:

- Gate the ssh/server PrivilegeDropper command-construction tests behind
  the privileged tag (they require root to target a different UID); split
  them into executor_unix_privileged_test.go.
- Tag sharedsock raw-socket tests privileged (need CAP_NET_RAW).
- Fix the Docker job command: nested single quotes around the build tags
  closed the sh -c wrapper early, dropping the go list package set and the
  privileged tag, so go test ran on the empty repo root. Use double quotes.

Make the self-hosting harness usable from a dev Mac:

- Build it on darwin as well as linux; it only drives Docker.
- Resolve the active docker context endpoint into DOCKER_HOST when the
  default /var/run/docker.sock is absent (Docker Desktop, Colima, OrbStack).
- Rename the misspelled containerGoModache constant to containerGoModCache.

* Update client/internal/engine_privileged_test.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Update client/internal/routemanager/systemops/systemops_linux_test.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Update client/internal/routemanager/systemops/systemops_windows_test.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* Update client/server/server_privileged_test.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* [ci] Run privileged-tagged tests on darwin, windows and freebsd

The privileged build tag split moved root/system-mutating tests behind
//go:build privileged, but only the linux docker job was given the tag.
The native darwin (sudo), windows (PsExec64 -s) and freebsd VM runners
already have the required privileges, so add the privileged tag there too
to keep CI running the same set of tests as before the split.

* [ci] Exclude dockertest harness from the darwin privileged run

The privileged tag now compiles client/testutil/privileged on darwin, whose
TestRunPrivilegedSuiteInDocker spawns a container the macOS runner has no
Docker for. Exclude the harness package from the darwin list, matching the
linux job, so the privileged tests run in place without a container spawn.

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
2026-06-28 16:15:54 +02:00
Viktor Liu
5968cff242 [client] Keep signal stream alive while receive loop is blocked on worker handoff (#6530) 2026-06-28 15:33:30 +02:00
dependabot[bot]
cf43841b86 Bump the actions group across 1 directory with 4 updates (#6550)
Bumps the actions group with 4 updates in the / directory: [actions/setup-go](https://github.com/actions/setup-go), [actions/cache](https://github.com/actions/cache), [actions/cache/restore](https://github.com/actions/cache) and [actions/setup-java](https://github.com/actions/setup-java).


Updates `actions/setup-go` from 6.4.0 to 6.5.0
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](4a3601121d...924ae3a1cd)

Updates `actions/cache` from 5.0.5 to 6.0.0
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](27d5ce7f10...2c8a9bd745)

Updates `actions/cache/restore` from 5.0.5 to 6.0.0
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](27d5ce7f10...2c8a9bd745)

Updates `actions/setup-java` from 5.3.0 to 5.4.0
- [Release notes](https://github.com/actions/setup-java/releases)
- [Commits](ad2b38190b...1bcf9fb12c)

---
updated-dependencies:
- dependency-name: actions/setup-go
  dependency-version: 6.5.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: actions
- dependency-name: actions/cache
  dependency-version: 6.0.0
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: actions
- dependency-name: actions/cache/restore
  dependency-version: 6.0.0
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: actions
- dependency-name: actions/setup-java
  dependency-version: 5.4.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-28 15:00:05 +02:00
Maycon Santos
739e36a313 [self-hosted] Add agent-network preset with dedicated configurations (#6569) 2026-06-28 14:56:42 +02:00
Riccardo Manfrin
2bb5421631 These logs are needed for troubleshooting (debug) (#6565) 2026-06-28 14:52:41 +02:00
MAAZIZ Adel Ayoub
998ade6e6d [client] fix nil pointer panic when applying SSH server setting to an existing config (#6556) 2026-06-28 14:51:21 +02:00
Zoltan Papp
62f5467cd8 [client] Eliminate packet loss during lazy connections. (#6355)
* [client] Remove peer deletion on lazy activity detection

Updated WireGuard dependency with a patch and removed the RemovePeer
call on lazy activity detection to force a new handshake initiation
to the updated endpoint. This also flushed the staged queue, dropping
the first packet.

Since UpdatePeer (called after ICE/relay negotiation) triggers
SendStagedPackets via IpcSet/handlePostConfig, the peer removal is
no longer necessary. The staged packet survives and the handshake
is initiated on the real endpoint automatically.

This also eliminates the transient state where the peer's endpoint
and routes were absent between the lazy idle and connected states.

* Update WireGuard dependency

* Update WireGuard dependencies

* Update WireGuard dependency
2026-06-28 14:22:19 +02:00
Zoltan Papp
1b29995ece [client] Fix blocked status lock via relay manager path (#6547)
* peer/status: move relay-state reads off the main mux

GetRelayStates held d.mux (RLock) while calling into the relay
Manager (RelayStates/RelayConnectError/ServerURLs). Those calls can be
slow or block on the relay manager's own locks while it is reconnecting,
which kept the central Status mutex held and stalled every peer state
writer (UpdatePeerState, ReplaceOfflinePeers, etc.) contending for it.

Guard relayMgr/relayStates with a dedicated muxRelays mutex and release
it before invoking the relay Manager, so the relay read path no longer
contends with the hot peer-state writers on d.mux.

* peer/status: clone relay states in nil-manager path

Return a cloned snapshot of d.relayStates when relayMgr is nil so callers
cannot mutate the shared cached state, matching the non-nil path.
2026-06-28 12:45:33 +02:00
Zoltan Papp
fd96b8c12f [client] Improve network addresses filter (#6515)
* [client] Filter link-local and multicast from network addresses

Skip IPv6 link-local and multicast addresses when building the peer
network_addresses list on non-iOS platforms, matching the existing iOS
behavior. A flapping NIC's link-local address otherwise churns the peer
meta on every interface up/down.

* [client] Skip engine restart when default route is unchanged

After the network monitor's debounce window, re-check the default next
hop before triggering a client restart. A flapping NIC that returns to
the same default route no longer forces a restart, avoiding redundant
sync stream reconnects and peer meta churn.

* [client] Exclude own overlay address from reported network addresses

The peer's own WireGuard overlay address (v4 and v6) was reported in
network_addresses. As the interface comes and goes during reconnects it
churned the peer meta on the management server. Drop it in
GetInfoWithChecks, matching the IP regardless of prefix length since the
engine knows the overlay address with the network mask while the
interface reports it as a host address.

* [client] Treat missing default route per protocol in next-hop check

A failed GetNextHop lookup is now treated as an absent route (zero
Nexthop) and compared per protocol, instead of forcing a restart. In a
single-stack network the missing IPv6 default route no longer counts as
a change on every debounce, which previously defeated the unchanged-route
check.

* [client] Make next-hop check injectable for network monitor tests

Move the next-hop comparison behind a NetworkMonitor field set by New(),
so tests can supply a stub instead of hitting the host's real default
route. Fixes the Event/MultiEvent tests hanging after the unchanged-route
check was added.

* Revert "[client] Make next-hop check injectable for network monitor tests"

This reverts commit 88a9d96e8f.

* Revert "[client] Treat missing default route per protocol in next-hop check"

This reverts commit 0fb531e4bc.

* Revert "[client] Skip engine restart when default route is unchanged"

This reverts commit a071b55f35.
2026-06-28 12:44:40 +02:00
Misha Bragin
6dd6c3f398 [Doc] Point Agent Network banner to netbird.ai (#6564) 2026-06-28 12:20:55 +02:00
Misha Bragin
d1422dcf09 [misc] Add agent-network readme (#6562) 2026-06-27 23:00:41 +02:00
dmitri-netbird
615631567a small gh workflow fixes (#6546)
Signed-off-by: Dmitri Dolguikh <dmitri.external@netbird.io>
2026-06-26 19:59:15 +02:00
Pascal Fischer
f4daf59bcd [management] bring back client version check on login filter hash (#6552) 2026-06-26 16:36:50 +02:00
154 changed files with 5727 additions and 12125 deletions

View File

@@ -64,7 +64,7 @@ jobs:
persist-credentials: false
- name: Set up Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: true

View File

@@ -21,13 +21,13 @@ jobs:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: ~/go/pkg/mod
key: macos-gotest-${{ hashFiles('**/go.sum') }}
@@ -45,7 +45,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags 'devcert privileged' -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/testutil/privileged)
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0

View File

@@ -48,14 +48,14 @@ jobs:
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
time go build -o netbird client/main.go
# check all component except management, since we do not support management server on freebsd
time go test -timeout 1m -failfast ./base62/...
time go test -tags privileged -timeout 1m -failfast ./base62/...
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
time go test -timeout 8m -failfast -v -p 1 ./client/...
time go test -timeout 1m -failfast ./dns/...
time go test -timeout 1m -failfast ./encryption/...
time go test -timeout 1m -failfast ./formatter/...
time go test -timeout 1m -failfast ./client/iface/...
time go test -timeout 1m -failfast ./route/...
time go test -timeout 1m -failfast ./sharedsock/...
time go test -timeout 1m -failfast ./util/...
time go test -timeout 1m -failfast ./version/...
time go test -tags privileged -timeout 8m -failfast -v -p 1 ./client/...
time go test -tags privileged -timeout 1m -failfast ./dns/...
time go test -tags privileged -timeout 1m -failfast ./encryption/...
time go test -tags privileged -timeout 1m -failfast ./formatter/...
time go test -tags privileged -timeout 1m -failfast ./client/iface/...
time go test -tags privileged -timeout 1m -failfast ./route/...
time go test -tags privileged -timeout 1m -failfast ./sharedsock/...
time go test -tags privileged -timeout 1m -failfast ./util/...
time go test -tags privileged -timeout 1m -failfast ./version/...

View File

@@ -30,7 +30,7 @@ jobs:
- 'management/**'
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
@@ -41,7 +41,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
id: cache
with:
path: |
@@ -124,7 +124,7 @@ jobs:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
@@ -135,7 +135,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: |
${{ env.cache }}
@@ -158,7 +158,7 @@ jobs:
run: git --no-pager diff --exit-code
- name: Test
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64'
@@ -180,7 +180,7 @@ jobs:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
@@ -192,7 +192,7 @@ jobs:
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
id: cache-restore
with:
path: |
@@ -229,7 +229,7 @@ jobs:
sh -c ' \
apk update; apk add --no-cache \
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
go test -buildvcs=false -tags "devcert privileged" -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server -e /client/testutil/privileged)
'
test_relay:
@@ -251,7 +251,7 @@ jobs:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
@@ -266,7 +266,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: |
${{ env.cache }}
@@ -311,7 +311,7 @@ jobs:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
@@ -325,7 +325,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: |
${{ env.cache }}
@@ -368,7 +368,7 @@ jobs:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
@@ -383,7 +383,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: |
${{ env.cache }}
@@ -429,7 +429,7 @@ jobs:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
@@ -440,7 +440,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: |
${{ env.cache }}
@@ -534,7 +534,7 @@ jobs:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
@@ -545,7 +545,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: |
${{ env.cache }}
@@ -579,10 +579,11 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags devcert -run=^$ -bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/... ./shared/management/... $(go list ./management/... ./shared/management/... | grep -v -e /management/server/http)
env:
GIT_BRANCH: ${{ github.ref_name }}
api_benchmark:
name: "Management / Benchmark (API)"
@@ -628,7 +629,7 @@ jobs:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
@@ -639,7 +640,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: |
${{ env.cache }}
@@ -673,12 +674,13 @@ jobs:
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
CI=true \
GIT_BRANCH=${{ github.ref_name }} \
go test -tags=benchmark \
-run=^$ \
-bench=. \
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
-timeout 20m ./management/server/http/...
env:
GIT_BRANCH: ${{ github.ref_name }}
api_integration_test:
name: "Management / Integration"
@@ -697,7 +699,7 @@ jobs:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
@@ -708,7 +710,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: |
${{ env.cache }}

View File

@@ -23,7 +23,7 @@ jobs:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
id: go
with:
go-version-file: "go.mod"
@@ -35,7 +35,7 @@ jobs:
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: |
${{ env.cache }}
@@ -68,7 +68,7 @@ jobs:
run: |
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
$cmd = "$goExe test -tags `"devcert privileged`" -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
- name: test

View File

@@ -37,7 +37,7 @@ jobs:
display_name: Linux
name: ${{ matrix.display_name }}
runs-on: ${{ matrix.os }}
timeout-minutes: 25
timeout-minutes: 15
steps:
- name: Checkout code
uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
@@ -48,7 +48,7 @@ jobs:
run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
@@ -62,4 +62,4 @@ jobs:
skip-cache: true
skip-save-cache: true
cache-invalidation-interval: 0
args: --timeout=20m
args: --timeout=12m

View File

@@ -20,7 +20,7 @@ jobs:
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
- name: Setup Android SDK
@@ -28,13 +28,13 @@ jobs:
with:
cmdline-tools-version: 8512546
- name: Setup Java
uses: actions/setup-java@ad2b38190b15e4d6bdf0c97fb4fca8412226d287
uses: actions/setup-java@1bcf9fb12cf4aa7d266a90ae39939e61372fe520
with:
java-version: "11"
distribution: "adopt"
- name: NDK Cache
id: ndk-cache
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: /usr/local/lib/android/sdk/ndk
key: ndk-cache-23.1.7779620
@@ -58,7 +58,7 @@ jobs:
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
- name: install gomobile

View File

@@ -166,12 +166,12 @@ jobs:
fi
- name: Set up Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache/restore@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache/restore@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: |
~/go/pkg/mod
@@ -374,12 +374,12 @@ jobs:
fi
- name: Set up Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: |
~/go/pkg/mod
@@ -469,12 +469,12 @@ jobs:
fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false
- name: Set up Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
cache: false
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: |
~/go/pkg/mod

View File

@@ -73,12 +73,12 @@ jobs:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
- name: Cache Go modules
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
uses: actions/cache@2c8a9bd7457de244a408f35966fab2fb45fda9c8 # v6.0.0
with:
path: ~/go/pkg/mod
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}

View File

@@ -23,7 +23,7 @@ jobs:
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
- name: Install dependencies
@@ -48,7 +48,7 @@ jobs:
with:
persist-credentials: false
- name: Install Go
uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
uses: actions/setup-go@924ae3a1cded613372ab5595356fb5720e22ba16 # v6.5.0
with:
go-version-file: "go.mod"
- name: Build Wasm client

View File

@@ -1,4 +1,4 @@
.PHONY: lint lint-all lint-install setup-hooks
.PHONY: lint lint-all lint-install setup-hooks test-unit test-privileged
GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
# Install golangci-lint locally if needed
@@ -25,3 +25,15 @@ setup-hooks:
@git config core.hooksPath .githooks
@chmod +x .githooks/pre-push
@echo "✅ Git hooks configured! Pre-push will now run 'make lint'"
# Host-safe unit tests: excludes the privileged-tagged tests (root / system-mutating).
# Runs as a normal user with no sudo and leaves host networking untouched.
test-unit:
@go test -tags devcert -timeout 10m ./...
# Privileged suite: runs the `privileged`-tagged tests inside a --privileged
# --cap-add=NET_ADMIN container via the ory/dockertest harness. Requires Docker.
# Narrow the run with env vars, e.g.:
# PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
test-privileged:
@go test -tags 'devcert privileged' -timeout 30m -run TestRunPrivilegedSuiteInDocker -v ./client/testutil/privileged/...

View File

@@ -33,10 +33,15 @@
<br/>
<br/>
<strong>
🚀 <a href="https://careers.netbird.io">We are hiring! Join us at careers.netbird.io</a>
🚀 <a href="https://netbird.io/careers">We are hiring! Join us at https://netbird.io/careers</a>
</strong>
</p>
> ### 🤖 NetBird Agent Network (Beta)
> Identity-aware access control for AI agents — keyless access to LLM APIs and private
> resources over the encrypted NetBird tunnel. See [`agent-network/`](agent-network/) or
> read the docs at **[netbird.ai](https://netbird.ai)**.
**NetBird combines a configuration-free peer-to-peer private network and a centralized access control system in a single platform, making it easy to create secure private networks for your organization or home.**
**Connect.** NetBird creates a WireGuard-based overlay network that automatically connects your machines over an encrypted tunnel, leaving behind the hassle of opening ports, complex firewall rules, VPN gateways, and so forth.

39
agent-network/README.md Normal file
View File

@@ -0,0 +1,39 @@
# NetBird Agent Network
Agent Network is NetBird's access control layer for AI agents and the people who run
them. It gives every agent a real identity, tied to your identity provider (IdP), and
governs what it can reach — the LLM APIs and AI gateways it can call, and the internal
resources it can access. Traffic flows only over the encrypted NetBird tunnel, scoped by
policy, with no API keys to leak.
> **Beta.** Agent Network is open source and can be self-hosted on your own
> infrastructure.
## How it works
Agent Network is built on two existing NetBird capabilities:
- **Overlay network** — the encrypted WireGuard mesh between peers.
- **Reverse proxy** — a NetBird peer that terminates LLM requests, establishes the
caller's identity, evaluates policies/limits/guardrails, injects the upstream provider
key server-side, forwards to the API or gateway, and records usage.
LLM traffic is routed through the proxy's identity-aware pipeline, while internal
resources (databases, internal APIs, self-hosted models) are reached directly over
peer-to-peer WireGuard tunnels, governed by the same identities and access policies.
## Where the code lives
There is no separate "agent-network" service — it reuses the reverse-proxy and management
components:
- [`proxy/`](../proxy) — the NetBird reverse proxy that serves the agent network endpoint
and runs the per-request middleware pipeline.
- [`management/internals/modules/reverseproxy/`](../management/internals/modules/reverseproxy)
— the management-side control plane: providers, policies, guardrails, limits, routing,
and usage/access logs.
## Documentation
Full documentation, architecture, and quickstart:
**https://docs.netbird.io/agent-network**

View File

@@ -0,0 +1,196 @@
//go:build privileged
package cmd
import (
"context"
"fmt"
"os"
"runtime"
"testing"
"time"
"github.com/kardianos/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
statusPollInterval = 500 * time.Millisecond
)
// waitForServiceStatus waits for service to reach expected status with timeout
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
cfg, err := newSVCConfig()
if err != nil {
return false, err
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
return false, err
}
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
defer timeoutCancel()
ticker := time.NewTicker(statusPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
case <-ticker.C:
status, err := s.Status()
if err != nil {
// Continue polling on transient errors
continue
}
if status == expectedStatus {
return true, nil
}
}
}
}
// TestServiceLifecycle tests the complete service lifecycle
func TestServiceLifecycle(t *testing.T) {
// TODO: Add support for Windows and macOS
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
}
if os.Getenv("CONTAINER") == "true" {
t.Skip("Skipping service lifecycle test in container environment")
}
originalServiceName := serviceName
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
defer func() {
serviceName = originalServiceName
}()
tempDir := t.TempDir()
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {
installCmd.SetContext(ctx)
err := installCmd.RunE(installCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
status, err := s.Status()
assert.NoError(t, err)
assert.NotEqual(t, service.StatusUnknown, status)
})
t.Run("Start", func(t *testing.T) {
startCmd.SetContext(ctx)
err := startCmd.RunE(startCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Restart", func(t *testing.T) {
restartCmd.SetContext(ctx)
err := restartCmd.RunE(restartCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Reconfigure", func(t *testing.T) {
originalLogLevel := logLevel
logLevel = "debug"
defer func() {
logLevel = originalLogLevel
}()
reconfigureCmd.SetContext(ctx)
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Stop", func(t *testing.T) {
stopCmd.SetContext(ctx)
err := stopCmd.RunE(stopCmd, []string{})
require.NoError(t, err)
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
require.NoError(t, err)
assert.True(t, stopped)
})
t.Run("Uninstall", func(t *testing.T) {
uninstallCmd.SetContext(ctx)
err := uninstallCmd.RunE(uninstallCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
_, err = s.Status()
assert.Error(t, err)
})
}

View File

@@ -1,16 +1,12 @@
package cmd
import (
"context"
"fmt"
"os"
"os/signal"
"runtime"
"syscall"
"testing"
"time"
"github.com/kardianos/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -31,186 +27,6 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
const (
serviceStartTimeout = 10 * time.Second
serviceStopTimeout = 5 * time.Second
statusPollInterval = 500 * time.Millisecond
)
// waitForServiceStatus waits for service to reach expected status with timeout
func waitForServiceStatus(expectedStatus service.Status, timeout time.Duration) (bool, error) {
cfg, err := newSVCConfig()
if err != nil {
return false, err
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
return false, err
}
ctx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
defer timeoutCancel()
ticker := time.NewTicker(statusPollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return false, fmt.Errorf("timeout waiting for service status %v", expectedStatus)
case <-ticker.C:
status, err := s.Status()
if err != nil {
// Continue polling on transient errors
continue
}
if status == expectedStatus {
return true, nil
}
}
}
}
// TestServiceLifecycle tests the complete service lifecycle
func TestServiceLifecycle(t *testing.T) {
// TODO: Add support for Windows and macOS
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
t.Skipf("Skipping service lifecycle test on unsupported OS: %s", runtime.GOOS)
}
if os.Getenv("CONTAINER") == "true" {
t.Skip("Skipping service lifecycle test in container environment")
}
originalServiceName := serviceName
serviceName = "netbirdtest" + fmt.Sprintf("%d", time.Now().Unix())
defer func() {
serviceName = originalServiceName
}()
tempDir := t.TempDir()
configPath = fmt.Sprintf("%s/netbird-test-config.json", tempDir)
logLevel = "info"
daemonAddr = fmt.Sprintf("unix://%s/netbird-test.sock", tempDir)
// Ensure cleanup even if a subtest fails and Stop/Uninstall subtests don't run.
t.Cleanup(func() {
cfg, err := newSVCConfig()
if err != nil {
t.Errorf("cleanup: create service config: %v", err)
return
}
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
if err != nil {
t.Errorf("cleanup: create service: %v", err)
return
}
// If the subtests already cleaned up, there's nothing to do.
if _, err := s.Status(); err != nil {
return
}
if err := s.Stop(); err != nil {
t.Errorf("cleanup: stop service: %v", err)
}
if err := s.Uninstall(); err != nil {
t.Errorf("cleanup: uninstall service: %v", err)
}
})
ctx := context.Background()
t.Run("Install", func(t *testing.T) {
installCmd.SetContext(ctx)
err := installCmd.RunE(installCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
status, err := s.Status()
assert.NoError(t, err)
assert.NotEqual(t, service.StatusUnknown, status)
})
t.Run("Start", func(t *testing.T) {
startCmd.SetContext(ctx)
err := startCmd.RunE(startCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Restart", func(t *testing.T) {
restartCmd.SetContext(ctx)
err := restartCmd.RunE(restartCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Reconfigure", func(t *testing.T) {
originalLogLevel := logLevel
logLevel = "debug"
defer func() {
logLevel = originalLogLevel
}()
reconfigureCmd.SetContext(ctx)
err := reconfigureCmd.RunE(reconfigureCmd, []string{})
require.NoError(t, err)
running, err := waitForServiceStatus(service.StatusRunning, serviceStartTimeout)
require.NoError(t, err)
assert.True(t, running)
})
t.Run("Stop", func(t *testing.T) {
stopCmd.SetContext(ctx)
err := stopCmd.RunE(stopCmd, []string{})
require.NoError(t, err)
stopped, err := waitForServiceStatus(service.StatusStopped, serviceStopTimeout)
require.NoError(t, err)
assert.True(t, stopped)
})
t.Run("Uninstall", func(t *testing.T) {
uninstallCmd.SetContext(ctx)
err := uninstallCmd.RunE(uninstallCmd, []string{})
require.NoError(t, err)
cfg, err := newSVCConfig()
require.NoError(t, err)
ctxSvc, cancel := context.WithCancel(context.Background())
defer cancel()
s, err := newSVC(newProgram(ctxSvc, cancel), cfg)
require.NoError(t, err)
_, err = s.Status()
assert.Error(t, err)
})
}
// TestServiceEnvVars tests environment variable parsing
func TestServiceEnvVars(t *testing.T) {
tests := []struct {

View File

@@ -1,3 +1,5 @@
//go:build privileged
package iptables
import (

View File

@@ -1,4 +1,4 @@
//go:build !android
//go:build !android && privileged
package iptables

View File

@@ -1,3 +1,5 @@
//go:build privileged
package nftables
import (

View File

@@ -1,4 +1,4 @@
//go:build !android
//go:build !android && privileged
package nftables

View File

@@ -1,3 +1,5 @@
//go:build privileged
package iface
import (

View File

@@ -1,4 +1,4 @@
//go:build linux && !android
//go:build linux && !android && privileged
package wgproxy

View File

@@ -1,4 +1,4 @@
//go:build !linux
//go:build !linux || !privileged
package wgproxy

View File

@@ -1,4 +1,4 @@
//go:build linux && !android
//go:build linux && !android && privileged
package wgproxy
@@ -26,64 +26,6 @@ func compareUDPAddr(addr1, addr2 net.Addr) bool {
return udpAddr1.IP.Equal(udpAddr2.IP) && udpAddr1.Port == udpAddr2.Port
}
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
wgPort := 51850
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
wgPort := 51851
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_UDP_IPv4 tests RedirectAs with UDP proxy using IPv4 addresses
func TestRedirectAs_UDP_IPv4(t *testing.T) {
wgPort := 51852
@@ -256,6 +198,64 @@ func testRedirectAs(t *testing.T, proxy Proxy, wgPort int, nbAddr, p2pEndpoint *
}
}
// TestRedirectAs_eBPF_IPv4 tests RedirectAs with eBPF proxy using IPv4 addresses
func TestRedirectAs_eBPF_IPv4(t *testing.T) {
wgPort := 51850
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("192.168.0.56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_eBPF_IPv6 tests RedirectAs with eBPF proxy using IPv6 addresses
func TestRedirectAs_eBPF_IPv6(t *testing.T) {
wgPort := 51851
ebpfProxy := ebpf.NewWGEBPFProxy(wgPort, 1280)
if err := ebpfProxy.Listen(); err != nil {
t.Fatalf("failed to initialize ebpf proxy: %v", err)
}
defer func() {
if err := ebpfProxy.Free(); err != nil {
t.Errorf("failed to free ebpf proxy: %v", err)
}
}()
proxy := ebpf.NewProxyWrapper(ebpfProxy)
// NetBird UDP address of the remote peer
nbAddr := &net.UDPAddr{
IP: net.ParseIP("100.108.111.177"),
Port: 38746,
}
p2pEndpoint := &net.UDPAddr{
IP: net.ParseIP("fe80::56"),
Port: 51820,
}
testRedirectAs(t, proxy, wgPort, nbAddr, p2pEndpoint)
}
// TestRedirectAs_Multiple_Switches tests switching between multiple endpoints
func TestRedirectAs_Multiple_Switches(t *testing.T) {
wgPort := 51856

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/hashicorp/go-multierror"
"github.com/mitchellh/hashstructure/v2"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
@@ -30,11 +31,13 @@ type Manager interface {
// DefaultManager uses firewall manager to handle
type DefaultManager struct {
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]struct{}
mutex sync.Mutex
firewall firewall.Manager
ipsetCounter int
peerRulesPairs map[id.RuleID][]firewall.Rule
routeRules map[id.RuleID]struct{}
previousConfigHash uint64
hasAppliedConfig bool
mutex sync.Mutex
}
func NewDefaultManager(fm firewall.Manager) *DefaultManager {
@@ -57,6 +60,23 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
return
}
// Skip the full rebuild + flush when the inputs that drive the firewall
// state are byte-for-byte identical to the last successfully applied
// update. Management re-sends the same network map far more often than it
// actually changes (account-wide updates, peer meta churn), and rebuilding
// every peer/route ACL and flushing the firewall on every such sync is the
// dominant client-side cost when nothing changed. Mirrors the same guard the
// DNS server already uses (previousConfigHash). Only the fields ApplyFiltering
// consumes participate in the hash, so an unrelated map change cannot mask a
// real ACL change.
hash, err := d.firewallConfigHash(networkMap, dnsRouteFeatureFlag)
if err != nil {
log.Errorf("unable to hash firewall configuration, applying unconditionally: %v", err)
} else if d.hasAppliedConfig && d.previousConfigHash == hash {
log.Debugf("not applying the firewall configuration update as there is nothing new (hash: %d)", hash)
return
}
start := time.Now()
defer func() {
total := 0
@@ -70,13 +90,49 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
d.applyPeerACLs(networkMap)
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
log.Errorf("Failed to apply route ACLs: %v", err)
routeErr := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag)
if routeErr != nil {
log.Errorf("Failed to apply route ACLs: %v", routeErr)
}
if err := d.firewall.Flush(); err != nil {
log.Error("failed to flush firewall rules: ", err)
flushErr := d.firewall.Flush()
if flushErr != nil {
log.Error("failed to flush firewall rules: ", flushErr)
}
// Only remember the hash once the firewall actually reflects this config.
// If applying or flushing failed, leave the previous hash untouched so the
// next (possibly identical) update is not skipped and gets a chance to
// reconcile the firewall state.
if err == nil && routeErr == nil && flushErr == nil {
d.previousConfigHash = hash
d.hasAppliedConfig = true
} else {
d.hasAppliedConfig = false
}
}
// firewallConfigHash hashes exactly the inputs ApplyFiltering uses to build the
// firewall state, so an identical hash means an identical resulting ruleset.
func (d *DefaultManager) firewallConfigHash(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) (uint64, error) {
return hashstructure.Hash(struct {
PeerRules []*mgmProto.FirewallRule
PeerRulesIsEmpty bool
RouteRules []*mgmProto.RouteFirewallRule
RouteRulesIsEmpty bool
DNSRouteFeatureFlag bool
}{
PeerRules: networkMap.GetFirewallRules(),
PeerRulesIsEmpty: networkMap.GetFirewallRulesIsEmpty(),
RouteRules: networkMap.GetRoutesFirewallRules(),
RouteRulesIsEmpty: networkMap.GetRoutesFirewallRulesIsEmpty(),
DNSRouteFeatureFlag: dnsRouteFeatureFlag,
}, hashstructure.FormatV2, &hashstructure.HashOptions{
ZeroNil: true,
IgnoreZeroValue: true,
SlicesAsSets: true,
UseStringer: true,
})
}
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {

View File

@@ -1,6 +1,7 @@
package acl
import (
"fmt"
"net/netip"
"testing"
@@ -485,3 +486,149 @@ func TestPortInfoEmpty(t *testing.T) {
})
}
}
// TestApplyFilteringSkipsUnchangedConfig verifies that an identical network map
// re-applied is recognized as a no-op (hash unchanged), while a real change to
// any firewall-relevant input forces a re-apply (hash changes). This is the
// guard that prevents a full ruleset rebuild + flush on every redundant sync.
func TestApplyFilteringSkipsUnchangedConfig(t *testing.T) {
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
t.Setenv(firewall.EnvForceUserspaceFirewall, "true")
ctrl := gomock.NewController(t)
defer ctrl.Finish()
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
ifaceMock.EXPECT().SetFilter(gomock.Any())
network := netip.MustParsePrefix("172.0.0.1/32")
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: network.Addr(),
Network: network,
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, fw.Close(nil))
}()
acl := NewDefaultManager(fw)
networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "22",
},
},
FirewallRulesIsEmpty: false,
}
acl.ApplyFiltering(networkMap, false)
require.True(t, acl.hasAppliedConfig, "config should be marked applied after first apply")
firstHash := acl.previousConfigHash
require.NotZero(t, firstHash)
// Re-applying the identical map must not change the recorded hash: the
// expensive rebuild path was skipped.
acl.ApplyFiltering(networkMap, false)
assert.Equal(t, firstHash, acl.previousConfigHash,
"identical re-apply must be a no-op (hash unchanged)")
// A real change must produce a different hash and re-apply.
networkMap.FirewallRules[0].Action = mgmProto.RuleAction_DROP
acl.ApplyFiltering(networkMap, false)
assert.NotEqual(t, firstHash, acl.previousConfigHash,
"changing a rule's action must force a re-apply (hash changed)")
// The dnsRouteFeatureFlag also participates in the hash.
changedHash := acl.previousConfigHash
acl.ApplyFiltering(networkMap, true)
assert.NotEqual(t, changedHash, acl.previousConfigHash,
"flipping dnsRouteFeatureFlag must force a re-apply (hash changed)")
}
func buildNetworkMap(peerRules, routeRules int) *mgmProto.NetworkMap {
nm := &mgmProto.NetworkMap{
FirewallRulesIsEmpty: peerRules == 0,
RoutesFirewallRulesIsEmpty: routeRules == 0,
}
for i := range peerRules {
nm.FirewallRules = append(nm.FirewallRules, &mgmProto.FirewallRule{
PeerIP: fmt.Sprintf("10.%d.%d.%d", i>>16&0xff, i>>8&0xff, i&0xff),
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: fmt.Sprintf("%d", 1024+i%64511),
})
}
for i := range routeRules {
nm.RoutesFirewallRules = append(nm.RoutesFirewallRules, &mgmProto.RouteFirewallRule{
Destination: fmt.Sprintf("192.168.%d.0/24", i%256),
SourceRanges: []string{fmt.Sprintf("10.0.%d.0/24", i%256)},
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_ALL,
})
}
return nm
}
func BenchmarkFirewallConfigHash_Small(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(10, 5)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
func BenchmarkFirewallConfigHash_Medium(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(100, 50)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
func BenchmarkFirewallConfigHash_Large(b *testing.B) {
d := &DefaultManager{}
nm := buildNetworkMap(1000, 200)
b.ResetTimer()
for b.Loop() {
_, _ = d.firewallConfigHash(nm, false)
}
}
// TestFirewallConfigHashDeterministic verifies the hash is stable for equal
// inputs and order-independent for the rule slices (management does not
// guarantee rule order).
func TestFirewallConfigHashDeterministic(t *testing.T) {
d := &DefaultManager{}
nm1 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
{PeerIP: "10.0.0.1", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_ACCEPT, Protocol: mgmProto.RuleProtocol_TCP, Port: "22"},
{PeerIP: "10.0.0.2", Direction: mgmProto.RuleDirection_IN, Action: mgmProto.RuleAction_DROP, Protocol: mgmProto.RuleProtocol_TCP, Port: "80"},
},
}
// Same rules, reversed order.
nm2 := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{
nm1.FirewallRules[1],
nm1.FirewallRules[0],
},
}
h1, err := d.firewallConfigHash(nm1, false)
require.NoError(t, err)
h2, err := d.firewallConfigHash(nm2, false)
require.NoError(t, err)
assert.Equal(t, h1, h2, "hash must be order-independent for rule slices")
}

View File

@@ -8,6 +8,7 @@ import (
"errors"
"net"
"net/netip"
"slices"
"strings"
"github.com/miekg/dns"
@@ -167,7 +168,10 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
case dns.TypeA:
alternativeNetwork = "ip6"
default:
return dns.RcodeNameError
// Non-address types reach LookupIP only unexpectedly; without an
// address pair to probe we cannot prove the name is absent, so answer
// NODATA rather than a poisoning NXDOMAIN.
return dns.RcodeSuccess
}
if _, err := r.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
@@ -184,6 +188,230 @@ func getRcodeForNotFound(ctx context.Context, r resolver, domain string, origina
return dns.RcodeSuccess
}
// RecordResolver is the host resolver surface used to forward non-address
// record queries. net.DefaultResolver satisfies it.
type RecordResolver interface {
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
LookupCNAME(ctx context.Context, host string) (string, error)
LookupAddr(ctx context.Context, addr string) ([]string, error)
}
// LookupRecords resolves a non-address DNS record type through the host
// resolver and returns the resource records and the DNS rcode. Types the host
// resolver cannot answer (anything not covered by the net.Resolver Lookup*
// methods) yield NODATA so that a routed name is never poisoned with NXDOMAIN
// for an unsupported type.
func LookupRecords(ctx context.Context, r RecordResolver, name string, qtype uint16, ttl uint32) ([]dns.RR, int) {
fqdn := dns.Fqdn(name)
switch qtype {
case dns.TypeMX:
return lookupMX(ctx, r, name, fqdn, ttl)
case dns.TypeTXT:
return lookupTXT(ctx, r, name, fqdn, ttl)
case dns.TypeNS:
return lookupNS(ctx, r, name, fqdn, ttl)
case dns.TypeSRV:
return lookupSRV(ctx, r, name, fqdn, ttl)
case dns.TypeCNAME:
return lookupCNAME(ctx, r, name, fqdn, ttl)
case dns.TypePTR:
return lookupPTR(ctx, r, name, fqdn, ttl)
default:
return nil, dns.RcodeSuccess
}
}
func recordHeader(fqdn string, rrtype uint16, ttl uint32) dns.RR_Header {
return dns.RR_Header{Name: fqdn, Rrtype: rrtype, Class: dns.ClassINET, Ttl: ttl}
}
func lookupMX(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupMX(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, mx := range recs {
rrs = append(rrs, &dns.MX{
Hdr: recordHeader(fqdn, dns.TypeMX, ttl),
Preference: mx.Pref,
Mx: dns.Fqdn(mx.Host),
})
}
return rrs, dns.RcodeSuccess
}
func lookupTXT(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupTXT(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, txt := range recs {
rrs = append(rrs, &dns.TXT{
Hdr: recordHeader(fqdn, dns.TypeTXT, ttl),
Txt: chunkTXT(txt),
})
}
return rrs, dns.RcodeSuccess
}
func lookupNS(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
recs, err := r.LookupNS(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, ns := range recs {
rrs = append(rrs, &dns.NS{
Hdr: recordHeader(fqdn, dns.TypeNS, ttl),
Ns: dns.Fqdn(ns.Host),
})
}
return rrs, dns.RcodeSuccess
}
func lookupSRV(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
_, recs, err := r.LookupSRV(ctx, "", "", name)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(recs))
for _, srv := range recs {
rrs = append(rrs, &dns.SRV{
Hdr: recordHeader(fqdn, dns.TypeSRV, ttl),
Priority: srv.Priority,
Weight: srv.Weight,
Port: srv.Port,
Target: dns.Fqdn(srv.Target),
})
}
return rrs, dns.RcodeSuccess
}
func lookupCNAME(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
cname, err := r.LookupCNAME(ctx, name)
if err != nil {
return nil, rcodeForRecordError(err)
}
// LookupCNAME returns the queried name itself when the name resolves but
// has no CNAME record; that is a NODATA result, not a CNAME.
if strings.EqualFold(dns.Fqdn(cname), fqdn) {
return nil, dns.RcodeSuccess
}
return []dns.RR{&dns.CNAME{
Hdr: recordHeader(fqdn, dns.TypeCNAME, ttl),
Target: dns.Fqdn(cname),
}}, dns.RcodeSuccess
}
func lookupPTR(ctx context.Context, r RecordResolver, name, fqdn string, ttl uint32) ([]dns.RR, int) {
addr, ok := ptrQueryAddr(name)
if !ok {
return nil, dns.RcodeSuccess
}
names, err := r.LookupAddr(ctx, addr)
if err != nil {
return nil, rcodeForRecordError(err)
}
rrs := make([]dns.RR, 0, len(names))
for _, n := range names {
rrs = append(rrs, &dns.PTR{
Hdr: recordHeader(fqdn, dns.TypePTR, ttl),
Ptr: dns.Fqdn(n),
})
}
return rrs, dns.RcodeSuccess
}
// ptrQueryAddr converts a reverse-DNS query name (in-addr.arpa or ip6.arpa)
// into the address string expected by net.Resolver.LookupAddr. It reports false
// when the name is not a well-formed reverse name.
func ptrQueryAddr(qname string) (string, bool) {
name := strings.TrimSuffix(strings.ToLower(dns.Fqdn(qname)), ".")
switch {
case strings.HasSuffix(name, ".in-addr.arpa"):
return parseInAddrArpa(strings.TrimSuffix(name, ".in-addr.arpa"))
case strings.HasSuffix(name, ".ip6.arpa"):
return parseIP6Arpa(strings.TrimSuffix(name, ".ip6.arpa"))
default:
return "", false
}
}
// parseInAddrArpa turns the label portion of an in-addr.arpa name into an IPv4
// address string, reporting false when it is not a well-formed reverse name.
func parseInAddrArpa(labelPart string) (string, bool) {
labels := strings.Split(labelPart, ".")
if len(labels) != 4 {
return "", false
}
slices.Reverse(labels)
addr, err := netip.ParseAddr(strings.Join(labels, "."))
if err != nil || !addr.Is4() {
return "", false
}
return addr.String(), true
}
// parseIP6Arpa turns the nibble portion of an ip6.arpa name into an IPv6
// address string, reporting false when it is not a well-formed reverse name.
func parseIP6Arpa(nibblePart string) (string, bool) {
nibbles := strings.Split(nibblePart, ".")
if len(nibbles) != 32 {
return "", false
}
slices.Reverse(nibbles)
var sb strings.Builder
for i, n := range nibbles {
if i > 0 && i%4 == 0 {
sb.WriteByte(':')
}
sb.WriteString(n)
}
addr, err := netip.ParseAddr(sb.String())
if err != nil || !addr.Is6() {
return "", false
}
return addr.String(), true
}
// rcodeForRecordError maps a non-address lookup error to a DNS rcode. A
// not-found result becomes NODATA rather than NXDOMAIN: net.DNSError.IsNotFound
// does not distinguish a missing name from a name that exists only with records
// of other types, so the name cannot be proven absent and must not be poisoned.
func rcodeForRecordError(err error) int {
var dnsErr *net.DNSError
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
return dns.RcodeSuccess
}
return dns.RcodeServerFailure
}
// chunkTXT splits a TXT string into character-strings no longer than 255 bytes
// so the record can be packed. The chunks form one TXT resource record.
func chunkTXT(s string) []string {
const maxLen = 255
if len(s) <= maxLen {
return []string{s}
}
var chunks []string
for len(s) > maxLen {
chunks = append(chunks, s[:maxLen])
s = s[maxLen:]
}
if len(s) > 0 {
chunks = append(chunks, s)
}
return chunks
}
// FormatAnswers formats DNS resource records for logging.
func FormatAnswers(answers []dns.RR) string {
if len(answers) == 0 {

View File

@@ -5,6 +5,7 @@ import (
"errors"
"net"
"net/netip"
"strings"
"testing"
"github.com/miekg/dns"
@@ -121,6 +122,164 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
}
func TestPtrQueryAddr(t *testing.T) {
tests := []struct {
name string
qname string
want string
wantOK bool
}{
{name: "ipv4", qname: "4.3.2.1.in-addr.arpa.", want: "1.2.3.4", wantOK: true},
{name: "ipv4 no trailing dot", qname: "1.0.0.127.in-addr.arpa", want: "127.0.0.1", wantOK: true},
{
name: "ipv6",
qname: "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.",
want: "2001:db8::1",
wantOK: true,
},
{name: "ipv4 wrong label count", qname: "2.1.in-addr.arpa.", wantOK: false},
{name: "ipv6 wrong nibble count", qname: "1.0.ip6.arpa.", wantOK: false},
{name: "not a reverse name", qname: "example.com.", wantOK: false},
{name: "ipv4 bad octet", qname: "4.3.2.999.in-addr.arpa.", wantOK: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := ptrQueryAddr(tt.qname)
assert.Equal(t, tt.wantOK, ok, "parse success mismatch")
if tt.wantOK {
assert.Equal(t, tt.want, got, "parsed address mismatch")
}
})
}
}
type mockRecordResolver struct {
mx []*net.MX
txt []string
ns []*net.NS
srv []*net.SRV
cname string
ptr []string
err error
}
func (m *mockRecordResolver) LookupMX(context.Context, string) ([]*net.MX, error) {
return m.mx, m.err
}
func (m *mockRecordResolver) LookupTXT(context.Context, string) ([]string, error) {
return m.txt, m.err
}
func (m *mockRecordResolver) LookupNS(context.Context, string) ([]*net.NS, error) {
return m.ns, m.err
}
func (m *mockRecordResolver) LookupSRV(context.Context, string, string, string) (string, []*net.SRV, error) {
return "", m.srv, m.err
}
func (m *mockRecordResolver) LookupCNAME(context.Context, string) (string, error) {
return m.cname, m.err
}
func (m *mockRecordResolver) LookupAddr(context.Context, string) ([]string, error) {
return m.ptr, m.err
}
func TestLookupRecords(t *testing.T) {
notFound := &net.DNSError{IsNotFound: true, Name: "example.com."}
t.Run("MX success", func(t *testing.T) {
r := &mockRecordResolver{mx: []*net.MX{{Host: "mail.example.com.", Pref: 10}}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "mail.example.com.", rrs[0].(*dns.MX).Mx)
})
t.Run("TXT short string is one character-string", func(t *testing.T) {
r := &mockRecordResolver{txt: []string{"v=spf1 -all"}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, []string{"v=spf1 -all"}, rrs[0].(*dns.TXT).Txt)
})
t.Run("TXT chunks long strings", func(t *testing.T) {
long := strings.Repeat("a", 300)
r := &mockRecordResolver{txt: []string{long}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeTXT, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
txt := rrs[0].(*dns.TXT).Txt
require.Len(t, txt, 2, "300-byte string should split into two character-strings")
assert.Equal(t, 255, len(txt[0]))
assert.Equal(t, 45, len(txt[1]))
})
t.Run("NS success", func(t *testing.T) {
r := &mockRecordResolver{ns: []*net.NS{{Host: "ns1.example.com."}}}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeNS, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "ns1.example.com.", rrs[0].(*dns.NS).Ns)
})
t.Run("SRV success", func(t *testing.T) {
r := &mockRecordResolver{srv: []*net.SRV{{Target: "sip.example.com.", Port: 5060}}}
rrs, rcode := LookupRecords(context.Background(), r, "_sip._tcp.example.com.", dns.TypeSRV, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, uint16(5060), rrs[0].(*dns.SRV).Port)
})
t.Run("CNAME success", func(t *testing.T) {
r := &mockRecordResolver{cname: "target.example.com."}
rrs, rcode := LookupRecords(context.Background(), r, "www.example.com.", dns.TypeCNAME, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "target.example.com.", rrs[0].(*dns.CNAME).Target)
})
t.Run("CNAME equal to name is NODATA", func(t *testing.T) {
r := &mockRecordResolver{cname: "example.com."}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCNAME, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs, "self-referential CNAME is NODATA")
})
t.Run("PTR success", func(t *testing.T) {
r := &mockRecordResolver{ptr: []string{"host.example.com."}}
rrs, rcode := LookupRecords(context.Background(), r, "4.3.2.1.in-addr.arpa.", dns.TypePTR, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
require.Len(t, rrs, 1)
assert.Equal(t, "host.example.com.", rrs[0].(*dns.PTR).Ptr)
})
t.Run("PTR malformed name is NODATA", func(t *testing.T) {
r := &mockRecordResolver{}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypePTR, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs)
})
t.Run("not found is NODATA never NXDOMAIN", func(t *testing.T) {
r := &mockRecordResolver{err: notFound}
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeSuccess, rcode, "missing record must not poison the name")
})
t.Run("server failure maps to SERVFAIL", func(t *testing.T) {
r := &mockRecordResolver{err: &net.DNSError{Err: "server misbehaving", IsTemporary: true}}
_, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeMX, 300)
assert.Equal(t, dns.RcodeServerFailure, rcode)
})
t.Run("unsupported type is NODATA", func(t *testing.T) {
r := &mockRecordResolver{}
rrs, rcode := LookupRecords(context.Background(), r, "example.com.", dns.TypeCAA, 300)
assert.Equal(t, dns.RcodeSuccess, rcode)
assert.Empty(t, rrs)
})
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{

View File

@@ -0,0 +1,485 @@
//go:build privileged
package dns
import (
"context"
"fmt"
"net/netip"
"os"
"testing"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbdns "github.com/netbirdio/netbird/dns"
)
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
testCases := []struct {
name string
initUpstreamMap []handlerWrapper
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap []handlerWrapper
expectedLocalQs []dns.Question
}{
{
name: "Initial Config Should Succeed",
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
priority: PriorityUpstream,
},
{
domain: "netbird.cloud",
priority: PriorityLocal,
},
{
domain: nbdns.RootZone,
priority: PriorityDefault,
},
},
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
},
{
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: "netbird.cloud",
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
priority: PriorityUpstream,
},
{
domain: "netbird.cloud",
priority: PriorityLocal,
},
},
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: []handlerWrapper{{
domain: ".",
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
privKey, _ := wgtypes.GenerateKey()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun230%d", n),
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
err = wgIface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
err = wgIface.Close()
if err != nil {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Fatal(err)
}
err = dnsServer.Initialize()
if err != nil {
t.Fatal(err)
}
defer func() {
err = dnsServer.hostManager.restoreHostDNS()
if err != nil {
t.Log(err)
}
}()
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
if testCase.shouldFail {
return
}
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
}
for _, expected := range testCase.expectedUpstreamMap {
found := false
for _, got := range dnsServer.dnsMuxHandlers {
if got.domain == expected.domain && got.priority == expected.priority {
found = true
break
}
}
if !found {
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
}
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
for _, q := range testCase.expectedLocalQs {
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
Question: []dns.Question{q},
})
}
if len(testCase.expectedLocalQs) > 0 {
assert.NotNil(t, responseMSG, "response message should not be nil")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
}
})
}
}
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
if err != nil {
t.Errorf("create stdnet: %v", err)
return
}
privKey, _ := wgtypes.GeneratePrivateKey()
opts := iface.WGIFaceOpts{
IFaceName: "utun2301",
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
}
err = wgIface.Create()
if err != nil {
t.Errorf("create and init wireguard interface: %v", err)
return
}
defer func() {
if err = wgIface.Close(); err != nil {
t.Logf("close wireguard interface: %v", err)
}
}()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
return
}
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Errorf("create DNS server: %v", err)
return
}
err = dnsServer.Initialize()
if err != nil {
t.Errorf("run DNS server: %v", err)
return
}
defer func() {
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
t.Logf("restore DNS settings on the host: %v", err)
return
}
}()
dnsServer.dnsMuxHandlers = []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityUpstream,
},
}
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
}
// Start the server with regular configuration
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update2 := update
update2.ServiceEnable = false
// Disable the server, stop the listener
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update3 := update2
update3.NameServerGroups = update3.NameServerGroups[:1]
// But service still get updates and we checking that we handle
// internal state in the right way
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
}

View File

@@ -10,7 +10,6 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -23,7 +22,6 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns/local"
"github.com/netbirdio/netbird/client/internal/dns/test"
@@ -104,466 +102,6 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger())
}
func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
testCases := []struct {
name string
initUpstreamMap []handlerWrapper
initLocalZones []nbdns.CustomZone
initSerial uint64
inputSerial uint64
inputUpdate nbdns.Config
shouldFail bool
expectedUpstreamMap []handlerWrapper
expectedLocalQs []dns.Question
}{
{
name: "Initial Config Should Succeed",
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
priority: PriorityUpstream,
},
{
domain: "netbird.cloud",
priority: PriorityLocal,
},
{
domain: nbdns.RootZone,
priority: PriorityDefault,
},
},
expectedLocalQs: []dns.Question{{Name: "peera.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}},
},
{
name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: "netbird.cloud",
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
},
},
expectedUpstreamMap: []handlerWrapper{
{
domain: "netbird.io",
priority: PriorityUpstream,
},
{
domain: "netbird.cloud",
priority: PriorityLocal,
},
},
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
},
{
name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 2,
inputSerial: 1,
shouldFail: true,
},
{
name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
},
},
},
shouldFail: true,
},
{
name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: nil,
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: nameServers,
Primary: true,
},
},
},
expectedUpstreamMap: []handlerWrapper{{
domain: ".",
priority: PriorityDefault,
}},
},
{
name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
{
name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &mockHandler{},
priority: PriorityUpstream,
},
},
initSerial: 0,
inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{},
},
}
for n, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
privKey, _ := wgtypes.GenerateKey()
newNet, err := stdnet.NewNet(context.Background(), nil)
if err != nil {
t.Fatal(err)
}
opts := iface.WGIFaceOpts{
IFaceName: fmt.Sprintf("utun230%d", n),
Address: wgaddr.MustParseWGAddress(fmt.Sprintf("100.66.100.%d/32", n+1)),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Fatal(err)
}
err = wgIface.Create()
if err != nil {
t.Fatal(err)
}
defer func() {
err = wgIface.Close()
if err != nil {
t.Log(err)
}
}()
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Fatal(err)
}
err = dnsServer.Initialize()
if err != nil {
t.Fatal(err)
}
defer func() {
err = dnsServer.hostManager.restoreHostDNS()
if err != nil {
t.Log(err)
}
}()
dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial
err = dnsServer.UpdateDNSServer(testCase.inputSerial, testCase.inputUpdate)
if err != nil {
if testCase.shouldFail {
return
}
t.Fatalf("update dns server should not fail, got error: %v", err)
}
if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
}
for _, expected := range testCase.expectedUpstreamMap {
found := false
for _, got := range dnsServer.dnsMuxHandlers {
if got.domain == expected.domain && got.priority == expected.priority {
found = true
break
}
}
if !found {
t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
}
}
var responseMSG *dns.Msg
responseWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
responseMSG = m
return nil
},
}
for _, q := range testCase.expectedLocalQs {
dnsServer.localResolver.ServeDNS(responseWriter, &dns.Msg{
Question: []dns.Question{q},
})
}
if len(testCase.expectedLocalQs) > 0 {
assert.NotNil(t, responseMSG, "response message should not be nil")
assert.Equal(t, dns.RcodeSuccess, responseMSG.Rcode, "response code should be success")
assert.NotEmpty(t, responseMSG.Answer, "response message should have answers")
}
})
}
}
func TestDNSFakeResolverHandleUpdates(t *testing.T) {
ov := os.Getenv("NB_WG_KERNEL_DISABLED")
defer t.Setenv("NB_WG_KERNEL_DISABLED", ov)
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
newNet, err := stdnet.NewNet(context.Background(), []string{"utun2301"})
if err != nil {
t.Errorf("create stdnet: %v", err)
return
}
privKey, _ := wgtypes.GeneratePrivateKey()
opts := iface.WGIFaceOpts{
IFaceName: "utun2301",
Address: wgaddr.MustParseWGAddress("100.66.100.1/32"),
WGPort: 33100,
WGPrivKey: privKey.String(),
MTU: iface.DefaultMTU,
TransportNet: newNet,
}
wgIface, err := iface.NewWGIFace(opts)
if err != nil {
t.Errorf("build interface wireguard: %v", err)
return
}
err = wgIface.Create()
if err != nil {
t.Errorf("create and init wireguard interface: %v", err)
return
}
defer func() {
if err = wgIface.Close(); err != nil {
t.Logf("close wireguard interface: %v", err)
}
}()
ctrl := gomock.NewController(t)
defer ctrl.Finish()
packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().SetTCPPacketHook(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
if err := wgIface.SetFilter(packetfilter); err != nil {
t.Errorf("set packet filter: %v", err)
return
}
dnsServer, err := NewDefaultServer(context.Background(), DefaultServerConfig{
WgInterface: wgIface,
CustomAddress: "",
StatusRecorder: peer.NewRecorder("mgm"),
StateManager: nil,
DisableSys: false,
})
if err != nil {
t.Errorf("create DNS server: %v", err)
return
}
err = dnsServer.Initialize()
if err != nil {
t.Errorf("run DNS server: %v", err)
return
}
defer func() {
if err = dnsServer.hostManager.restoreHostDNS(); err != nil {
t.Logf("restore DNS settings on the host: %v", err)
return
}
}()
dnsServer.dnsMuxHandlers = []handlerWrapper{
{
domain: zoneRecords[0].Name,
handler: &local.Resolver{},
priority: PriorityUpstream,
},
}
dnsServer.localResolver.Update([]nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}})
dnsServer.updateSerial = 0
nameServers := []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
{
IP: netip.MustParseAddr("8.8.4.4"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
}
update := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "netbird.cloud",
Records: zoneRecords,
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"netbird.io"},
NameServers: nameServers,
},
{
NameServers: nameServers,
Primary: true,
},
},
}
// Start the server with regular configuration
if err := dnsServer.UpdateDNSServer(1, update); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update2 := update
update2.ServiceEnable = false
// Disable the server, stop the listener
if err := dnsServer.UpdateDNSServer(2, update2); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
update3 := update2
update3.NameServerGroups = update3.NameServerGroups[:1]
// But service still get updates and we checking that we handle
// internal state in the right way
if err := dnsServer.UpdateDNSServer(3, update3); err != nil {
t.Fatalf("update dns server should not fail, got error: %v", err)
return
}
}
func TestDNSServerStartStop(t *testing.T) {
testCases := []struct {
name string

View File

@@ -37,6 +37,12 @@ const (
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
LookupMX(ctx context.Context, name string) ([]*net.MX, error)
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupNS(ctx context.Context, name string) ([]*net.NS, error)
LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error)
LookupCNAME(ctx context.Context, host string) (string, error)
LookupAddr(ctx context.Context, addr string) ([]string, error)
}
type firewaller interface {
@@ -210,12 +216,6 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
resp := query.SetReply(query)
network := resutil.NetworkForQtype(question.Qtype)
if network == "" {
resp.Rcode = dns.RcodeNotImplemented
f.writeResponse(logger, w, resp, qname, startTime)
return
}
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
if mostSpecificResId == "" {
@@ -227,9 +227,46 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel()
reqHasEdns := query.IsEdns0() != nil
switch question.Qtype {
case dns.TypeA, dns.TypeAAAA:
f.handleAddressQuery(ctx, logger, w, resp, mostSpecificResId, matchingEntries, reqHasEdns, startTime)
case dns.TypeMX, dns.TypeTXT, dns.TypeNS, dns.TypeSRV, dns.TypeCNAME, dns.TypePTR:
f.handleRecordQuery(ctx, logger, w, resp, startTime)
default:
// The domain is routed here, so any other type is answered NODATA
// (NOERROR, empty answer) rather than falling back to a resolver that
// would poison the name with NXDOMAIN. The Extended DNS Error lets a
// client tell this capability-driven NODATA apart from an
// authoritative one. The OPT pseudo-record must not appear unless the
// query advertised EDNS0.
if reqHasEdns {
attachEDE(resp, dns.ExtendedErrorCodeNotSupported, "netbird forwarder: unsupported query type")
}
f.writeResponse(logger, w, resp, qname, startTime)
}
}
// handleAddressQuery resolves A/AAAA queries, programs the firewall sets and
// resolved-IP state, and caches the answer for resilience on upstream failure.
func (f *DNSForwarder) handleAddressQuery(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
resp *dns.Msg,
mostSpecificResId route.ResID,
matchingEntries []*ForwarderEntry,
reqHasEdns bool,
startTime time.Time,
) {
question := resp.Question[0]
qname := strings.ToLower(question.Name)
network := resutil.NetworkForQtype(question.Qtype)
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
f.handleDNSError(ctx, logger, w, question, resp, qname, result, reqHasEdns, startTime)
return
}
@@ -240,6 +277,25 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
f.writeResponse(logger, w, resp, qname, startTime)
}
// handleRecordQuery resolves non-address record types (MX, TXT, NS, SRV,
// CNAME, PTR) through the host resolver. Missing records are answered NODATA so
// the routed name is never poisoned with NXDOMAIN.
func (f *DNSForwarder) handleRecordQuery(
ctx context.Context,
logger *log.Entry,
w dns.ResponseWriter,
resp *dns.Msg,
startTime time.Time,
) {
question := resp.Question[0]
qname := strings.ToLower(question.Name)
records, rcode := resutil.LookupRecords(ctx, f.resolver, qname, question.Qtype, f.ttl)
resp.Rcode = rcode
resp.Answer = append(resp.Answer, records...)
f.writeResponse(logger, w, resp, qname, startTime)
}
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed to write DNS response: %v", err)

View File

@@ -133,6 +133,41 @@ func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([
return args.Get(0).([]netip.Addr), args.Error(1)
}
func (m *MockResolver) LookupMX(ctx context.Context, name string) ([]*net.MX, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]*net.MX)
return recs, args.Error(1)
}
func (m *MockResolver) LookupTXT(ctx context.Context, name string) ([]string, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]string)
return recs, args.Error(1)
}
func (m *MockResolver) LookupNS(ctx context.Context, name string) ([]*net.NS, error) {
args := m.Called(ctx, name)
recs, _ := args.Get(0).([]*net.NS)
return recs, args.Error(1)
}
func (m *MockResolver) LookupSRV(ctx context.Context, service, proto, name string) (string, []*net.SRV, error) {
args := m.Called(ctx, service, proto, name)
recs, _ := args.Get(1).([]*net.SRV)
return args.String(0), recs, args.Error(2)
}
func (m *MockResolver) LookupCNAME(ctx context.Context, host string) (string, error) {
args := m.Called(ctx, host)
return args.String(0), args.Error(1)
}
func (m *MockResolver) LookupAddr(ctx context.Context, addr string) ([]string, error) {
args := m.Called(ctx, addr)
recs, _ := args.Get(0).([]string)
return recs, args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct {
name string
@@ -545,12 +580,15 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
}
func TestDNSForwarder_ResponseCodes(t *testing.T) {
// A type with no net.Resolver Lookup method (CAA) must answer NODATA
// (NOERROR, empty) rather than NXDOMAIN/NOTIMP to avoid poisoning the name.
tests := []struct {
name string
queryType uint16
queryDomain string
configured string
expectedCode int
expectEDE bool
description string
}{
{
@@ -562,28 +600,13 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
description: "RFC compliant REFUSED for unauthorized queries",
},
{
name: "unsupported query type returns NOTIMP",
queryType: dns.TypeMX,
name: "unsupported query type returns NODATA",
queryType: dns.TypeCAA,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "RFC compliant NOTIMP for unsupported types",
},
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
expectedCode: dns.RcodeSuccess,
expectEDE: true,
description: "Unsupported types answer NODATA, not NXDOMAIN/NOTIMP",
},
}
@@ -599,6 +622,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
query.SetEdns0(dns.DefaultMsgSize, false)
// Capture the written response
var writtenResp *dns.Msg
@@ -614,10 +638,213 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
assert.Empty(t, writtenResp.Answer, "Non-address response should carry no answers")
if tt.expectEDE {
require.NotNil(t, writtenResp.IsEdns0(), "EDNS0 client should get an OPT in the reply")
assert.True(t, hasEDE(writtenResp, dns.ExtendedErrorCodeNotSupported),
"unsupported type NODATA should carry EDE Not Supported")
}
})
}
}
func hasEDE(m *dns.Msg, code uint16) bool {
opt := m.IsEdns0()
if opt == nil {
return false
}
for _, o := range opt.Option {
if ede, ok := o.(*dns.EDNS0_EDE); ok && ede.InfoCode == code {
return true
}
}
return false
}
func TestDNSForwarder_RecordQueries(t *testing.T) {
notFound := &net.DNSError{IsNotFound: true, Name: "example.com"}
t.Run("MX records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupMX", mock.Anything, "example.com.").
Return([]*net.MX{{Host: "mail.example.com.", Pref: 10}}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
mx, ok := resp.Answer[0].(*dns.MX)
require.True(t, ok, "answer should be an MX record")
assert.Equal(t, uint16(10), mx.Preference)
assert.Equal(t, "mail.example.com.", mx.Mx)
mockResolver.AssertExpectations(t)
})
t.Run("missing MX is NODATA not NXDOMAIN", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
// A not-found cannot prove the name is absent (it may exist with only
// other record types), so it must answer NODATA, never NXDOMAIN.
mockResolver.On("LookupMX", mock.Anything, "example.com.").
Return(nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeMX)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode, "missing record must be NODATA")
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("NS records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupNS", mock.Anything, "example.com.").
Return([]*net.NS{{Host: "ns1.example.com."}}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ns, ok := resp.Answer[0].(*dns.NS)
require.True(t, ok, "answer should be an NS record")
assert.Equal(t, "ns1.example.com.", ns.Ns)
mockResolver.AssertExpectations(t)
})
t.Run("missing NS is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupNS", mock.Anything, "example.com.").
Return(nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeNS)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("SRV records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
Return("", []*net.SRV{{Target: "sip.example.com.", Port: 5060, Priority: 10, Weight: 5}}, nil).Once()
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
srv, ok := resp.Answer[0].(*dns.SRV)
require.True(t, ok, "answer should be an SRV record")
assert.Equal(t, "sip.example.com.", srv.Target)
assert.Equal(t, uint16(5060), srv.Port)
assert.Equal(t, uint16(10), srv.Priority)
mockResolver.AssertExpectations(t)
})
t.Run("missing SRV is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "_sip._tcp.example.com")
mockResolver.On("LookupSRV", mock.Anything, "", "", "_sip._tcp.example.com.").
Return("", nil, notFound).Once()
resp := runRecordQuery(t, forwarder, "_sip._tcp.example.com", dns.TypeSRV)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer)
mockResolver.AssertExpectations(t)
})
t.Run("TXT records are forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
mockResolver.On("LookupTXT", mock.Anything, "example.com.").
Return([]string{"v=spf1 -all"}, nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeTXT)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
txt, ok := resp.Answer[0].(*dns.TXT)
require.True(t, ok, "answer should be a TXT record")
assert.Equal(t, []string{"v=spf1 -all"}, txt.Txt)
mockResolver.AssertExpectations(t)
})
t.Run("CNAME record is forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "www.example.com")
mockResolver.On("LookupCNAME", mock.Anything, "www.example.com.").
Return("target.example.com.", nil).Once()
resp := runRecordQuery(t, forwarder, "www.example.com", dns.TypeCNAME)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
cname, ok := resp.Answer[0].(*dns.CNAME)
require.True(t, ok, "answer should be a CNAME record")
assert.Equal(t, "target.example.com.", cname.Target)
mockResolver.AssertExpectations(t)
})
t.Run("CNAME equal to the name is NODATA", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "example.com")
// No CNAME exists: LookupCNAME echoes the queried name back.
mockResolver.On("LookupCNAME", mock.Anything, "example.com.").
Return("example.com.", nil).Once()
resp := runRecordQuery(t, forwarder, "example.com", dns.TypeCNAME)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
assert.Empty(t, resp.Answer, "self-referential CNAME means no CNAME record")
mockResolver.AssertExpectations(t)
})
t.Run("PTR record is forwarded", func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := newRecordTestForwarder(t, mockResolver, "*.in-addr.arpa")
// The reverse name is parsed back to the address LookupAddr expects.
mockResolver.On("LookupAddr", mock.Anything, "1.2.3.4").
Return([]string{"host.example.com."}, nil).Once()
resp := runRecordQuery(t, forwarder, "4.3.2.1.in-addr.arpa", dns.TypePTR)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 1)
ptr, ok := resp.Answer[0].(*dns.PTR)
require.True(t, ok, "answer should be a PTR record")
assert.Equal(t, "host.example.com.", ptr.Ptr)
mockResolver.AssertExpectations(t)
})
}
func newRecordTestForwarder(t *testing.T, r resolver, configured string) *DNSForwarder {
t.Helper()
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = r
d, err := domain.FromString(configured)
require.NoError(t, err)
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
return forwarder
}
func runRecordQuery(t *testing.T, forwarder *DNSForwarder, qname string, qtype uint16) *dns.Msg {
t.Helper()
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(qname), qtype)
mockWriter := &test.MockResponseWriter{}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
resp := mockWriter.GetLastResponse()
require.NotNil(t, resp, "expected response to be written")
return resp
}
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
tests := []struct {
name string

View File

@@ -63,9 +63,7 @@ import (
"github.com/netbirdio/netbird/route"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
nbnetworkmap "github.com/netbirdio/netbird/shared/management/networkmap"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
types "github.com/netbirdio/netbird/shared/management/types"
"github.com/netbirdio/netbird/shared/netiputil"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -212,13 +210,6 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64
// latestComponents is the most-recent NetworkMapComponents decoded from
// a NetworkMapEnvelope (capability=3 peers only). Held alongside the
// NetworkMap that Calculate() produced from it so future incremental
// updates have a base to apply changes against. nil for legacy-format
// peers. Guarded by syncMsgMux.
latestComponents *types.NetworkMapComponents
networkMonitor *networkmonitor.NetworkMonitor
sshServer sshServer
@@ -904,6 +895,16 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
}
// phase times a sync sub-phase: it returns a function that records the elapsed
// duration when called. Starting the timer at the call site keeps inter-phase
// glue code out of the measurement.
func (e *Engine) phase(name string) func() {
start := time.Now()
return func() {
e.clientMetrics.RecordSyncPhase(e.ctx, name, time.Since(start))
}
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
@@ -919,71 +920,37 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return e.ctx.Err()
}
// Envelope sync responses carry PeerConfig at the top level; legacy
// NetworkMap syncs carry it under NetworkMap.PeerConfig.
if pc := update.GetPeerConfig(); pc != nil {
e.handleAutoUpdateVersion(pc.GetAutoUpdate())
} else if nm := update.GetNetworkMap(); nm != nil && nm.GetPeerConfig() != nil {
e.handleAutoUpdateVersion(nm.GetPeerConfig().GetAutoUpdate())
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
}
if err := e.updateNetbirdConfig(update.GetNetbirdConfig()); err != nil {
done := e.phase("netbird_config")
err := e.updateNetbirdConfig(update.GetNetbirdConfig())
done()
if err != nil {
return err
}
// Decode the network map from either the components envelope or the
// legacy proto.NetworkMap before the posture-check gating below, so the
// "is there a network map" decision covers both wire shapes.
var (
nm *mgmProto.NetworkMap
components *types.NetworkMapComponents
)
if envelope := update.GetNetworkMapEnvelope(); envelope != nil {
// Components-format peer: decode the envelope back to typed
// components, run Calculate() locally, and convert to the wire
// NetworkMap shape the rest of the engine consumes. Components are
// retained so future incremental updates can apply deltas instead
// of doing a full reconstruction.
localKey := e.config.WgPrivateKey.PublicKey().String()
dnsName := ""
if pc := update.GetPeerConfig(); pc != nil {
// PeerConfig.Fqdn = "<dns_label>.<dns_domain>" — extract the
// shared domain by stripping the peer's own label prefix. Falls
// back to empty if the FQDN doesn't have the expected shape.
dnsName = extractDNSDomainFromFQDN(pc.GetFqdn())
}
result, err := nbnetworkmap.EnvelopeToNetworkMap(e.ctx, envelope, localKey, dnsName)
if err != nil {
return fmt.Errorf("decode network map envelope: %w", err)
}
nm = result.NetworkMap
components = result.Components
} else {
nm = update.GetNetworkMap()
}
// Posture checks are bound to the network map presence:
// NetworkMap != nil, checks present -> apply the received checks
// NetworkMap != nil, checks nil -> posture checks were removed, clear them
// NetworkMap == nil -> config-only update (e.g. relay token rotation),
// leave the previously applied checks untouched
nm := update.GetNetworkMap()
if nm == nil {
return nil
}
if err := e.updateChecksIfNew(update.Checks); err != nil {
done = e.phase("checks")
err = e.updateChecksIfNew(update.Checks)
done()
if err != nil {
return err
}
// Only retain the components view when the server sent the envelope
// path. A legacy proto.NetworkMap means components == nil; writing it
// here would clobber a previously-cached snapshot, breaking the
// incremental-delta base on a future envelope sync.
if components != nil {
e.latestComponents = components
}
done = e.phase("persist")
e.persistSyncResponse(update)
done()
// only apply new changes and ignore old ones
if err := e.updateNetworkMap(nm); err != nil {
@@ -995,19 +962,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil
}
// extractDNSDomainFromFQDN returns the trailing dotted domain part of the
// receiving peer's FQDN — the same value the management server fills as
// dnsName when it builds the legacy NetworkMap. "peer42.netbird.cloud" →
// "netbird.cloud". An empty string is returned for unrecognized formats.
func extractDNSDomainFromFQDN(fqdn string) string {
for i := 0; i < len(fqdn); i++ {
if fqdn[i] == '.' && i+1 < len(fqdn) {
return fqdn[i+1:]
}
}
return ""
}
// updateNetbirdConfig applies the management-provided NetBird configuration:
// STUN/TURN and relay servers, flow logging and DNS settings. A nil config is a no-op,
// which is the case for sync updates carrying only a network map.
@@ -1130,7 +1084,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
}
e.checks = checks
info, err := system.GetInfoWithChecks(e.ctx, checks)
info, err := system.GetInfoWithChecks(e.ctx, checks, e.overlayAddresses()...)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
@@ -1161,6 +1115,20 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
return nil
}
// overlayAddresses returns our own WireGuard overlay address (v4 and v6) so it
// can be excluded from the reported network addresses; the interface coming and
// going otherwise churns the peer meta on the management server.
func (e *Engine) overlayAddresses() []netip.Addr {
var ips []netip.Addr
if e.config.WgAddr.IP.IsValid() {
ips = append(ips, e.config.WgAddr.IP)
}
if e.config.WgAddr.HasIPv6() {
ips = append(ips, e.config.WgAddr.IPv6)
}
return ips
}
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface == nil {
return errors.New("wireguard interface is not initialized")
@@ -1304,7 +1272,7 @@ func (e *Engine) receiveManagementEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
info, err := system.GetInfoWithChecks(e.ctx, e.checks, e.overlayAddresses()...)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
info = system.GetInfo(e.ctx)
@@ -1421,13 +1389,16 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address())
done := e.phase("dns_server")
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
log.Errorf("failed to update dns server, err: %v", err)
}
done()
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
// apply routes first, route related actions might depend on routing being enabled
done = e.phase("routes_classify")
routes := toRoutes(networkMap.GetRoutes())
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
@@ -1436,29 +1407,60 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.connMgr.UpdateRouteHAMap(clientRoutes)
log.Debugf("updated lazy connection manager with %d HA groups", len(clientRoutes))
}
done()
done = e.phase("routes_apply")
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
if err := e.routeManager.UpdateRoutes(serial, serverRoutes, clientRoutes, dnsRouteFeatureFlag); err != nil {
log.Errorf("failed to update routes: %v", err)
}
done()
done = e.phase("filtering")
if e.acl != nil {
e.acl.ApplyFiltering(networkMap, dnsRouteFeatureFlag)
}
done()
done = e.phase("dns_forwarder")
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
done()
// Ingress forward rules
done = e.phase("forward_rules")
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
if err != nil {
log.Errorf("failed to update forward rules, err: %v", err)
}
done()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
done = e.phase("offline_peers")
e.updateOfflinePeers(networkMap.GetOfflinePeers())
done()
remotePeers, err := e.reconcilePeers(networkMap)
if err != nil {
return err
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
done = e.phase("lazy_exclude")
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
done()
e.networkSerial = serial
return nil
}
// reconcilePeers applies the remote peer list from the network map (removing,
// modifying and adding peers, then updating SSH config) and returns the remote
// peers with our own peer filtered out, for use by later sync steps.
func (e *Engine) reconcilePeers(networkMap *mgmProto.NetworkMap) ([]*mgmProto.RemotePeerConfig, error) {
// Filter out own peer from the remote peers list
localPubKey := e.config.WgPrivateKey.PublicKey().String()
remotePeers := make([]*mgmProto.RemotePeerConfig, 0, len(networkMap.GetRemotePeers()))
@@ -1473,42 +1475,43 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
err := e.removeAllPeers()
e.statusRecorder.FinishPeerListModifications()
if err != nil {
return err
return nil, err
}
} else {
err := e.removePeers(remotePeers)
if err != nil {
return err
}
err = e.modifyPeers(remotePeers)
if err != nil {
return err
}
err = e.addNewPeers(remotePeers)
if err != nil {
return err
}
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
return remotePeers, nil
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
excludedLazyPeers := e.toExcludedLazyPeers(forwardingRules, remotePeers)
e.connMgr.SetExcludeList(e.ctx, excludedLazyPeers)
done := e.phase("removed_peers")
err := e.removePeers(remotePeers)
done()
if err != nil {
return nil, err
}
e.networkSerial = serial
done = e.phase("modified_peers")
err = e.modifyPeers(remotePeers)
done()
if err != nil {
return nil, err
}
return nil
done = e.phase("added_peers")
err = e.addNewPeers(remotePeers)
done()
if err != nil {
return nil, err
}
e.statusRecorder.FinishPeerListModifications()
e.updatePeerSSHHostKeys(remotePeers)
if err := e.updateSSHClientConfig(remotePeers); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
e.updateSSHServerAuth(networkMap.GetSshAuth())
return remotePeers, nil
}
func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool {

View File

@@ -0,0 +1,565 @@
//go:build privileged
package internal
import (
"context"
"fmt"
"net"
"runtime"
"strings"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
)
func TestEngine_SSH(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(
ctx, cancel,
&EngineConfig{
WgIfaceName: "utun101",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
MTU: iface.DefaultMTU,
SSHKey: sshKey,
},
EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
},
MobileDependency{},
)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
err = engine.Start(nil, nil)
require.NoError(t, err)
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
peerWithSSH := &mgmtProto.RemotePeerConfig{
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.21/24"},
SshConfig: &mgmtProto.SSHConfig{
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
},
}
// SSH server is not enabled so SSH config of a remote peer should be ignored
networkMap := &mgmtProto.NetworkMap{
Serial: 6,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
// SSH server is enabled, therefore SSH config should be applied
networkMap = &mgmtProto.NetworkMap{
Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{
SshEnabled: true,
JwtConfig: &mgmtProto.JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
MaxTokenAge: 3600,
},
}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now remove peer
networkMap = &mgmtProto.NetworkMap{
Serial: 8,
RemotePeers: []*mgmtProto.RemotePeerConfig{},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
// time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now disable SSH server
networkMap = &mgmtProto.NetworkMap{
Serial: 9,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
}
func TestEngine_Sync(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
t.Fatal(err)
}
}
return nil
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
err = engine.Start(nil, nil)
if err != nil {
t.Fatal(err)
return
}
peer1 := &mgmtProto.RemotePeerConfig{
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.10/24"},
}
peer2 := &mgmtProto.RemotePeerConfig{
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.11/24"},
}
peer3 := &mgmtProto.RemotePeerConfig{
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.12/24"},
}
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
updates <- &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{
Serial: 10,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
RemotePeersIsEmpty: false,
},
}
timeout := time.After(time.Second * 2)
for {
select {
case <-timeout:
t.Fatalf("timeout while waiting for test to finish")
return
default:
}
if getPeers(engine) == 3 && engine.networkSerial == 10 {
break
}
}
}
func TestEngine_MultiplePeers(t *testing.T) {
// log.SetLevel(log.DebugLevel)
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
sigServer, signalAddr, err := startSignal(t)
if err != nil {
t.Fatal(err)
return
}
defer sigServer.Stop()
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
if err != nil {
t.Fatal(err)
return
}
defer mgmtServer.GracefulStop()
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
mu := sync.Mutex{}
engines := []*Engine{}
numPeers := 10
wg := sync.WaitGroup{}
wg.Add(numPeers)
// create and start peers
for i := 0; i < numPeers; i++ {
j := i
go func() {
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
if err != nil {
wg.Done()
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
return
}
engine.dnsServer = &dns.MockServer{}
mu.Lock()
defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start(nil, nil)
if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err)
wg.Done()
return
}
engines = append(engines, engine)
wg.Done()
}()
}
// wait until all have been created and started
wg.Wait()
if len(engines) != numPeers {
t.Fatal("not all peers were started")
}
// check whether all the peer have expected peers connected
expectedConnected := numPeers * (numPeers - 1)
// adjust according to timeouts
timeout := 50 * time.Second
timeoutChan := time.After(timeout)
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
loop:
for {
select {
case <-timeoutChan:
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
break loop
case <-ticker.C:
totalConnected := 0
for _, engine := range engines {
totalConnected += getConnectedPeers(engine)
}
if totalConnected == expectedConnected {
log.Infof("total connected=%d", totalConnected)
break loop
}
log.Infof("total connected=%d", totalConnected)
}
}
// cleanup test
for n, peerEngine := range engines {
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
errStop := peerEngine.mgmClient.Close()
if errStop != nil {
log.Infoln("got error trying to close management clients from engine: ", errStop)
}
errStop = peerEngine.Stop()
if errStop != nil {
log.Infoln("got error trying to close testing peers engine: ", errStop)
}
}
}
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
if err != nil {
return nil, err
}
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
var ifaceName string
if runtime.GOOS == "darwin" {
ifaceName = fmt.Sprintf("utun1%d", i)
} else {
ifaceName = fmt.Sprintf("wt%d", i)
}
wgPort := 33100 + i
conf := &EngineConfig{
WgIfaceName: ifaceName,
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
WgPrivateKey: key,
WgPort: wgPort,
MTU: iface.DefaultMTU,
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{}), nil
e.ctx = ctx
return e, err
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper()
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &config.Host{
Proto: "http",
URI: "localhost:10000",
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, "", err
}
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
func getConnectedPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
i := 0
for _, id := range e.peerStore.PeersPubKey() {
conn, _ := e.peerStore.PeerConn(id)
if conn.IsConnected() {
i++
}
}
return i
}
func getPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
return len(e.peerStore.PeersPubKey())
}

View File

@@ -6,37 +6,18 @@ import (
"net"
"net/netip"
"os"
"runtime"
"strings"
"sync"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
@@ -50,18 +31,7 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/routemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/monotime"
"github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
@@ -69,25 +39,9 @@ import (
"github.com/netbirdio/netbird/shared/netiputil"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
"github.com/netbirdio/netbird/util"
)
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
type MockWGIface struct {
CreateFunc func() error
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
@@ -234,129 +188,6 @@ func TestMain(m *testing.M) {
os.Exit(code)
}
func TestEngine_SSH(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(
ctx, cancel,
&EngineConfig{
WgIfaceName: "utun101",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
ServerSSHAllowed: true,
MTU: iface.DefaultMTU,
SSHKey: sshKey,
},
EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
},
MobileDependency{},
)
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
err = engine.Start(nil, nil)
require.NoError(t, err)
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
peerWithSSH := &mgmtProto.RemotePeerConfig{
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.21/24"},
SshConfig: &mgmtProto.SSHConfig{
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
},
}
// SSH server is not enabled so SSH config of a remote peer should be ignored
networkMap := &mgmtProto.NetworkMap{
Serial: 6,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
// SSH server is enabled, therefore SSH config should be applied
networkMap = &mgmtProto.NetworkMap{
Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{
SshEnabled: true,
JwtConfig: &mgmtProto.JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
MaxTokenAge: 3600,
},
}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now remove peer
networkMap = &mgmtProto.NetworkMap{
Serial: 8,
RemotePeers: []*mgmtProto.RemotePeerConfig{},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
// time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
// now disable SSH server
networkMap = &mgmtProto.NetworkMap{
Serial: 9,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
}
func TestEngine_SSHUpdateLogic(t *testing.T) {
// Test that SSH server start/stop logic works based on config
engine := &Engine{
@@ -631,97 +462,6 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
}
}
func TestEngine_Sync(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
// feed updates to Engine via mocked Management client
updates := make(chan *mgmtProto.SyncResponse)
defer close(updates)
syncFunc := func(ctx context.Context, info *system.Info, msgHandler func(msg *mgmtProto.SyncResponse) error) error {
for msg := range updates {
err := msgHandler(msg)
if err != nil {
t.Fatal(err)
}
}
return nil
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
engine := NewEngine(ctx, cancel, &EngineConfig{
WgIfaceName: "utun103",
WgAddr: wgaddr.MustParseWGAddress("100.64.0.1/24"),
WgPrivateKey: key,
WgPort: 33100,
MTU: iface.DefaultMTU,
}, EngineServices{
SignalClient: &signal.MockClient{},
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{})
engine.ctx = ctx
engine.dnsServer = &dns.MockServer{
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
err = engine.Start(nil, nil)
if err != nil {
t.Fatal(err)
return
}
peer1 := &mgmtProto.RemotePeerConfig{
WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.10/24"},
}
peer2 := &mgmtProto.RemotePeerConfig{
WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.11/24"},
}
peer3 := &mgmtProto.RemotePeerConfig{
WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=",
AllowedIps: []string{"100.64.0.12/24"},
}
// 1st update with just 1 peer and serial larger than the current serial of the engine => apply update
updates <- &mgmtProto.SyncResponse{
NetworkMap: &mgmtProto.NetworkMap{
Serial: 10,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3},
RemotePeersIsEmpty: false,
},
}
timeout := time.After(time.Second * 2)
for {
select {
case <-timeout:
t.Fatalf("timeout while waiting for test to finish")
return
default:
}
if getPeers(engine) == 3 && engine.networkSerial == 10 {
break
}
}
}
func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
testCases := []struct {
name string
@@ -1105,104 +845,6 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
}
}
func TestEngine_MultiplePeers(t *testing.T) {
// log.SetLevel(log.DebugLevel)
ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel()
sigServer, signalAddr, err := startSignal(t)
if err != nil {
t.Fatal(err)
return
}
defer sigServer.Stop()
mgmtServer, mgmtAddr, err := startManagement(t, t.TempDir(), "../testdata/store.sql")
if err != nil {
t.Fatal(err)
return
}
defer mgmtServer.GracefulStop()
setupKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB"
mu := sync.Mutex{}
engines := []*Engine{}
numPeers := 10
wg := sync.WaitGroup{}
wg.Add(numPeers)
// create and start peers
for i := 0; i < numPeers; i++ {
j := i
go func() {
engine, err := createEngine(ctx, cancel, setupKey, j, mgmtAddr, signalAddr)
if err != nil {
wg.Done()
t.Errorf("unable to create the engine for peer %d with error %v", j, err)
return
}
engine.dnsServer = &dns.MockServer{}
mu.Lock()
defer mu.Unlock()
guid := fmt.Sprintf("{%s}", uuid.New().String())
device.CustomWindowsGUIDString = strings.ToLower(guid)
err = engine.Start(nil, nil)
if err != nil {
t.Errorf("unable to start engine for peer %d with error %v", j, err)
wg.Done()
return
}
engines = append(engines, engine)
wg.Done()
}()
}
// wait until all have been created and started
wg.Wait()
if len(engines) != numPeers {
t.Fatal("not all peers was started")
}
// check whether all the peer have expected peers connected
expectedConnected := numPeers * (numPeers - 1)
// adjust according to timeouts
timeout := 50 * time.Second
timeoutChan := time.After(timeout)
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
loop:
for {
select {
case <-timeoutChan:
t.Fatalf("waiting for expected connections timeout after %s", timeout.String())
break loop
case <-ticker.C:
totalConnected := 0
for _, engine := range engines {
totalConnected += getConnectedPeers(engine)
}
if totalConnected == expectedConnected {
log.Infof("total connected=%d", totalConnected)
break loop
}
log.Infof("total connected=%d", totalConnected)
}
}
// cleanup test
for n, peerEngine := range engines {
t.Logf("stopping peer with interface %s from multipeer test, loopIndex %d", peerEngine.wgInterface.Name(), n)
errStop := peerEngine.mgmClient.Close()
if errStop != nil {
log.Infoln("got error trying to close management clients from engine: ", errStop)
}
errStop = peerEngine.Stop()
if errStop != nil {
log.Infoln("got error trying to close testing peers engine: ", errStop)
}
}
}
func Test_ParseNATExternalIPMappings(t *testing.T) {
ifaceList, err := net.Interfaces()
if err != nil {
@@ -1526,187 +1168,6 @@ func TestCompareNetIPLists(t *testing.T) {
}
}
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
return nil, err
}
mgmtClient, err := mgmt.NewClient(ctx, mgmtAddr, key, false)
if err != nil {
return nil, err
}
signalClient, err := signal.NewClient(ctx, signalAddr, key, false)
if err != nil {
return nil, err
}
info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(setupKey, "", info, nil, nil)
if err != nil {
return nil, err
}
var ifaceName string
if runtime.GOOS == "darwin" {
ifaceName = fmt.Sprintf("utun1%d", i)
} else {
ifaceName = fmt.Sprintf("wt%d", i)
}
wgPort := 33100 + i
conf := &EngineConfig{
WgIfaceName: ifaceName,
WgAddr: wgaddr.MustParseWGAddress(resp.PeerConfig.Address),
WgPrivateKey: key,
WgPort: wgPort,
MTU: iface.DefaultMTU,
}
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
e, err := NewEngine(ctx, cancel, conf, EngineServices{
SignalClient: signalClient,
MgmClient: mgmtClient,
RelayManager: relayMgr,
StatusRecorder: peer.NewRecorder("https://mgm"),
}, MobileDependency{}), nil
e.ctx = ctx
return e, err
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
t.Helper()
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Relay: &config.Relay{
Addresses: []string{"127.0.0.1:1234"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "222222222222222222",
},
Signal: &config.Host{
Proto: "http",
URI: "localhost:10000",
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
permissionsManager := permissions.NewManager(store)
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
settingsMockManager := settings.NewMockManager(ctrl)
settingsMockManager.EXPECT().
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
Return(&types.Settings{}, nil).
AnyTimes()
settingsMockManager.EXPECT().
GetExtraSettings(gomock.Any(), gomock.Any()).
Return(&types.ExtraSettings{}, nil).
AnyTimes()
groupsManager := groups.NewManagerMock()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
if err != nil {
return nil, "", err
}
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(updateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
mgmtProto.RegisterManagementServiceServer(s, mgmtServer)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
// getConnectedPeers returns a connection Status or nil if peer connection wasn't found
func getConnectedPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
i := 0
for _, id := range e.peerStore.PeersPubKey() {
conn, _ := e.peerStore.PeerConn(id)
if conn.IsConnected() {
i++
}
}
return i
}
func getPeers(e *Engine) int {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
return len(e.peerStore.PeersPubKey())
}
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
t.Helper()
b, err := netiputil.EncodePrefix(p)

View File

@@ -119,10 +119,6 @@ func (d *BindListener) ReadPackets() {
}
d.peerCfg.Log.Debugf("removing lazy endpoint for peer %s", d.peerCfg.PublicKey)
if err := d.wgIface.RemovePeer(d.peerCfg.PublicKey); err != nil {
d.peerCfg.Log.Errorf("failed to remove endpoint: %s", err)
}
_ = d.lazyConn.Close()
d.bind.RemoveEndpoint(d.fakeIP)
d.done.Done()

View File

@@ -120,6 +120,30 @@ func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentI
m.trimLocked()
}
func (m *influxDBMetrics) RecordSyncPhase(_ context.Context, agentInfo AgentInfo, phase string, duration time.Duration) {
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s,phase=%s",
agentInfo.DeploymentType.String(),
agentInfo.Version,
agentInfo.OS,
agentInfo.Arch,
agentInfo.peerID,
phase,
)
m.mu.Lock()
defer m.mu.Unlock()
m.samples = append(m.samples, influxSample{
measurement: "netbird_sync_phase",
tags: tags,
fields: map[string]float64{
"duration_seconds": duration.Seconds(),
},
timestamp: time.Now(),
})
m.trimLocked()
}
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
result := "success"
if !success {

View File

@@ -78,6 +78,25 @@ Tags:
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
### Sync Phase Timing
Measurement: `netbird_sync_phase`
Breaks down where time goes inside a single sync, so the total `netbird_sync` duration can be attributed to the sub-step that dominates.
| Field | Description |
|-------|-------------|
| `duration_seconds` | Time spent in one sub-phase of sync processing |
Tags:
- `phase`: the sub-phase — `netbird_config`, `checks`, `persist`, `dns_server`, `routes_classify`, `routes_apply`, `filtering`, `dns_forwarder`, `forward_rules`, `offline_peers`, `removed_peers`, `modified_peers`, `added_peers`, `lazy_exclude`
- `deployment_type`: "cloud" | "selfhosted" | "unknown"
- `version`: NetBird version string
- `os`: Operating system (linux, darwin, windows, android, ios, etc.)
- `arch`: CPU architecture (amd64, arm64, etc.)
**Note:** this is wall-time per phase — it includes both CPU work and time spent waiting on locks. A slow phase points to *where* the time goes, not *why*; pair it with lock-wait metrics to tell contention apart from real work.
### Login Duration
Measurement: `netbird_login`
@@ -191,4 +210,52 @@ docker compose exec influxdb influx query \
# Check ingest server health
curl http://localhost:8087/health
```
```
## Analyzing a Debug Bundle
Metrics collection is always on, so every debug bundle ships a `metrics.txt` in InfluxDB line protocol — a timestamped time series of all recorded events (sync durations, sync phases, connection stages, login). You can replay it into the local stack and graph it, without a running client.
The bundle's `metrics.txt` is a rolling window (capped at 5 days / ~20k samples, see [Buffer Limits](#buffer-limits)). For a connection incident the relevant window is short (connection setup is seconds), so a bundle captured during the issue is enough.
### 1. Start the stack
```bash
# From this directory (client/internal/metrics/infra)
INFLUXDB_ADMIN_TOKEN=admin123 INFLUXDB_ADMIN_PASSWORD=admin123 GRAFANA_ADMIN_PASSWORD=admin123 \
docker compose up -d
```
(`admin123` are throwaway local credentials — fine for offline analysis.)
### 2. Clear any previous data
So you only see this bundle:
```bash
docker exec influxdb influx delete --org netbird --bucket metrics --token admin123 \
--start 1970-01-01T00:00:00Z --stop 2100-01-01T00:00:00Z
```
### 3. Import the bundle's metrics.txt
InfluxDB is not exposed on the host, so import inside the container:
```bash
docker cp /path/to/bundle/metrics.txt influxdb:/tmp/m.txt
docker exec influxdb influx write --org netbird --bucket metrics --precision ns \
--token admin123 --file /tmp/m.txt
```
Re-importing the same file is idempotent (same measurement+tags+timestamp overwrites).
### 4. View the dashboards
Grafana on http://localhost:3001 (login `admin` / `admin123`), datasource pre-provisioned:
- **Where sync time goes:** http://localhost:3001/d/netbird-sync-phases/netbird-sync-phases-where-time-goes
- **General client metrics:** http://localhost:3001/d/netbird-influxdb-metrics
**Set the time range** to cover the bundle's timestamps (e.g. "Last 7 days" or an absolute range matching when the bundle was taken) — with the default short range the panels look empty.
Bundles are distinguishable by the `version` tag; add a tag at import time (e.g. `sed 's/^netbird_\([a-z_]*\),/netbird_\1,bundle=mycase,/' metrics.txt`) if you want to compare several side by side.

View File

@@ -0,0 +1,259 @@
{
"annotations": {
"list": []
},
"editable": true,
"fiscalYearStartMonth": 0,
"graphTooltip": 1,
"links": [],
"refresh": "",
"schemaVersion": 39,
"tags": [
"netbird",
"sync"
],
"templating": {
"list": [
{
"current": {
"text": "All",
"value": "$__all"
},
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"definition": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
"includeAll": true,
"label": "version",
"multi": true,
"name": "version",
"query": "import \"influxdata/influxdb/schema\"\nschema.tagValues(bucket: \"metrics\", tag: \"version\")",
"refresh": 2,
"type": "query",
"allValue": ".*"
}
]
},
"time": {
"from": "now-2d",
"to": "now"
},
"timepicker": {},
"timezone": "",
"title": "NetBird Sync Phases (where time goes)",
"uid": "netbird-sync-phases",
"version": 1,
"panels": [
{
"id": 1,
"title": "Time per phase over time (stacked, ms)",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 10,
"w": 24,
"x": 0,
"y": 0
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"custom": {
"drawStyle": "bars",
"stacking": {
"mode": "normal",
"group": "A"
},
"fillOpacity": 80,
"lineWidth": 0
}
},
"overrides": []
},
"options": {
"legend": {
"displayMode": "table",
"placement": "right",
"calcs": [
"max",
"mean"
]
},
"tooltip": {
"mode": "multi",
"sort": "desc"
}
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"phase\"])\n |> group(columns: [\"phase\"])"
}
]
},
{
"id": 2,
"title": "p95 per phase (ms)",
"type": "bargauge",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 11,
"w": 12,
"x": 0,
"y": 10
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"color": {
"mode": "continuous-GrYlRd"
}
},
"overrides": []
},
"options": {
"displayMode": "gradient",
"orientation": "horizontal",
"reduceOptions": {
"calcs": [
"lastNotNull"
],
"fields": "",
"values": false
},
"showUnfilled": true
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> sort(columns: [\"_value\"], desc: true)"
}
]
},
{
"id": 3,
"title": "Per-phase stats (ms): mean / p95 / max",
"type": "table",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 11,
"w": 12,
"x": 12,
"y": 10
},
"fieldConfig": {
"defaults": {
"unit": "ms"
},
"overrides": []
},
"options": {
"showHeader": true,
"sortBy": [
{
"displayName": "max",
"desc": true
}
]
},
"transformations": [
{
"id": "merge",
"options": {}
}
],
"targets": [
{
"refId": "mean",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> mean()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"mean\"})"
},
{
"refId": "p95",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> quantile(q: 0.95)\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"p95\"})"
},
{
"refId": "max",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync_phase\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> group(columns: [\"phase\"])\n |> max()\n |> group()\n |> keep(columns: [\"phase\", \"_value\"])\n |> rename(columns: {_value: \"max\"})"
}
]
},
{
"id": 4,
"title": "Total sync duration (netbird_sync, ms) \u2014 reference",
"type": "timeseries",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"gridPos": {
"h": 8,
"w": 24,
"x": 0,
"y": 21
},
"fieldConfig": {
"defaults": {
"unit": "ms",
"custom": {
"drawStyle": "points",
"pointSize": 5
}
},
"overrides": []
},
"options": {
"legend": {
"displayMode": "table",
"placement": "right",
"calcs": [
"max",
"mean"
]
},
"tooltip": {
"mode": "single"
}
},
"targets": [
{
"refId": "A",
"datasource": {
"type": "influxdb",
"uid": "influxdb"
},
"query": "from(bucket: \"metrics\")\n |> range(start: v.timeRangeStart, stop: v.timeRangeStop)\n |> filter(fn: (r) => r._measurement == \"netbird_sync\" and r._field == \"duration_seconds\")\n |> filter(fn: (r) => r.version =~ /${version:regex}/)\n |> map(fn: (r) => ({ r with _value: r._value * 1000.0 }))\n |> keep(columns: [\"_time\", \"_value\", \"version\"])\n |> group(columns: [\"version\"])"
}
]
}
]
}

View File

@@ -19,7 +19,7 @@ const (
defaultListenAddr = ":8087"
defaultInfluxDBURL = "http://influxdb:8086/api/v2/write?org=netbird&bucket=metrics&precision=ns"
maxBodySize = 50 * 1024 * 1024 // 50 MB max request body
maxDurationSeconds = 300.0 // reject any duration field > 5 minutes
maxDurationSeconds = 86400.0 // reject any duration field > 24 hours
peerIDLength = 16 // truncated SHA-256: 8 bytes = 16 hex chars
maxTagValueLength = 64 // reject tag values longer than this
)
@@ -59,6 +59,19 @@ var allowedMeasurements = map[string]measurementSpec{
"peer_id": true,
},
},
"netbird_sync_phase": {
allowedFields: map[string]bool{
"duration_seconds": true,
},
allowedTags: map[string]bool{
"deployment_type": true,
"version": true,
"os": true,
"arch": true,
"peer_id": true,
"phase": true,
},
},
"netbird_login": {
allowedFields: map[string]bool{
"duration_seconds": true,

View File

@@ -53,14 +53,14 @@ func TestValidateLine_NegativeValue(t *testing.T) {
}
func TestValidateLine_DurationTooLarge(t *testing.T) {
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=999 1234567890`
line := `netbird_sync,deployment_type=cloud,version=1.0.0,os=linux,arch=amd64,peer_id=abc duration_seconds=100000 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "too large")
}
func TestValidateLine_TotalSecondsTooLarge(t *testing.T) {
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=500 1234567890`
line := `netbird_peer_connection,deployment_type=cloud,connection_type=ice,attempt_type=initial,version=1.0.0,os=linux,arch=amd64,peer_id=abc,connection_pair_id=pair total_seconds=100000 1234567890`
err := validateLine(line)
require.Error(t, err)
assert.Contains(t, err.Error(), "too large")

View File

@@ -56,6 +56,9 @@ type metricsImplementation interface {
// RecordSyncDuration records how long it took to process a sync message
RecordSyncDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration)
// RecordSyncPhase records how long a single sub-phase of sync processing took
RecordSyncPhase(ctx context.Context, agentInfo AgentInfo, phase string, duration time.Duration)
// RecordLoginDuration records how long the login to management took
RecordLoginDuration(ctx context.Context, agentInfo AgentInfo, duration time.Duration, success bool)
@@ -127,6 +130,18 @@ func (c *ClientMetrics) RecordSyncDuration(ctx context.Context, duration time.Du
c.impl.RecordSyncDuration(ctx, agentInfo, duration)
}
// RecordSyncPhase records the duration of a single sub-phase of sync processing
func (c *ClientMetrics) RecordSyncPhase(ctx context.Context, phase string, duration time.Duration) {
if c == nil {
return
}
c.mu.RLock()
agentInfo := c.agentInfo
c.mu.RUnlock()
c.impl.RecordSyncPhase(ctx, agentInfo, phase, duration)
}
// RecordLoginDuration records how long the login to management server took
func (c *ClientMetrics) RecordLoginDuration(ctx context.Context, duration time.Duration, success bool) {
if c == nil {

View File

@@ -70,6 +70,9 @@ func (m *mockMetrics) RecordConnectionStages(_ context.Context, _ AgentInfo, _ s
func (m *mockMetrics) RecordSyncDuration(_ context.Context, _ AgentInfo, _ time.Duration) {
}
func (m *mockMetrics) RecordSyncPhase(_ context.Context, _ AgentInfo, _ string, _ time.Duration) {
}
func (m *mockMetrics) RecordLoginDuration(_ context.Context, _ AgentInfo, _ time.Duration, _ bool) {
}

View File

@@ -195,14 +195,14 @@ func (h *Handshaker) sendOffer() error {
}
offer := h.buildOfferAnswer()
h.log.Infof("sending offer with serial: %s", offer.SessionIDString())
h.log.Debugf("sending offer with serial: %s", offer.SessionIDString())
return h.signaler.SignalOffer(offer, h.config.Key)
}
func (h *Handshaker) sendAnswer() error {
answer := h.buildOfferAnswer()
h.log.Infof("sending answer with serial: %s", answer.SessionIDString())
h.log.Debugf("sending answer with serial: %s", answer.SessionIDString())
return h.signaler.SignalAnswer(answer, h.config.Key)
}

View File

@@ -192,6 +192,7 @@ func (s *StatusChangeSubscription) Events() chan map[string]RouterState {
// Pure read methods take RLock; anything that mutates state takes Lock.
type Status struct {
mux sync.RWMutex
muxRelays sync.RWMutex
peers map[string]State
ipToKey map[string]string
changeNotify map[string]map[string]*StatusChangeSubscription // map[peerID]map[subscriptionID]*StatusChangeSubscription
@@ -244,8 +245,8 @@ func NewRecorder(mgmAddress string) *Status {
}
func (d *Status) SetRelayMgr(manager *relayClient.Manager) {
d.mux.Lock()
defer d.mux.Unlock()
d.muxRelays.Lock()
defer d.muxRelays.Unlock()
d.relayMgr = manager
}
@@ -906,8 +907,8 @@ func (d *Status) MarkSignalConnected() {
}
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
d.mux.Lock()
defer d.mux.Unlock()
d.muxRelays.Lock()
defer d.muxRelays.Unlock()
d.relayStates = relayResults
}
@@ -1018,24 +1019,26 @@ func (d *Status) GetSignalState() SignalState {
// GetRelayStates returns the stun/turn/permanent relay states
func (d *Status) GetRelayStates() []relay.ProbeResult {
d.mux.RLock()
defer d.mux.RUnlock()
d.muxRelays.RLock()
if d.relayMgr == nil {
return d.relayStates
defer d.muxRelays.RUnlock()
return slices.Clone(d.relayStates)
}
relayMgr := d.relayMgr
// extend the list of stun, turn servers with the relay server connections
relayStates := slices.Clone(d.relayStates)
d.muxRelays.RUnlock()
states := d.relayMgr.RelayStates()
states := relayMgr.RelayStates()
if len(states) == 0 {
// no relay connection tracked yet; surface configured servers as
// unavailable with the real reconnect error when known
err := relayClient.ErrRelayClientNotConnected
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
if connErr := relayMgr.RelayConnectError(); connErr != nil {
err = connErr
}
for _, r := range d.relayMgr.ServerURLs() {
for _, r := range relayMgr.ServerURLs() {
relayStates = append(relayStates, relay.ProbeResult{
URI: r,
Err: err,

View File

@@ -433,7 +433,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
if input.ServerSSHAllowed != nil && (config.ServerSSHAllowed == nil || *input.ServerSSHAllowed != *config.ServerSSHAllowed) {
if *input.ServerSSHAllowed {
log.Infof("enabling SSH server")
} else {

View File

@@ -242,6 +242,35 @@ func TestWireguardPortDefaultVsExplicit(t *testing.T) {
}
}
func TestUpdateConfigServerSSHAllowedNotSet(t *testing.T) {
// Configs written before ServerSSHAllowed was introduced lack the field and
// unmarshal to nil. Supplying the SSH server flag on top of such a config must
// apply the value instead of panicking on a nil pointer dereference.
tests := []struct {
name string
input *bool
want bool
}{
{"enable", util.True(), true},
{"disable", util.False(), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
configPath := filepath.Join(t.TempDir(), "config.json")
require.NoError(t, os.WriteFile(configPath, []byte("{}"), 0600))
config, err := UpdateConfig(ConfigInput{
ConfigPath: configPath,
ServerSSHAllowed: tt.input,
})
require.NoError(t, err)
require.NotNil(t, config.ServerSSHAllowed, "ServerSSHAllowed should be set from input")
assert.Equal(t, tt.want, *config.ServerSSHAllowed)
})
}
}
func TestUpdateOldManagementURL(t *testing.T) {
origProber := newMgmProber
newMgmProber = func(_ context.Context, _ string, _ wgtypes.Key, _ bool) (mgmProber, error) {

View File

@@ -226,12 +226,11 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
// pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
d.continueToNextHandler(w, r, logger, "non A/AAAA query")
return
}
// All query types for an intercepted domain are forwarded to the peer's
// DNS forwarder, which owns the name. Falling through to the system
// resolver would let it answer NXDOMAIN for a name it isn't authoritative
// for, poisoning the whole name (including the A/AAAA records the route
// does serve). The forwarder answers NODATA for types it cannot resolve.
d.mu.RLock()
peerKey := d.currentPeerKey
d.mu.RUnlock()
@@ -293,19 +292,6 @@ func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger
}
}
// continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil {
logger.Errorf("failed writing DNS continue response: %v", err)
}
}
func (d *DnsInterceptor) getUpstreamIP(peerKey string) (netip.Addr, error) {
peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey)
if !exists {

View File

@@ -1,3 +1,5 @@
//go:build privileged
package routemanager
import (

View File

@@ -0,0 +1,69 @@
//go:build linux && !android
package systemops
import (
"fmt"
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestEntryExists(t *testing.T) {
tempDir := t.TempDir()
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
content := []string{
"1000 reserved",
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
"9999 other_table",
}
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
file, err := os.Open(tempFilePath)
require.NoError(t, err)
defer func() {
assert.NoError(t, file.Close())
}()
tests := []struct {
name string
id int
shouldExist bool
err error
}{
{
name: "ExistsWithNetbirdPrefix",
id: 7120,
shouldExist: true,
err: nil,
},
{
name: "ExistsWithDifferentName",
id: 1000,
shouldExist: true,
err: ErrTableIDExists,
},
{
name: "DoesNotExist",
id: 1234,
shouldExist: false,
err: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
exists, err := entryExists(file, tc.id)
if tc.err != nil {
assert.ErrorIs(t, err, tc.err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.shouldExist, exists)
})
}
}

View File

@@ -0,0 +1,191 @@
//go:build (darwin || dragonfly || freebsd || netbsd || openbsd) && privileged
package systemops
import (
"fmt"
"net"
"net/netip"
"os/exec"
"regexp"
"runtime"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func init() {
testCases = append(testCases, []testCase{
{
name: "To more specific route without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
},
}...)
}
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
var intf *net.Interface
var nexthop Nexthop
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}
r := New(nil, nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
baseIP = netip.MustParseAddr("192.0.2.0")
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()
if runtime.GOOS == "darwin" {
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return intf
}
prefix, err := netip.ParsePrefix(ipAddressCIDR)
require.NoError(t, err, "Failed to parse prefix")
netIntf, err := net.InterfaceByName(intf)
require.NoError(t, err, "Failed to get interface by name")
nexthop := Nexthop{netip.Addr{}, netIntf}
r := New(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")
t.Cleanup(func() {
err := r.removeFromRouteTable(prefix, nexthop)
assert.NoError(t, err, "Failed to remove route from table")
})
return intf
}
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
t.Helper()
var originalNexthop net.IP
if dstCIDR == "0.0.0.0/0" {
var err error
originalNexthop, err = fetchOriginalGateway()
if err != nil {
t.Logf("Failed to fetch original gateway: %v", err)
}
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
t.Logf("Failed to delete route: %v, output: %s", err, output)
}
}
t.Cleanup(func() {
if originalNexthop != nil {
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
assert.NoError(t, err, "Failed to restore original route")
}
})
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
require.NoError(t, err, "Failed to add route")
t.Cleanup(func() {
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
assert.NoError(t, err, "Failed to remove route")
})
}
func fetchOriginalGateway() (net.IP, error) {
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
if err != nil {
return nil, err
}
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
if len(matches) == 0 {
return nil, fmt.Errorf("gateway not found")
}
return net.ParseIP(matches[1]), nil
}
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
t.Helper()
if runtime.GOOS == "darwin" {
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
}
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
tunName := strings.TrimSpace(string(output))
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
intf, err := net.InterfaceByName(tunName)
require.NoError(t, err, "Failed to get interface by name")
t.Cleanup(func() {
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
}
})
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
}

View File

@@ -3,79 +3,24 @@
package systemops
import (
"fmt"
"net"
"net/netip"
"os/exec"
"regexp"
"runtime"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/route"
)
// Interface names used by the shared routing test fixtures. Kept untagged (no
// privileged build tag) so the non-privileged test files in this package compile.
//
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedVPNint = "utun100"
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedExternalInt = "lo0"
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedInternalInt = "lo0"
func init() {
testCases = append(testCases, []testCase{
{
name: "To more specific route without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
},
}...)
}
func TestConcurrentRoutes(t *testing.T) {
baseIP := netip.MustParseAddr("192.0.2.0")
var intf *net.Interface
var nexthop Nexthop
_, intf = setupDummyInterface(t)
nexthop = Nexthop{netip.Addr{}, intf}
r := New(nil, nil)
var wg sync.WaitGroup
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.addToRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to add route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
baseIP = netip.MustParseAddr("192.0.2.0")
for i := 0; i < 1024; i++ {
wg.Add(1)
go func(ip netip.Addr) {
defer wg.Done()
prefix := netip.PrefixFrom(ip, 32)
if err := r.removeFromRouteTable(prefix, nexthop); err != nil {
t.Errorf("Failed to remove route for %s: %v", prefix, err)
}
}(baseIP)
baseIP = baseIP.Next()
}
wg.Wait()
}
func TestBits(t *testing.T) {
tests := []struct {
name string
@@ -122,122 +67,3 @@ func TestBits(t *testing.T) {
})
}
}
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
t.Helper()
if runtime.GOOS == "darwin" {
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
require.NoError(t, err, "Failed to create loopback alias")
t.Cleanup(func() {
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
assert.NoError(t, err, "Failed to remove loopback alias")
})
return intf
}
prefix, err := netip.ParsePrefix(ipAddressCIDR)
require.NoError(t, err, "Failed to parse prefix")
netIntf, err := net.InterfaceByName(intf)
require.NoError(t, err, "Failed to get interface by name")
nexthop := Nexthop{netip.Addr{}, netIntf}
r := New(nil, nil)
err = r.addToRouteTable(prefix, nexthop)
require.NoError(t, err, "Failed to add route to table")
t.Cleanup(func() {
err := r.removeFromRouteTable(prefix, nexthop)
assert.NoError(t, err, "Failed to remove route from table")
})
return intf
}
func addDummyRoute(t *testing.T, dstCIDR string, gw netip.Addr, _ string) {
t.Helper()
var originalNexthop net.IP
if dstCIDR == "0.0.0.0/0" {
var err error
originalNexthop, err = fetchOriginalGateway()
if err != nil {
t.Logf("Failed to fetch original gateway: %v", err)
}
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
t.Logf("Failed to delete route: %v, output: %s", err, output)
}
}
t.Cleanup(func() {
if originalNexthop != nil {
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
assert.NoError(t, err, "Failed to restore original route")
}
})
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
require.NoError(t, err, "Failed to add route")
t.Cleanup(func() {
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
assert.NoError(t, err, "Failed to remove route")
})
}
func fetchOriginalGateway() (net.IP, error) {
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
if err != nil {
return nil, err
}
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
if len(matches) == 0 {
return nil, fmt.Errorf("gateway not found")
}
return net.ParseIP(matches[1]), nil
}
// setupDummyInterface creates a dummy tun interface for FreeBSD route testing
func setupDummyInterface(t *testing.T) (netip.Addr, *net.Interface) {
t.Helper()
if runtime.GOOS == "darwin" {
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), &net.Interface{Name: "lo0"}
}
output, err := exec.Command("ifconfig", "tun", "create").CombinedOutput()
require.NoError(t, err, "Failed to create tun interface: %s", string(output))
tunName := strings.TrimSpace(string(output))
output, err = exec.Command("ifconfig", tunName, "192.168.1.1", "netmask", "255.255.0.0", "192.168.1.2", "up").CombinedOutput()
require.NoError(t, err, "Failed to configure tun interface: %s", string(output))
intf, err := net.InterfaceByName(tunName)
require.NoError(t, err, "Failed to get interface by name")
t.Cleanup(func() {
if err := exec.Command("ifconfig", tunName, "destroy").Run(); err != nil {
t.Logf("Failed to destroy tun interface %s: %v", tunName, err)
}
})
return netip.AddrFrom4([4]byte{192, 168, 1, 2}), intf
}
func setupDummyInterfacesAndRoutes(t *testing.T) {
t.Helper()
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
addDummyRoute(t, "0.0.0.0/0", netip.AddrFrom4([4]byte{192, 168, 0, 1}), defaultDummy)
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
addDummyRoute(t, "10.0.0.0/8", netip.AddrFrom4([4]byte{192, 168, 1, 1}), otherDummy)
}

View File

@@ -0,0 +1,17 @@
//go:build !android && !ios
package systemops
import (
"context"
"net"
)
// dialer is shared by the per-platform routing test cases. Kept untagged (no
// privileged build tag) so the non-privileged test files compile on every platform.
//
//nolint:unused // consumed by the privileged-tagged routing tests
type dialer interface {
Dial(network, address string) (net.Conn, error)
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

View File

@@ -1,4 +1,4 @@
//go:build !android && !ios
//go:build !android && !ios && privileged
package systemops
@@ -26,11 +26,6 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
type dialer interface {
Dial(network, address string) (net.Conn, error)
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func TestAddVPNRoute(t *testing.T) {
testCases := []struct {
name string
@@ -515,125 +510,3 @@ func setupTestEnv(t *testing.T) {
// unique route in vpn table
setupRouteAndCleanup(t, r, netip.MustParsePrefix("172.16.0.0/12"), intf)
}
func TestIsVpnRoute(t *testing.T) {
tests := []struct {
name string
addr string
vpnRoutes []string
localRoutes []string
expectedVpn bool
expectedPrefix netip.Prefix
}{
{
name: "Match in VPN routes",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Match in local routes",
addr: "10.1.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
},
{
name: "No match",
addr: "172.16.0.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Default route ignored",
addr: "192.168.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Default route matches but ignored",
addr: "172.16.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Longest prefix match local",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.0.0/16"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match local multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
},
{
name: "Longest prefix match vpn",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.0.0/16"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match vpn multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
},
{
name: "Duplicate prefix in both",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := netip.ParseAddr(tt.addr)
if err != nil {
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
}
var vpnRoutes, localRoutes []netip.Prefix
for _, route := range tt.vpnRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
}
vpnRoutes = append(vpnRoutes, prefix)
}
for _, route := range tt.localRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse local route %s: %v", route, err)
}
localRoutes = append(localRoutes, prefix)
}
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
})
}
}

View File

@@ -0,0 +1,132 @@
//go:build !android && !ios
package systemops
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsVpnRoute(t *testing.T) {
tests := []struct {
name string
addr string
vpnRoutes []string
localRoutes []string
expectedVpn bool
expectedPrefix netip.Prefix
}{
{
name: "Match in VPN routes",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Match in local routes",
addr: "10.1.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("10.0.0.0/8"),
},
{
name: "No match",
addr: "172.16.0.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Default route ignored",
addr: "192.168.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Default route matches but ignored",
addr: "172.16.1.1",
vpnRoutes: []string{"0.0.0.0/0", "192.168.1.0/24"},
localRoutes: []string{"10.0.0.0/8"},
expectedVpn: false,
expectedPrefix: netip.Prefix{},
},
{
name: "Longest prefix match local",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.0.0/16"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match local multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26", "192.168.0.0/28"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/28"),
},
{
name: "Longest prefix match vpn",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.0.0/16"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
{
name: "Longest prefix match vpn multiple",
addr: "192.168.0.1",
vpnRoutes: []string{"192.168.0.0/16", "192.168.0.0/25", "192.168.0.0/27"},
localRoutes: []string{"192.168.0.0/24", "192.168.0.0/26"},
expectedVpn: true,
expectedPrefix: netip.MustParsePrefix("192.168.0.0/27"),
},
{
name: "Duplicate prefix in both",
addr: "192.168.1.1",
vpnRoutes: []string{"192.168.1.0/24"},
localRoutes: []string{"192.168.1.0/24"},
expectedVpn: false,
expectedPrefix: netip.MustParsePrefix("192.168.1.0/24"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr, err := netip.ParseAddr(tt.addr)
if err != nil {
t.Fatalf("Failed to parse address %s: %v", tt.addr, err)
}
var vpnRoutes, localRoutes []netip.Prefix
for _, route := range tt.vpnRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse VPN route %s: %v", route, err)
}
vpnRoutes = append(vpnRoutes, prefix)
}
for _, route := range tt.localRoutes {
prefix, err := netip.ParsePrefix(route)
if err != nil {
t.Fatalf("Failed to parse local route %s: %v", route, err)
}
localRoutes = append(localRoutes, prefix)
}
isVpn, matchedPrefix := isVpnRoute(addr, vpnRoutes, localRoutes)
assert.Equal(t, tt.expectedVpn, isVpn, "isVpnRoute should return expectedVpn value")
assert.Equal(t, tt.expectedPrefix, matchedPrefix, "isVpnRoute should return expectedVpn prefix")
})
}
}

View File

@@ -1,13 +1,10 @@
//go:build !android
//go:build linux && !android && privileged
package systemops
import (
"errors"
"fmt"
"net"
"os"
"strings"
"syscall"
"testing"
@@ -18,10 +15,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
)
var expectedVPNint = "wgtest0"
var expectedExternalInt = "dummyext0"
var expectedInternalInt = "dummyint0"
func init() {
testCases = append(testCases, []testCase{
{
@@ -33,62 +26,6 @@ func init() {
}...)
}
func TestEntryExists(t *testing.T) {
tempDir := t.TempDir()
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
content := []string{
"1000 reserved",
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
"9999 other_table",
}
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
file, err := os.Open(tempFilePath)
require.NoError(t, err)
defer func() {
assert.NoError(t, file.Close())
}()
tests := []struct {
name string
id int
shouldExist bool
err error
}{
{
name: "ExistsWithNetbirdPrefix",
id: 7120,
shouldExist: true,
err: nil,
},
{
name: "ExistsWithDifferentName",
id: 1000,
shouldExist: true,
err: ErrTableIDExists,
},
{
name: "DoesNotExist",
id: 1234,
shouldExist: false,
err: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
exists, err := entryExists(file, tc.id)
if tc.err != nil {
assert.ErrorIs(t, err, tc.err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tc.shouldExist, exists)
})
}
}
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
t.Helper()

View File

@@ -0,0 +1,15 @@
//go:build linux && !android
package systemops
// Interface names used by the shared routing test fixtures. Kept untagged (no
// privileged build tag) so the non-privileged test files in this package compile.
//
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedVPNint = "wgtest0"
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedExternalInt = "dummyext0"
//nolint:unused // consumed by the privileged-tagged routing tests
var expectedInternalInt = "dummyint0"

View File

@@ -0,0 +1,83 @@
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
package systemops
import (
"net"
nbnet "github.com/netbirdio/netbird/client/net"
)
// Shared, non-privileged routing test fixtures. The privileged TestRouting (and its
// per-platform init() appenders) consume these; they live here so the unprivileged
// BSD/darwin test files compile without the privileged build tag.
type PacketExpectation struct {
SrcIP net.IP
DstIP net.IP
SrcPort int
DstPort int
UDP bool
TCP bool
}
//nolint:unused // consumed by the privileged-tagged routing tests
type testCase struct {
name string
expectedInterface string
dialer dialer
expectedPacket PacketExpectation
}
//nolint:unused // consumed by the privileged-tagged routing tests
var testCases = []testCase{
{
name: "To external host without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To external host with custom dialer via physical interface",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To duplicate internal route with custom dialer via physical interface",
expectedInterface: expectedInternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
expectedInterface: expectedInternalInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To unique vpn route with custom dialer via physical interface",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
},
{
name: "To unique vpn route without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
},
}
//nolint:unused // consumed by the privileged-tagged routing tests
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
return PacketExpectation{
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
UDP: true,
}
}

View File

@@ -1,4 +1,4 @@
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
//go:build ((linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly) && privileged
package systemops
@@ -20,63 +20,6 @@ import (
nbnet "github.com/netbirdio/netbird/client/net"
)
type PacketExpectation struct {
SrcIP net.IP
DstIP net.IP
SrcPort int
DstPort int
UDP bool
TCP bool
}
type testCase struct {
name string
expectedInterface string
dialer dialer
expectedPacket PacketExpectation
}
var testCases = []testCase{
{
name: "To external host without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To external host with custom dialer via physical interface",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
},
{
name: "To duplicate internal route with custom dialer via physical interface",
expectedInterface: expectedInternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
expectedInterface: expectedInternalInt,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
},
{
name: "To unique vpn route with custom dialer via physical interface",
expectedInterface: expectedExternalInt,
dialer: nbnet.NewDialer(),
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
},
{
name: "To unique vpn route without custom dialer via vpn",
expectedInterface: expectedVPNint,
dialer: &net.Dialer{},
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
},
}
func TestRouting(t *testing.T) {
nbnet.Init()
for _, tc := range testCases {
@@ -102,16 +45,6 @@ func TestRouting(t *testing.T) {
}
}
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
return PacketExpectation{
SrcIP: net.ParseIP(srcIP),
DstIP: net.ParseIP(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
UDP: true,
}
}
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
t.Helper()

View File

@@ -1,3 +1,5 @@
//go:build windows && privileged
package systemops
import (

View File

@@ -11,6 +11,8 @@ import (
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
// interface so route lookups for global IPv6 prefixes resolve in environments
// without v6 connectivity. If a default already exists it is left alone.
//
//nolint:unused // consumed by the privileged-tagged routing tests
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()

View File

@@ -1,4 +1,4 @@
//go:build linux && !android
//go:build linux && !android && privileged
package systemops

View File

@@ -8,11 +8,14 @@ import (
"testing"
)
//nolint:unused // consumed by the privileged-tagged routing tests
const loopbackIfaceWindows = "Loopback Pseudo-Interface 1"
// ensureIPv6DefaultRoute installs an IPv6 default route via the loopback
// interface so route lookups for global IPv6 prefixes resolve in environments
// without v6 connectivity. If a default already exists it is left alone.
//
//nolint:unused // consumed by the privileged-tagged routing tests
func ensureIPv6DefaultRoute(t *testing.T) {
t.Helper()

View File

@@ -67,6 +67,7 @@ var boolStringLiterals = map[string]bool{
"no": false,
}
// Policy holds MDM-managed settings read from the platform source. A nil or
// empty Policy means no enforcement is active.
type Policy struct {

View File

@@ -31,8 +31,8 @@ func TestPolicy_Empty(t *testing.T) {
func TestPolicy_HasKey(t *testing.T) {
p := NewPolicy(map[string]any{
KeyManagementURL: "https://corp.example.com",
KeyDisableProfiles: true,
KeyManagementURL: "https://corp.example.com",
KeyDisableProfiles: true,
})
assert.False(t, p.IsEmpty())
assert.True(t, p.HasKey(KeyManagementURL))
@@ -53,8 +53,8 @@ func TestPolicy_ManagedKeysSorted(t *testing.T) {
func TestPolicy_GetString(t *testing.T) {
p := NewPolicy(map[string]any{
KeyManagementURL: "https://corp.example.com",
KeyDisableProfiles: true, // wrong type for GetString
KeyPreSharedKey: "", // empty rejected
KeyDisableProfiles: true, // wrong type for GetString
KeyPreSharedKey: "", // empty rejected
})
v, ok := p.GetString(KeyManagementURL)
assert.True(t, ok)

View File

@@ -0,0 +1,235 @@
//go:build privileged
package server
import (
"context"
"net"
"os/user"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
)
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
// we will use a management server started via to simulate the server and capture the number of retries
func TestConnectWithRetryRuns(t *testing.T) {
// start the signal server
_, signalAddr, err := startSignal(t)
if err != nil {
t.Fatalf("failed to start signal server: %v", err)
}
counter := 0
// start the management server
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
if err != nil {
t.Fatalf("failed to start management server: %v", err)
}
ctx := internal.CtxInitState(context.Background())
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
defer cancel()
// create new server
ic := profilemanager.ConfigInput{
ManagementURL: "http://" + mgmtAddr,
ConfigPath: t.TempDir() + "/test-profile.json",
}
config, err := profilemanager.UpdateOrCreateConfig(ic)
if err != nil {
t.Fatalf("failed to create config: %v", err)
}
currUser, err := user.Current()
require.NoError(t, err)
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "test-profile",
Username: currUser.Username,
})
if err != nil {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "debug", "", false, false, false, false)
s.config = config
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
t.Setenv(retryInitialIntervalVar, "1s")
t.Setenv(maxRetryIntervalVar, "2s")
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
}
type mockServer struct {
mgmtProto.ManagementServiceServer
counter *int
}
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
*m.counter++
return m.ManagementServiceServer.Login(ctx, req)
}
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
t.Helper()
dataDir := t.TempDir()
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Signal: &config.Host{
Proto: "http",
URI: signalAddr,
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersManager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
if err != nil {
return nil, "", err
}
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
mock := &mockServer{
ManagementServiceServer: mgmtServer,
counter: counter,
}
mgmtProto.RegisterManagementServiceServer(s, mock)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}

View File

@@ -2,124 +2,22 @@ package server
import (
"context"
"net"
"net/url"
"os/user"
"path/filepath"
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator/validator"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/job"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
daemonProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
)
var (
kaep = keepalive.EnforcementPolicy{
MinTime: 15 * time.Second,
PermitWithoutStream: true,
}
kasp = keepalive.ServerParameters{
MaxConnectionIdle: 15 * time.Second,
MaxConnectionAgeGrace: 5 * time.Second,
Time: 5 * time.Second,
Timeout: 2 * time.Second,
}
)
// TestConnectWithRetryRuns checks that the connectWithRetry function runs and runs the retries according to the times specified via environment variables
// we will use a management server started via to simulate the server and capture the number of retries
func TestConnectWithRetryRuns(t *testing.T) {
// start the signal server
_, signalAddr, err := startSignal(t)
if err != nil {
t.Fatalf("failed to start signal server: %v", err)
}
counter := 0
// start the management server
_, mgmtAddr, err := startManagement(t, signalAddr, &counter)
if err != nil {
t.Fatalf("failed to start management server: %v", err)
}
ctx := internal.CtxInitState(context.Background())
ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
defer cancel()
// create new server
ic := profilemanager.ConfigInput{
ManagementURL: "http://" + mgmtAddr,
ConfigPath: t.TempDir() + "/test-profile.json",
}
config, err := profilemanager.UpdateOrCreateConfig(ic)
if err != nil {
t.Fatalf("failed to create config: %v", err)
}
currUser, err := user.Current()
require.NoError(t, err)
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "test-profile",
Username: currUser.Username,
})
if err != nil {
t.Fatalf("failed to set active profile state: %v", err)
}
s := New(ctx, "debug", "", false, false, false, false)
s.config = config
s.statusRecorder = peer.NewRecorder(config.ManagementURL.String())
t.Setenv(retryInitialIntervalVar, "1s")
t.Setenv(maxRetryIntervalVar, "2s")
t.Setenv(maxRetryTimeVar, "5s")
t.Setenv(retryMultiplierVar, "1")
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
if counter < 3 {
t.Fatalf("expected counter > 2, got %d", counter)
}
}
func TestServer_Up(t *testing.T) {
tempDir := t.TempDir()
origDefaultProfileDir := profilemanager.DefaultConfigPathDir
@@ -259,119 +157,3 @@ func TestServer_SubcribeEvents(t *testing.T) {
assert.NoError(t, err)
}
type mockServer struct {
mgmtProto.ManagementServiceServer
counter *int
}
func (m *mockServer) Login(ctx context.Context, req *mgmtProto.EncryptedMessage) (*mgmtProto.EncryptedMessage, error) {
*m.counter++
return m.ManagementServiceServer.Login(ctx, req)
}
func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Server, string, error) {
t.Helper()
dataDir := t.TempDir()
config := &config.Config{
Stuns: []*config.Host{},
TURNConfig: &config.TURNConfig{},
Signal: &config.Host{
Proto: "http",
URI: signalAddr,
},
Datadir: dataDir,
HttpConfig: nil,
}
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, "", err
}
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir)
if err != nil {
return nil, "", err
}
t.Cleanup(cleanUp)
eventStore := &activity.InMemoryEventStore{}
if err != nil {
return nil, "", err
}
ctrl := gomock.NewController(t)
t.Cleanup(ctrl.Finish)
permissionsManagerMock := permissions.NewMockManager(ctrl)
peersManager := peers.NewManager(store, permissionsManagerMock)
settingsManagerMock := settings.NewMockManager(ctrl)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := validator.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
groupsManager := groups.NewManagerMock()
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
peersUpdateManager := update_channel.NewPeersUpdateManager(metrics)
networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
if err != nil {
return nil, "", err
}
secretsManager, err := nbgrpc.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager, groupsManager)
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
if err != nil {
return nil, "", err
}
mock := &mockServer{
ManagementServiceServer: mgmtServer,
counter: counter,
}
mgmtProto.RegisterManagementServiceServer(s, mock)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}
func startSignal(t *testing.T) (*grpc.Server, string, error) {
t.Helper()
s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp))
lis, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
srv, err := signalServer.NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
proto.RegisterSignalExchangeServer(s, srv)
go func() {
if err = s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}()
return s, lis.Addr().String(), nil
}

View File

@@ -0,0 +1,118 @@
//go:build privileged
package client
import (
"context"
"errors"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/ssh/testutil"
)
func TestSSHClient_CommandExecution(t *testing.T) {
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
}
server, _, client := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
t.Run("ExecuteCommand captures output", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo hello")
assert.NoError(t, err)
assert.Contains(t, string(output), "hello")
})
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
err := client.ExecuteCommandWithIO(ctx, "echo world")
assert.NoError(t, err)
})
t.Run("commands with flags work", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
assert.NoError(t, err)
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
})
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
var testCmd string
if runtime.GOOS == "windows" {
testCmd = "echo hello | Select-String notfound"
} else {
testCmd = "echo 'hello' | grep 'notfound'"
}
_, err := client.ExecuteCommand(ctx, testCmd)
assert.NoError(t, err)
})
}
func TestSSHClient_ContextCancellation(t *testing.T) {
server, serverAddr, _ := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
t.Run("connection with short timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
currentUser := testutil.GetTestUsername(t)
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
if err != nil {
// Check for actual timeout-related errors rather than string matching
assert.True(t,
errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
strings.Contains(err.Error(), "timeout"),
"Expected timeout-related error, got: %v", err)
}
})
t.Run("command execution cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Logf("client close error: %v", err)
}
}()
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cmdCancel()
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
if err != nil {
var exitMissingErr *cryptossh.ExitMissingError
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
errors.As(err, &exitMissingErr)
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
}
})
}

View File

@@ -15,7 +15,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/ssh"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
@@ -78,53 +77,6 @@ func TestSSHClient_DialWithKey(t *testing.T) {
assert.NotNil(t, client.client)
}
func TestSSHClient_CommandExecution(t *testing.T) {
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
}
server, _, client := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
t.Run("ExecuteCommand captures output", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo hello")
assert.NoError(t, err)
assert.Contains(t, string(output), "hello")
})
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
err := client.ExecuteCommandWithIO(ctx, "echo world")
assert.NoError(t, err)
})
t.Run("commands with flags work", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
assert.NoError(t, err)
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
})
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
var testCmd string
if runtime.GOOS == "windows" {
testCmd = "echo hello | Select-String notfound"
} else {
testCmd = "echo 'hello' | grep 'notfound'"
}
_, err := client.ExecuteCommand(ctx, testCmd)
assert.NoError(t, err)
})
}
func TestSSHClient_ConnectionHandling(t *testing.T) {
server, serverAddr, _ := setupTestSSHServerAndClient(t)
defer func() {
@@ -154,59 +106,6 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
}
}
func TestSSHClient_ContextCancellation(t *testing.T) {
server, serverAddr, _ := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
t.Run("connection with short timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
currentUser := testutil.GetTestUsername(t)
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
if err != nil {
// Check for actual timeout-related errors rather than string matching
assert.True(t,
errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
strings.Contains(err.Error(), "timeout"),
"Expected timeout-related error, got: %v", err)
}
})
t.Run("command execution cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Logf("client close error: %v", err)
}
}()
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cmdCancel()
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
if err != nil {
var exitMissingErr *cryptossh.ExitMissingError
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
errors.As(err, &exitMissingErr)
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
}
})
}
func TestSSHClient_NoAuthMode(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)

View File

@@ -0,0 +1,423 @@
//go:build privileged
package proxy
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"os"
"runtime"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func (m *mockDaemon) setJWTToken(token string) {
m.impl.jwtToken = token
}
func TestSSHProxy_Connect(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// TODO: Windows test times out - user switching and command execution tested on Linux
if runtime.GOOS == "windows" {
t.Skip("Skipping on Windows - covered by Linux tests")
}
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
// Configure SSH authorization for the test user
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0}, // Index 0 in AuthorizedUsers
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
defer func() { _ = clientConn.Close() }()
origStdin := os.Stdin
origStdout := os.Stdout
defer func() {
os.Stdin = origStdin
os.Stdout = origStdout
}()
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
go func() {
_, _ = io.Copy(stdinWriter, proxyConn)
}()
go func() {
_, _ = io.Copy(proxyConn, stdoutReader)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
connectErrCh := make(chan error, 1)
go func() {
connectErrCh <- proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 3 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err, "Should connect to proxy server")
defer func() { _ = sshClientConn.Close() }()
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
session, err := sshClient.NewSession()
require.NoError(t, err, "Should create session through full proxy to backend")
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output("echo hello-from-proxy")
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
require.NoError(t, err, "Command should execute successfully through proxy")
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
case <-time.After(3 * time.Second):
t.Fatal("Command execution timed out")
}
_ = session.Close()
_ = sshClient.Close()
_ = clientConn.Close()
cancel()
}
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
// when forwarding commands to the backend. This is critical for tools like
// Ansible that send commands such as:
//
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
//
// The single quotes must be preserved so the backend shell receives the
// subshell expression as a single argument to -c.
func TestSSHProxy_CommandQuoting(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
sshClient, cleanup := setupProxySSHClient(t)
defer cleanup()
// These commands simulate what the SSH protocol delivers as exec payloads.
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
// the local shell strips the outer single quotes, and the SSH exec request
// contains the raw string: /bin/sh -c "( echo hello )"
//
// The proxy must forward this string verbatim. Using session.Command()
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
// the command on the backend.
tests := []struct {
name string
command string
expect string
}{
{
name: "subshell_in_double_quotes",
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
expect: "from-subshell\nouter\n",
},
{
name: "printf_with_special_chars",
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
expect: "hello world\n",
},
{
name: "nested_command_substitution",
command: `/bin/sh -c "echo $(echo nested)"`,
expect: "nested\n",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
session, err := sshClient.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()
var stderrBuf bytes.Buffer
session.Stderr = &stderrBuf
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output(tc.command)
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
if stderrBuf.Len() > 0 {
t.Logf("stderr: %s", stderrBuf.String())
}
require.NoError(t, err, "command should succeed: %s", tc.command)
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
case <-time.After(5 * time.Second):
t.Fatalf("command timed out: %s", tc.command)
}
})
}
}
// setupProxySSHClient creates a full proxy test environment and returns
// an SSH client connected through the proxy to a backend NetBird SSH server.
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
t.Helper()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0},
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
mockDaemon := startMockDaemon(t)
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
origStdin := os.Stdin
origStdout := os.Stdout
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
clientConn, proxyConn := net.Pipe()
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
_ = proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err)
client := cryptossh.NewClient(sshClientConn, chans, reqs)
cleanupFn := func() {
_ = client.Close()
_ = clientConn.Close()
cancel()
os.Stdin = origStdin
os.Stdout = origStdout
_ = sshServer.Stop()
mockDaemon.stop()
jwksServer.Close()
}
return client, cleanupFn
}
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
t.Helper()
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64.RawURLEncoding.EncodeToString(n),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": user,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}

View File

@@ -1,25 +1,12 @@
package proxy
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"os"
"runtime"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
@@ -28,11 +15,7 @@ import (
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
)
func TestMain(m *testing.M) {
@@ -106,331 +89,6 @@ func TestSSHProxy_verifyHostKey(t *testing.T) {
})
}
func TestSSHProxy_Connect(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// TODO: Windows test times out - user switching and command execution tested on Linux
if runtime.GOOS == "windows" {
t.Skip("Skipping on Windows - covered by Linux tests")
}
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
// Configure SSH authorization for the test user
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0}, // Index 0 in AuthorizedUsers
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
defer func() { _ = clientConn.Close() }()
origStdin := os.Stdin
origStdout := os.Stdout
defer func() {
os.Stdin = origStdin
os.Stdout = origStdout
}()
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
go func() {
_, _ = io.Copy(stdinWriter, proxyConn)
}()
go func() {
_, _ = io.Copy(proxyConn, stdoutReader)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
connectErrCh := make(chan error, 1)
go func() {
connectErrCh <- proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 3 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err, "Should connect to proxy server")
defer func() { _ = sshClientConn.Close() }()
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
session, err := sshClient.NewSession()
require.NoError(t, err, "Should create session through full proxy to backend")
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output("echo hello-from-proxy")
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
require.NoError(t, err, "Command should execute successfully through proxy")
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
case <-time.After(3 * time.Second):
t.Fatal("Command execution timed out")
}
_ = session.Close()
_ = sshClient.Close()
_ = clientConn.Close()
cancel()
}
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
// when forwarding commands to the backend. This is critical for tools like
// Ansible that send commands such as:
//
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
//
// The single quotes must be preserved so the backend shell receives the
// subshell expression as a single argument to -c.
func TestSSHProxy_CommandQuoting(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
sshClient, cleanup := setupProxySSHClient(t)
defer cleanup()
// These commands simulate what the SSH protocol delivers as exec payloads.
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
// the local shell strips the outer single quotes, and the SSH exec request
// contains the raw string: /bin/sh -c "( echo hello )"
//
// The proxy must forward this string verbatim. Using session.Command()
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
// the command on the backend.
tests := []struct {
name string
command string
expect string
}{
{
name: "subshell_in_double_quotes",
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
expect: "from-subshell\nouter\n",
},
{
name: "printf_with_special_chars",
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
expect: "hello world\n",
},
{
name: "nested_command_substitution",
command: `/bin/sh -c "echo $(echo nested)"`,
expect: "nested\n",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
session, err := sshClient.NewSession()
require.NoError(t, err)
defer func() { _ = session.Close() }()
var stderrBuf bytes.Buffer
session.Stderr = &stderrBuf
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output(tc.command)
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
if stderrBuf.Len() > 0 {
t.Logf("stderr: %s", stderrBuf.String())
}
require.NoError(t, err, "command should succeed: %s", tc.command)
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
case <-time.After(5 * time.Second):
t.Fatalf("command timed out: %s", tc.command)
}
})
}
}
// setupProxySSHClient creates a full proxy test environment and returns
// an SSH client connected through the proxy to a backend NetBird SSH server.
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
t.Helper()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audiences: []string{audience},
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
testUsername := testutil.GetTestUsername(t)
testJWTUser := "test-username"
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
require.NoError(t, err)
authConfig := &sshauth.Config{
UserIDClaim: sshauth.DefaultUserIDClaim,
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
MachineUsers: map[string][]uint32{
testUsername: {0},
},
}
sshServer.UpdateSSHAuth(authConfig)
sshServerAddr := server.StartTestServer(t, sshServer)
mockDaemon := startMockDaemon(t)
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
require.NoError(t, err)
origStdin := os.Stdin
origStdout := os.Stdout
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
clientConn, proxyConn := net.Pipe()
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
go func() {
_ = proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 5 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err)
client := cryptossh.NewClient(sshClientConn, chans, reqs)
cleanupFn := func() {
_ = client.Close()
_ = clientConn.Close()
cancel()
os.Stdin = origStdin
os.Stdout = origStdout
_ = sshServer.Stop()
mockDaemon.stop()
jwksServer.Close()
}
return client, cleanupFn
}
type mockDaemonServer struct {
proto.UnimplementedDaemonServiceServer
hostKeys map[string][]byte
@@ -492,10 +150,6 @@ func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
m.impl.hostKeys[addr] = pubKey
}
func (m *mockDaemon) setJWTToken(token string) {
m.impl.jwtToken = token
}
func (m *mockDaemon) stop() {
if m.server != nil {
m.server.Stop()
@@ -508,63 +162,3 @@ func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
require.NoError(t, err)
return pubKey
}
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
t.Helper()
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64.RawURLEncoding.EncodeToString(n),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": user,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}

View File

@@ -0,0 +1,66 @@
//go:build unix && privileged
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000, 1001},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "ls -la",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify the command is calling netbird ssh exec
assert.Contains(t, cmd.Args, "ssh")
assert.Contains(t, cmd.Args, "exec")
assert.Contains(t, cmd.Args, "--uid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--gid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--groups")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "1001")
assert.Contains(t, cmd.Args, "--working-dir")
assert.Contains(t, cmd.Args, "/home/testuser")
assert.Contains(t, cmd.Args, "--shell")
assert.Contains(t, cmd.Args, "/bin/bash")
assert.Contains(t, cmd.Args, "--cmd")
assert.Contains(t, cmd.Args, "ls -la")
}
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify no command mode (command is empty so no --cmd flag)
assert.NotContains(t, cmd.Args, "--cmd")
assert.NotContains(t, cmd.Args, "--interactive")
}

View File

@@ -73,61 +73,6 @@ func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) {
}
}
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000, 1001},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "ls -la",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify the command is calling netbird ssh exec
assert.Contains(t, cmd.Args, "ssh")
assert.Contains(t, cmd.Args, "exec")
assert.Contains(t, cmd.Args, "--uid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--gid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--groups")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "1001")
assert.Contains(t, cmd.Args, "--working-dir")
assert.Contains(t, cmd.Args, "/home/testuser")
assert.Contains(t, cmd.Args, "--shell")
assert.Contains(t, cmd.Args, "/bin/bash")
assert.Contains(t, cmd.Args, "--cmd")
assert.Contains(t, cmd.Args, "ls -la")
}
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify no command mode (command is empty so no --cmd flag)
assert.NotContains(t, cmd.Args, "--cmd")
assert.NotContains(t, cmd.Args, "--interactive")
}
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
// This test requires root privileges and will be skipped if not running as root
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {

View File

@@ -3,6 +3,7 @@ package system
import (
"context"
"net/netip"
"slices"
"strings"
log "github.com/sirupsen/logrus"
@@ -121,6 +122,23 @@ func (i *Info) SetFlags(
}
}
// removeAddresses drops network addresses whose IP matches any of the given
// addresses, regardless of prefix length. Used to exclude the NetBird overlay
// address, which otherwise churns the meta as the interface comes and goes.
func (i *Info) removeAddresses(ips ...netip.Addr) {
if len(ips) == 0 {
return
}
filtered := i.NetworkAddresses[:0]
for _, addr := range i.NetworkAddresses {
if slices.Contains(ips, addr.NetIP.Addr()) {
continue
}
filtered = append(filtered, addr)
}
i.NetworkAddresses = filtered
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
func extractUserAgent(ctx context.Context) string {
md, hasMeta := metadata.FromOutgoingContext(ctx)
@@ -147,7 +165,9 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
}
// GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
// excludeIPs are dropped from the reported network addresses (e.g. our own
// WireGuard overlay address, which otherwise churns the peer meta).
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks, excludeIPs ...netip.Addr) (*Info, error) {
log.Debugf("gathering system information with checks: %d", len(checks))
processCheckPaths := make([]string, 0)
for _, check := range checks {
@@ -162,6 +182,7 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
info := GetInfo(ctx)
info.Files = files
info.removeAddresses(excludeIPs...)
log.Debugf("all system information gathered successfully")
return info, nil

View File

@@ -2,6 +2,7 @@ package system
import (
"context"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
@@ -43,3 +44,42 @@ func Test_NetAddresses(t *testing.T) {
t.Errorf("no network addresses found")
}
}
func TestInfo_RemoveAddresses(t *testing.T) {
addr := func(cidr string) NetworkAddress {
return NetworkAddress{NetIP: netip.MustParsePrefix(cidr)}
}
info := &Info{
NetworkAddresses: []NetworkAddress{
addr("192.168.1.7/24"),
addr("100.76.70.97/32"), // overlay v4 (host mask /32)
addr("2001:818:c51b:4800:845:a65d:ae6f:623f/64"), // real global v6
addr("fd00:1234::1/64"), // overlay v6
},
}
// Overlay addresses as the engine knows them, with a different mask (/16, /64).
info.removeAddresses(
netip.MustParseAddr("100.76.70.97"),
netip.MustParseAddr("fd00:1234::1"),
)
want := []string{"192.168.1.7/24", "2001:818:c51b:4800:845:a65d:ae6f:623f/64"}
if len(info.NetworkAddresses) != len(want) {
t.Fatalf("got %d addresses, want %d: %v", len(info.NetworkAddresses), len(want), info.NetworkAddresses)
}
for i, w := range want {
if got := info.NetworkAddresses[i].NetIP.String(); got != w {
t.Errorf("address[%d] = %s, want %s", i, got, w)
}
}
}
func TestInfo_RemoveAddresses_NoOp(t *testing.T) {
info := &Info{NetworkAddresses: []NetworkAddress{{NetIP: netip.MustParsePrefix("10.0.0.1/24")}}}
info.removeAddresses()
if len(info.NetworkAddresses) != 1 {
t.Errorf("expected no change with empty input, got %v", info.NetworkAddresses)
}
}

View File

@@ -46,7 +46,9 @@ func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() {
// Skip link-local and multicast: they carry no routable peer info and the
// IPv6 link-local of a flapping NIC churns the meta on every up/down.
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())

View File

@@ -0,0 +1,45 @@
//go:build !ios
package system
import (
"net"
"testing"
)
func mustIPNet(t *testing.T, cidr string) *net.IPNet {
t.Helper()
ip, ipNet, err := net.ParseCIDR(cidr)
if err != nil {
t.Fatalf("parse %q: %v", cidr, err)
}
ipNet.IP = ip
return ipNet
}
func TestToNetworkAddress_Filtering(t *testing.T) {
const mac = "c8:4b:d6:b6:04:ac"
tests := []struct {
name string
cidr string
want bool
}{
{"ipv4 global", "10.65.16.181/23", true},
{"ipv6 global", "2620:52:0:4110:102d:6a98:ee75:8b92/64", true},
{"ipv4 loopback", "127.0.0.1/8", false},
{"ipv6 loopback", "::1/128", false},
{"ipv6 link-local", "fe80::871:4c25:23d7:2529/64", false},
{"ipv4 link-local", "169.254.1.2/16", false},
{"ipv6 multicast", "ff02::1/128", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, got := toNetworkAddress(mustIPNet(t, tt.cidr), mac)
if got != tt.want {
t.Errorf("toNetworkAddress(%s) ok = %v, want %v", tt.cidr, got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,196 @@
//go:build privileged && (linux || darwin)
// Package privileged provides a self-hosting harness that runs the repo's
// privileged-tagged test suite inside a --privileged --cap-add=NET_ADMIN
// container, so developers can exercise the root/system-mutating tests on a
// non-root host with a single `go test` invocation.
package privileged
import (
"bytes"
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"time"
"github.com/moby/moby/api/types/container"
"github.com/ory/dockertest/v4"
)
// containerImage / containerTag match the image used by the CI privileged job
// (.github/workflows/golang-test-linux.yml, test_client_on_docker).
const (
containerImage = "golang"
containerTag = "1.25-alpine"
)
const (
containerWorkdir = "/app"
containerGoCache = "/root/.cache/go-build"
containerGoModCache = "/go/pkg/mod"
)
// alpinePackages are the build/runtime deps the privileged tests need, mirroring
// the CI container setup.
const alpinePackages = "ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base"
// privilegedTestPackages is the package list the suite runs, excluding the
// server-side trees and UI/upload helpers, matching the CI Docker job's filter.
const privilegedTestPackages = `go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server`
// testWriter forwards container output to the test log line by line.
type testWriter struct{ t *testing.T }
func (w testWriter) Write(p []byte) (int, error) {
for _, line := range strings.Split(strings.TrimRight(string(p), "\n"), "\n") {
w.t.Log(line)
}
return len(p), nil
}
// TestRunPrivilegedSuiteInDocker spins up a privileged container, mounts the repo,
// and runs `go test -tags 'devcert privileged'` inside it. When already running
// inside that container (DOCKER_CI=true) it returns immediately so the real
// privileged tests in the suite execute in place instead of recursing.
func TestRunPrivilegedSuiteInDocker(t *testing.T) {
if os.Getenv("DOCKER_CI") == "true" {
t.Skip("inside privileged container, skipping container spawn; privileged tests run in place")
}
repoRoot, err := findRepoRoot()
if err != nil {
t.Fatalf("locate repo root: %v", err)
}
goCache, goModCache := hostGoCaches(t)
// dockertest reads DOCKER_HOST; point it at the active context's socket when
// the default one is absent (macOS Docker Desktop, Colima, OrbStack).
if host := dockerHost(); host != "" {
t.Setenv("DOCKER_HOST", host)
}
// NewPoolT registers container cleanup via t.Cleanup automatically.
pool := dockertest.NewPoolT(t, "", dockertest.WithMaxWait(30*time.Minute))
// Keep the container alive so the suite runs via Exec, which yields a clean
// exit code (the v4 Resource API exposes no container wait/exit-code).
resource := pool.RunT(t, containerImage,
dockertest.WithTag(containerTag),
dockertest.WithWorkingDir(containerWorkdir),
dockertest.WithMounts([]string{
repoRoot + ":" + containerWorkdir,
goCache + ":" + containerGoCache,
goModCache + ":" + containerGoModCache,
}),
dockertest.WithEnv([]string{
"CGO_ENABLED=1",
"CI=true",
"DOCKER_CI=true",
"CONTAINER=true",
"GOCACHE=" + containerGoCache,
"GOMODCACHE=" + containerGoModCache,
}),
dockertest.WithCmd([]string{"sleep", "infinity"}),
dockertest.WithHostConfig(func(hc *container.HostConfig) {
hc.Privileged = true
hc.CapAdd = []string{"NET_ADMIN"}
}),
dockertest.WithoutReuse(),
)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
result, err := resource.Exec(ctx, []string{"sh", "-c", buildTestScript()})
if err != nil {
t.Fatalf("run privileged suite in container: %v", err)
}
w := testWriter{t}
_, _ = w.Write([]byte(result.StdOut))
_, _ = w.Write([]byte(result.StdErr))
if result.ExitCode != 0 {
t.Fatalf("privileged test suite failed in container (exit code %d)", result.ExitCode)
}
}
// findRepoRoot walks up from the test's working directory to the module root.
func findRepoRoot() (string, error) {
dir, err := os.Getwd()
if err != nil {
return "", err
}
for {
if _, statErr := os.Stat(filepath.Join(dir, "go.mod")); statErr == nil {
return dir, nil
}
parent := filepath.Dir(dir)
if parent == dir {
return "", fmt.Errorf("go.mod not found above %s", dir)
}
dir = parent
}
}
// dockerHost returns a DOCKER_HOST override when the default socket is missing.
// An empty result means the caller should leave DOCKER_HOST untouched (it is
// already set, or the default unix socket exists). When neither is present
// (common on macOS Docker Desktop, Colima and OrbStack, which use a per-user
// socket), it resolves the active docker context's endpoint.
func dockerHost() string {
if os.Getenv("DOCKER_HOST") != "" {
return ""
}
if _, err := os.Stat("/var/run/docker.sock"); err == nil {
return ""
}
out, err := exec.Command("docker", "context", "inspect", "-f", "{{.Endpoints.docker.Host}}").Output()
if err != nil {
return ""
}
return strings.TrimSpace(string(out))
}
// hostGoCaches resolves the host GOCACHE/GOMODCACHE so the container reuses the
// existing build/module cache for speed.
func hostGoCaches(t *testing.T) (string, string) {
t.Helper()
return goEnv(t, "GOCACHE"), goEnv(t, "GOMODCACHE")
}
func goEnv(t *testing.T, key string) string {
t.Helper()
var out bytes.Buffer
cmd := exec.Command("go", "env", key)
cmd.Stdout = &out
if err := cmd.Run(); err != nil {
t.Fatalf("go env %s: %v", key, err)
}
return strings.TrimSpace(out.String())
}
// buildTestScript builds the in-container command. PRIV_PKGS overrides the package
// list (default: the full filtered set); PRIV_RUN adds a -run test-name filter.
// Both empty reproduces the full privileged suite.
func buildTestScript() string {
pkgs := privilegedTestPackages + " | xargs"
if p := os.Getenv("PRIV_PKGS"); p != "" {
pkgs = "echo " + p + " | xargs"
}
runFilter := ""
if r := os.Getenv("PRIV_RUN"); r != "" {
runFilter = "-run '" + r + "' "
}
return fmt.Sprintf(
"apk update >/dev/null && apk add --no-cache %s >/dev/null && %s go test -buildvcs=false -tags 'devcert privileged' %s-v -timeout 20m -p 1",
alpinePackages, pkgs, runFilter,
)
}

View File

@@ -336,11 +336,11 @@ type serviceClient struct {
// mNetworks + mExitNode submenu items. Combines features.DisableNetworks
// AND s.connected — both must be true for the menus to be active.
// Zero value (false) matches the Disable() call at AddMenuItem time.
networksMenuEnabled bool
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
wQuickActions fyne.Window
networksMenuEnabled bool
showNetworks bool
wNetworks fyne.Window
wProfiles fyne.Window
wQuickActions fyne.Window
eventManager *event.Manager

View File

@@ -53,9 +53,6 @@ type NameServerGroup struct {
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `gorm:"index"`
// AccountSeqID is a per-account monotonically increasing identifier used as the
// compact wire id when sending NetworkMap components to capable peers.
AccountSeqID uint32 `json:"-" gorm:"index:idx_nameserver_groups_account_seq_id;not null;default:0"`
// Name group name
Name string
// Description group description

View File

@@ -0,0 +1,78 @@
# Privileged tests
Some tests in this repo need `root` or mutate host network state: they create
TUN/WireGuard interfaces, open netlink/raw sockets, run eBPF programs, or shell
out to `ip`/`iptables`/`nft`/`ifconfig`/`route`. Running them on a developer
machine would require `sudo` and could leave stray interfaces or routes behind.
These tests are gated behind the **`privileged` build tag** so the default test
run is host-safe.
## Running tests
```bash
# Host-safe: excludes privileged tests. Runs as a normal user, no sudo.
make test-unit
# equivalently:
go test -tags devcert ./...
# Privileged suite: runs the privileged-tagged tests inside a
# --privileged --cap-add=NET_ADMIN container (requires Docker).
make test-privileged
# Narrow the container run to a single test / package:
PRIV_RUN=TestNftablesManager PRIV_PKGS=./client/firewall/nftables/... make test-privileged
```
`PRIV_RUN` adds a `-run` test-name filter and `PRIV_PKGS` overrides the package
list; both are optional and default to the full privileged suite.
`make test-privileged` invokes the `ory/dockertest` harness in
`client/testutil/privileged/`. The harness:
1. Skips immediately when it detects it is already inside the container
(`DOCKER_CI=true`), so the privileged tests run in place instead of recursing.
2. Otherwise spins up a `golang:1.25-alpine` container (matching CI),
bind-mounts the repo and the host Go build/module caches, installs the
required packages, and runs `go test -tags 'devcert privileged'` over the
client packages.
3. Streams the container's output to the test log and fails if the suite fails.
## Adding a privileged test
A test is privileged if it does any of:
- creates a real interface via `iface.NewWGIFace(...).Create()`,
- opens a netlink or raw socket that hard-fails without `CAP_NET_ADMIN`,
- runs an eBPF program (`ebpf.*.Listen()`),
- shells out to `ip`, `iptables`, `nft`, `ifconfig`, or `route` to change state.
Add the tag to the **top** of the file, combined with any existing platform
constraint:
```go
//go:build privileged && linux
package foo
```
If a file mixes privileged and pure-logic tests, **split it**: keep the pure
tests (and any shared data — type/var declarations, table-driven `testCases`,
helper interfaces) in an untagged file, and move the privileged tests into a
`*_privileged_test.go` file with the tag. Shared declarations must stay untagged,
otherwise the unprivileged files in the package will not compile.
Always verify both build modes compile on every target platform:
```bash
go vet -tags devcert ./...
go vet -tags 'devcert privileged' ./...
```
## CI
- The `Client / Unit` job runs `go test -tags devcert` with **no** `sudo` — only
host-safe tests.
- The `Client (Docker) / Unit` job runs `go test -tags 'devcert privileged'`
inside a `--privileged --cap-add=NET_ADMIN` container, which is where the
privileged tests actually execute.

13
go.mod
View File

@@ -64,7 +64,7 @@ require (
github.com/google/go-cmp v0.7.0
github.com/google/gopacket v1.1.19
github.com/google/nftables v0.3.0
github.com/gopacket/gopacket v1.4.0
github.com/gopacket/gopacket v1.6.1
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
@@ -78,10 +78,12 @@ require (
github.com/mdp/qrterminal/v3 v3.2.1
github.com/miekg/dns v1.1.72
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/moby/moby/api v1.54.1
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
github.com/oapi-codegen/runtime v1.1.2
github.com/okta/okta-sdk-golang/v2 v2.18.0
github.com/ory/dockertest/v4 v4.0.0
github.com/oschwald/maxminddb-golang v1.12.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/petermattis/goid v0.0.0-20250303134427-723919f7f203
@@ -145,7 +147,7 @@ require (
dario.cat/mergo v1.0.1 // indirect
filippo.io/edwards25519 v1.1.1 // indirect
github.com/AppsFlyer/go-sundheit v0.6.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect
github.com/Azure/go-ntlmssp v0.1.0 // indirect
github.com/BurntSushi/toml v1.5.0 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
@@ -177,6 +179,8 @@ require (
github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cenkalti/backoff/v5 v5.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
github.com/containerd/errdefs/pkg v0.3.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
@@ -271,11 +275,12 @@ require (
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/moby/client v0.4.0 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.5.0 // indirect
github.com/moby/sys/user v0.3.0 // indirect
github.com/moby/sys/userns v0.1.0 // indirect
github.com/moby/term v0.5.0 // indirect
github.com/moby/term v0.5.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
@@ -341,7 +346,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a
replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

32
go.sum
View File

@@ -23,8 +23,8 @@ github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=
github.com/AppsFlyer/go-sundheit v0.6.0/go.mod h1:LDdBHD6tQBtmHsdW+i1GwdTt6Wqc0qazf5ZEJVTbTME=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg=
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Azure/go-ntlmssp v0.1.0 h1:DjFo6YtWzNqNvQdrwEyr/e4nhU3vRiwenz5QX7sFz+A=
github.com/Azure/go-ntlmssp v0.1.0/go.mod h1:NYqdhxd/8aAct/s4qSYZEerdPuH1liG2/X9DiVTbhpk=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
@@ -117,6 +117,10 @@ github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI=
github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M=
github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE=
github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
@@ -315,8 +319,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
github.com/gopacket/gopacket v1.4.0 h1:cr1OlFpzksCkZHNO0eLjaSSOrMQnpPXg0j6qHIY3y2U=
github.com/gopacket/gopacket v1.4.0/go.mod h1:EpvsxINeehp5qj4YMKMLf2/dekdhKn2IIAO/ZOifS7o=
github.com/gopacket/gopacket v1.6.1 h1:S19Ok/KVGDFNHVW2uCva5U0vZ+uHqiZQdxteL50v6Ak=
github.com/gopacket/gopacket v1.6.1/go.mod h1:i3NaGaqfoWKAr1+g7qxEdWsmfT+MXuWkAe9+THv8LME=
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
@@ -480,6 +484,10 @@ github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zx
github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/moby/api v1.54.1 h1:TqVzuJkOLsgLDDwNLmYqACUuTehOHRGKiPhvH8V3Nn4=
github.com/moby/moby/api v1.54.1/go.mod h1:+RQ6wluLwtYaTd1WnPLykIDPekkuyD/ROWQClE83pzs=
github.com/moby/moby/client v0.4.0 h1:S+2XegzHQrrvTCvF6s5HFzcrywWQmuVnhOXe2kiWjIw=
github.com/moby/moby/client v0.4.0/go.mod h1:QWPbvWchQbxBNdaLSpoKpCdf5E+WxFAgNHogCWDoa7g=
github.com/moby/patternmatcher v0.6.0 h1:GmP9lR19aU5GqSSFko+5pRqHi+Ohk1O69aFiKkVGiPk=
github.com/moby/patternmatcher v0.6.0/go.mod h1:hDPoyOpDY7OrrMDLaYoY3hf52gNCR/YOUYxkhApJIxc=
github.com/moby/sys/sequential v0.5.0 h1:OPvI35Lzn9K04PBbCLW0g4LcFAJgHsvXsRyewg5lXtc=
@@ -488,8 +496,8 @@ github.com/moby/sys/user v0.3.0 h1:9ni5DlcW5an3SvRSx4MouotOygvzaXbaSrc/wGDFWPo=
github.com/moby/sys/user v0.3.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs=
github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g=
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
@@ -510,8 +518,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk=
github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a h1:3CWK+yTvRKOcC0Q8VCTGy4l60TEb27CQVS7LkMxwjmw=
github.com/netbirdio/wireguard-go v0.0.0-20260628102922-2834bebf6c1a/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=
@@ -542,6 +550,8 @@ github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/ory/dockertest/v4 v4.0.0 h1:i19aFsO/VXE0VrMk4ifnKW4G/KIJ93PCjLOslxXoPME=
github.com/ory/dockertest/v4 v4.0.0/go.mod h1:b5Ofu8VIxWNhXFvQcLu17pRNQdoUBKtXBW74G4Ygzx8=
github.com/oschwald/maxminddb-golang v1.12.0 h1:9FnTOD0YOhP7DGxGsq4glzpGy5+w7pq50AS6wALUMYs=
github.com/oschwald/maxminddb-golang v1.12.0/go.mod h1:q0Nob5lTCqyQ8WT6FYgS1L7PXKVVbgiymefNwIjPzgY=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
@@ -973,11 +983,13 @@ gorm.io/driver/sqlite v1.5.7/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDa
gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8=
gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
howett.net/plist v1.0.1 h1:37GdZ8tP09Q35o9ych3ehygcsL+HqKSwzctveSlarvM=
howett.net/plist v1.0.1/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g=
pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk=
pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04=
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=

View File

@@ -308,7 +308,7 @@ func (s *Storage) OpenStorage(logger *slog.Logger) (storage.Storage, error) {
if file == "" {
return nil, fmt.Errorf("sqlite3 storage requires 'file' config")
}
return newSQLite3(file).Open(logger)
return (&sql.SQLite3{File: file}).Open(logger)
case "postgres":
dsn, _ := s.Config["dsn"].(string)
if dsn == "" {

View File

@@ -20,6 +20,7 @@ import (
"github.com/dexidp/dex/server"
"github.com/dexidp/dex/server/signer"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/sql"
jose "github.com/go-jose/go-jose/v4"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
@@ -78,7 +79,7 @@ func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
// Initialize SQLite storage
dbPath := filepath.Join(config.DataDir, "oidc.db")
sqliteConfig := newSQLite3(dbPath)
sqliteConfig := &sql.SQLite3{File: dbPath}
stor, err := sqliteConfig.Open(logger)
if err != nil {
return nil, fmt.Errorf("failed to open storage: %w", err)

View File

@@ -1,15 +0,0 @@
//go:build cgo
package dex
import (
sql "github.com/dexidp/dex/storage/sql"
)
// newSQLite3 builds the dex SQLite3 config. CGO builds use the upstream
// struct that takes a File path. Non-CGO builds get an empty stub whose
// Open() returns the dex "SQLite not available" error — correct behaviour
// for binaries that can't link sqlite3 (e.g. cross-compiled ARM targets).
func newSQLite3(file string) *sql.SQLite3 {
return &sql.SQLite3{File: file}
}

View File

@@ -1,15 +0,0 @@
//go:build !cgo
package dex
import (
sql "github.com/dexidp/dex/storage/sql"
)
// newSQLite3 for non-CGO builds. The dex SQLite3 stub has no fields and its
// Open() returns an error documenting the missing CGO support — correct
// behaviour for cross-compiled artefacts that never actually run the
// embedded IdP. The `file` argument is ignored.
func newSQLite3(_ string) *sql.SQLite3 {
return &sql.SQLite3{}
}

View File

@@ -351,6 +351,11 @@ initialize_default_values() {
NETBIRD_STUN_PORT=3478
# Docker images
# Record whether the operator explicitly pinned the server/proxy images via
# env vars, so the agent-network preset can pick its own defaults without
# clobbering an explicit override.
NETBIRD_SERVER_IMAGE_EXPLICIT=${NETBIRD_SERVER_IMAGE:+true}
NETBIRD_PROXY_IMAGE_EXPLICIT=${NETBIRD_PROXY_IMAGE:+true}
DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"}
# Combined server replaces separate signal, relay, and management containers
NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"}
@@ -398,7 +403,53 @@ configure_domain() {
return 0
}
apply_agent_network_preset() {
# Agent-network turnkey install: built-in Traefik + NetBird Proxy with
# NB_PROXY_PRIVATE=true, dashboard locked to agent-network-only mode.
# Bypasses every reverse-proxy / proxy / CrowdSec prompt. The only
# inputs we still need from the operator are the domain (handled by
# configure_domain via NETBIRD_DOMAIN env var or interactive prompt)
# and the ACME email — both honor env vars first and fall back to a
# prompt only when unset. CrowdSec is intentionally off.
REVERSE_PROXY_TYPE="0"
ENABLE_PROXY="true"
ENABLE_CROWDSEC="false"
# Agent-network ships dedicated server/proxy images. Honor an explicit
# env override; otherwise pin the agent-network builds.
if [[ "${NETBIRD_SERVER_IMAGE_EXPLICIT}" != "true" ]]; then
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:0.74.0-rc.2"
fi
if [[ "${NETBIRD_PROXY_IMAGE_EXPLICIT}" != "true" ]]; then
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:0.74.0-rc.2"
fi
if [[ -n "${NETBIRD_LETSENCRYPT_EMAIL}" ]]; then
TRAEFIK_ACME_EMAIL="${NETBIRD_LETSENCRYPT_EMAIL}"
else
TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email)
fi
echo "" > /dev/stderr
echo "Agent-network preset enabled (NETBIRD_AGENT_NETWORK=true):" > /dev/stderr
echo " - reverse proxy: built-in Traefik" > /dev/stderr
echo " - NetBird Proxy: enabled with NB_PROXY_PRIVATE=true" > /dev/stderr
echo " - server image: ${NETBIRD_SERVER_IMAGE}" > /dev/stderr
echo " - proxy image: ${NETBIRD_PROXY_IMAGE}" > /dev/stderr
echo " - dashboard: NETBIRD_AGENT_NETWORK_ONLY=true" > /dev/stderr
echo " - CrowdSec: disabled" > /dev/stderr
echo " - Let's Encrypt email: ${TRAEFIK_ACME_EMAIL}" > /dev/stderr
echo "" > /dev/stderr
}
configure_reverse_proxy() {
# Short-circuit: agent-network preset locks every reverse-proxy /
# proxy / CrowdSec choice and bypasses the interactive prompts.
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
apply_agent_network_preset
return 0
fi
# Prompt for reverse proxy type
REVERSE_PROXY_TYPE=$(read_reverse_proxy_type)
@@ -910,6 +961,15 @@ NGINX_SSL_PORT=443
# Letsencrypt
LETSENCRYPT_DOMAIN=none
EOF
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
cat <<EOF
# Agent-network preset: dashboard hides the standard NetBird surfaces
# and exposes only the AI Observability + agent-network configuration
# pages. Paired with NB_PROXY_PRIVATE=true on the proxy side.
NETBIRD_AGENT_NETWORK_ONLY=true
EOF
fi
return 0
}
@@ -946,6 +1006,17 @@ NB_PROXY_PROXY_PROTOCOL=true
NB_PROXY_TRUSTED_PROXIES=$TRAEFIK_IP
EOF
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
cat <<EOF
# Agent-network preset: turn the proxy into the private reverse-proxy
# ingress for agent-network synth services. Disables the public-facing
# surface so the proxy serves only synth-generated routes (the
# llm_router-driven LLM endpoints) and the per-account inbound
# listeners on the embedded netstack.
NB_PROXY_PRIVATE=true
EOF
fi
if [[ "$ENABLE_CROWDSEC" == "true" && -n "$CROWDSEC_BOUNCER_KEY" ]]; then
cat <<EOF
NB_PROXY_CROWDSEC_API_URL=http://crowdsec:8080
@@ -1326,12 +1397,20 @@ print_builtin_traefik_instructions() {
echo " - 51820/udp (WIREGUARD - (optional) for P2P proxy connections)"
fi
echo ""
echo "This setup is ideal for homelabs and smaller organization deployments."
echo "For enterprise environments requiring high availability and advanced integrations,"
echo "consider a commercial on-prem license or scaling your open source deployment:"
echo ""
echo " Commercial license: https://netbird.io/pricing#on-prem"
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
echo "For enterprise environments requiring high availability and advanced integrations,"
echo "consider a commercial on-prem license:"
echo ""
echo " Commercial license: https://netbird.ai/pricing"
echo " Documentation: https://docs.netbird.io/agent-network"
else
echo "This setup is ideal for homelabs and smaller organization deployments."
echo "For enterprise environments requiring high availability and advanced integrations,"
echo "consider a commercial on-prem license or scaling your open source deployment:"
echo ""
echo " Commercial license: https://netbird.io/pricing#on-prem"
echo " Scaling guide: https://docs.netbird.io/scaling-your-self-hosted-deployment"
fi
echo ""
if [[ "$ENABLE_PROXY" == "true" ]]; then
echo "NetBird Proxy:"
@@ -1354,6 +1433,11 @@ print_builtin_traefik_instructions() {
echo ""
fi
fi
if [[ "${NETBIRD_AGENT_NETWORK}" == "true" ]]; then
echo "Note: The public domain is only for setting up secure connections."
echo "Your APIs and agent services remain private and are never exposed publicly."
echo ""
fi
return 0
}

View File

@@ -56,12 +56,6 @@ type Controller struct {
proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator
// componentsDisabled, when true, forces the controller to emit legacy
// proto.NetworkMap to every peer regardless of capability. Set once at
// construction and never written after — readers race-free without a
// mutex.
componentsDisabled bool
}
type bufferUpdate struct {
@@ -95,27 +89,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
settingsManager: settingsManager,
dnsDomain: dnsDomain,
config: config,
componentsDisabled: parseBoolEnv("NB_NETWORK_MAP_COMPONENTS_DISABLE"),
proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager,
}
}
// PeerNeedsComponents reports whether the gRPC layer should emit the
// component-based wire format for this peer.
func (c *Controller) PeerNeedsComponents(p *nbpeer.Peer) bool {
return p != nil && p.SupportsComponentNetworkMap() && !c.componentsDisabled
}
// parseBoolEnv reads an env var via strconv.ParseBool so callers accept the
// usual "1/t/T/TRUE/true/True" set instead of being strict about a single
// literal.
func parseBoolEnv(key string) bool {
v, _ := strconv.ParseBool(os.Getenv(key))
return v
}
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
if err != nil {
@@ -225,26 +204,18 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
result := account.GetPeerNetworkMapResult(ctx, p.ID, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap := proxyNetworkMaps[p.ID]
if result.NetworkMap != nil && proxyNetworkMap != nil {
result.NetworkMap.Merge(proxyNetworkMap)
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
var update *proto.SyncResponse
if result.IsComponents() {
// proxyNetworkMap rides the envelope as a ProxyPatch sidecar;
// the client merges it into Calculate()'s output the same
// way the legacy server did via NetworkMap.Merge.
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
} else {
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
}
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
@@ -454,11 +425,11 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err
}
result := account.GetPeerNetworkMapResult(ctx, peerId, c.componentsDisabled, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
proxyNetworkMap := proxyNetworkMaps[peer.ID]
if result.NetworkMap != nil && proxyNetworkMap != nil {
result.NetworkMap.Merge(proxyNetworkMap)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
extraSettings, err := c.settingsManager.GetExtraSettings(ctx, peer.AccountID)
@@ -469,12 +440,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
peerGroups := account.GetPeerGroups(peerId)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
var update *proto.SyncResponse
if result.IsComponents() {
update = grpc.ToComponentSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.Components, proxyNetworkMap, dnsDomain, postureChecks, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
} else {
update = grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, result.NetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
}
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
@@ -521,66 +487,6 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil
}
// GetValidatedPeerWithComponents is the components-format counterpart of
// GetValidatedPeerWithMap. It returns raw NetworkMapComponents for capable
// peers along with the proxy NetworkMap fragment (BYOP / port-forwarding
// data the legacy server folds in via NetworkMap.Merge). The gRPC layer
// encodes both into the wire envelope. Callers must gate on capability
// themselves before dispatching here — this method does NOT branch on it.
func (c *Controller) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
return peer, &types.NetworkMapComponents{Network: network.Copy()}, nil, nil, 0, nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
account.InjectProxyPolicies(ctx)
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, nil, 0, err
}
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, nil, 0, err
}
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
return nil, nil, nil, nil, 0, err
}
// Fetch the proxy network map fragment for this peer alongside the
// components — same single-account-load path the streaming controller
// uses, so initial-sync delivers BYOP/forwarding patches synchronously
// instead of waiting for the next streaming push.
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, nil, 0, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
components := account.GetPeerNetworkMapComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, groupIDToUserIDs)
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
return peer, components, proxyNetworkMaps[peer.ID], postureChecks, dnsFwdPort, nil
}
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
if len(peerIDs) == 0 {

View File

@@ -24,10 +24,6 @@ type Controller interface {
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error)
GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error)
// PeerNeedsComponents combines the peer's advertised capability with the
// kill-switch flag — the only public predicate gRPC layers should ask.
PeerNeedsComponents(p *nbpeer.Peer) bool
GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)

View File

@@ -143,39 +143,6 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, peerID)
}
// GetValidatedPeerWithComponents mocks base method.
func (m *MockController) GetValidatedPeerWithComponents(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMapComponents, *types.NetworkMap, []*posture.Checks, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetValidatedPeerWithComponents", ctx, isRequiresApproval, accountID, p)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMapComponents)
ret2, _ := ret[2].(*types.NetworkMap)
ret3, _ := ret[3].([]*posture.Checks)
ret4, _ := ret[4].(int64)
ret5, _ := ret[5].(error)
return ret0, ret1, ret2, ret3, ret4, ret5
}
// GetValidatedPeerWithComponents indicates an expected call of GetValidatedPeerWithComponents.
func (mr *MockControllerMockRecorder) GetValidatedPeerWithComponents(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithComponents", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithComponents), ctx, isRequiresApproval, accountID, p)
}
// PeerNeedsComponents mocks base method.
func (m *MockController) PeerNeedsComponents(p *peer.Peer) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PeerNeedsComponents", p)
ret0, _ := ret[0].(bool)
return ret0
}
// PeerNeedsComponents indicates an expected call of PeerNeedsComponents.
func (mr *MockControllerMockRecorder) PeerNeedsComponents(p any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PeerNeedsComponents", reflect.TypeOf((*MockController)(nil).PeerNeedsComponents), p)
}
// OnPeerConnected mocks base method.
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
m.ctrl.T.Helper()

View File

@@ -1,813 +0,0 @@
package grpc
import (
"encoding/base64"
"strconv"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
)
// wgKeyRawLen is the raw byte length of a WireGuard public key.
const wgKeyRawLen = 32
// ComponentsEnvelopeInput bundles the data the component-format encoder needs.
// The envelope is fully self-contained — every field needed by the client's
// local Calculate() comes from the components struct itself. The only
// externally-supplied data is the receiving peer's PeerConfig (which is
// computed alongside the components in the network_map controller and reused
// from the legacy proto path) and the dns_domain string.
type ComponentsEnvelopeInput struct {
Components *types.NetworkMapComponents
PeerConfig *proto.PeerConfig
DNSDomain string
DNSForwarderPort int64
// UserIDClaim is the OIDC claim name the client should embed in
// SshAuth.UserIDClaim when reconstructing the NetworkMap. Empty value
// is OK — client treats empty as "no SshAuth to build".
UserIDClaim string
// ProxyPatch carries pre-expanded NetworkMap fragments injected by
// external controllers (BYOP/port-forwarding). Nil when no proxy data
// is present; encoder skips the field in that case.
ProxyPatch *proto.ProxyPatch
}
// EncodeNetworkMapEnvelope converts NetworkMapComponents into the component
// wire envelope. The encoder is intentionally non-deterministic: it iterates
// Go maps in their native (random) order. Indexes inside the envelope
// (peer_indexes, source_group_ids, agent_version_idx, router_peer_indexes)
// are self-consistent within a single encode, so the decoder reconstructs
// the same typed objects regardless of emit order. Tests that need to
// compare envelopes do so semantically via proto round-trip + canonicalize,
// not byte-equal.
//
// Callers must NOT concatenate or merge envelopes from different encodes —
// index spaces are local to a single envelope.
func EncodeNetworkMapEnvelope(in ComponentsEnvelopeInput) *proto.NetworkMapEnvelope {
c := in.Components
// Graceful degrade when components is nil — matches the legacy path's
// behaviour for missing/unvalidated peers (return a NetworkMap with only
// Network populated). The receiver gets an envelope it can decode
// without crashing; AccountSettings stays non-nil so client-side
// dereferences are safe.
if c == nil {
// Match legacy missing-peer minimum: a NetworkMap with only Network
// populated. The receiver gets enough to bootstrap (Network
// identifier, dns_domain, account_settings) and nothing else.
return &proto.NetworkMapEnvelope{
Payload: &proto.NetworkMapEnvelope_Full{
Full: &proto.NetworkMapComponentsFull{
PeerConfig: in.PeerConfig,
DnsDomain: in.DNSDomain,
DnsForwarderPort: in.DNSForwarderPort,
UserIdClaim: in.UserIDClaim,
AccountSettings: &proto.AccountSettingsCompact{},
ProxyPatch: in.ProxyPatch,
},
},
}
}
// Phase 1: build dedup tables. Every routing peer (in c.RouterPeers) and
// every regular peer (in c.Peers) must be indexed before any encoder
// looks up indexes via e.peerOrder — otherwise routes / routers_map for
// peers that exist only in c.RouterPeers would silently lose their
// peer_index reference.
enc := newComponentEncoder(c)
enc.indexAllPeers()
routerIdxs := enc.indexRouterPeers(c.RouterPeers)
// Phase 2: gather every policy that any consumer references (peer-pair
// policies + resource-only policies) so encodeResourcePoliciesMap can
// translate every *Policy pointer to a wire index.
allPolicies := unionPolicies(c.Policies, c.ResourcePoliciesMap)
policies, policyToIdxs := enc.encodePolicies(allPolicies)
// Phase 3: emit. Order of struct field expressions no longer matters:
// every encoder either reads from the dedup tables or works on
// independent input.
full := &proto.NetworkMapComponentsFull{
Serial: networkSerial(c.Network),
PeerConfig: in.PeerConfig,
Network: toAccountNetwork(c.Network),
AccountSettings: toAccountSettingsCompact(c.AccountSettings),
DnsForwarderPort: in.DNSForwarderPort,
UserIdClaim: in.UserIDClaim,
ProxyPatch: in.ProxyPatch,
DnsSettings: enc.encodeDNSSettings(c.DNSSettings),
DnsDomain: in.DNSDomain,
CustomZoneDomain: c.CustomZoneDomain,
AgentVersions: enc.agentVersions,
Peers: enc.peers,
RouterPeerIndexes: routerIdxs,
Policies: policies,
Groups: enc.encodeGroups(),
Routes: enc.encodeRoutes(c.Routes),
NameserverGroups: enc.encodeNameServerGroups(c.NameServerGroups),
AllDnsRecords: encodeSimpleRecords(c.AllDNSRecords),
AccountZones: encodeCustomZones(c.AccountZones),
NetworkResources: enc.encodeNetworkResources(c.NetworkResources),
RoutersMap: enc.encodeRoutersMap(c.RoutersMap),
ResourcePoliciesMap: enc.encodeResourcePoliciesMap(c.ResourcePoliciesMap, policyToIdxs),
GroupIdToUserIds: enc.encodeGroupIDToUserIDs(c.GroupIDToUserIDs),
AllowedUserIds: stringSetToSlice(c.AllowedUserIDs),
PostureFailedPeers: enc.encodePostureFailedPeers(c.PostureFailedPeers),
}
return &proto.NetworkMapEnvelope{
Payload: &proto.NetworkMapEnvelope_Full{Full: full},
}
}
// networkSerial returns c.Network.CurrentSerial() with a nil guard. The
// production path always populates c.Network, but the encoder is exported
// and a hand-built components struct may omit it.
func networkSerial(n *types.Network) uint64 {
if n == nil {
return 0
}
return n.CurrentSerial()
}
type componentEncoder struct {
components *types.NetworkMapComponents
peerOrder map[string]uint32
peers []*proto.PeerCompact
agentVersionOrder map[string]uint32
agentVersions []string
}
func newComponentEncoder(c *types.NetworkMapComponents) *componentEncoder {
return &componentEncoder{
components: c,
peerOrder: make(map[string]uint32, len(c.Peers)),
peers: make([]*proto.PeerCompact, 0, len(c.Peers)),
agentVersionOrder: make(map[string]uint32),
}
}
func (e *componentEncoder) indexAllPeers() {
for _, p := range e.components.Peers {
if p == nil {
continue
}
e.appendPeer(p)
}
}
func (e *componentEncoder) appendPeer(p *nbpeer.Peer) uint32 {
if idx, ok := e.peerOrder[p.ID]; ok {
return idx
}
idx := uint32(len(e.peers))
e.peerOrder[p.ID] = idx
e.peers = append(e.peers, toPeerCompact(p, e.agentVersionIndex(p.Meta.WtVersion)))
return idx
}
func (e *componentEncoder) agentVersionIndex(v string) uint32 {
if idx, ok := e.agentVersionOrder[v]; ok {
return idx
}
// Lazy-initialise the table with "" at index 0 so the empty string
// stays interchangeable with proto3's default uint32=0 — peers without
// a WtVersion don't force the table to materialise.
if v == "" {
idx := uint32(len(e.agentVersions))
if idx == 0 {
e.agentVersions = append(e.agentVersions, "")
}
e.agentVersionOrder[""] = idx
return idx
}
if len(e.agentVersions) == 0 {
e.agentVersions = append(e.agentVersions, "")
e.agentVersionOrder[""] = 0
}
idx := uint32(len(e.agentVersions))
e.agentVersionOrder[v] = idx
e.agentVersions = append(e.agentVersions, v)
return idx
}
// indexRouterPeers ensures every router peer is in the peer dedup table
// (c.RouterPeers may contain peers not in c.Peers when validation rules drop
// them) and returns their wire indexes for the RouterPeerIndexes field. Must
// run before any encoder that resolves peer ids via e.peerOrder.
func (e *componentEncoder) indexRouterPeers(routers map[string]*nbpeer.Peer) []uint32 {
if len(routers) == 0 {
return nil
}
out := make([]uint32, 0, len(routers))
for _, p := range routers {
if p == nil {
continue
}
out = append(out, e.appendPeer(p))
}
return out
}
func (e *componentEncoder) encodeGroups() []*proto.GroupCompact {
if len(e.components.Groups) == 0 {
return nil
}
out := make([]*proto.GroupCompact, 0, len(e.components.Groups))
for _, g := range e.components.Groups {
if !g.HasSeqID() {
continue
}
peerIdxs := make([]uint32, 0, len(g.Peers))
for _, peerID := range g.Peers {
if idx, ok := e.peerOrder[peerID]; ok {
peerIdxs = append(peerIdxs, idx)
}
}
out = append(out, &proto.GroupCompact{
Id: g.AccountSeqID,
Name: g.Name,
PeerIndexes: peerIdxs,
})
}
return out
}
// encodePolicies flattens Policy{Rules} → []PolicyCompact. Returns the wire
// list and a map from policy pointer to the indexes of its emitted rules in
// that list — used by encodeResourcePoliciesMap to translate
// ResourcePoliciesMap[resourceID][]*Policy into wire-side indexes.
func (e *componentEncoder) encodePolicies(policies []*types.Policy) ([]*proto.PolicyCompact, map[*types.Policy][]uint32) {
if len(policies) == 0 {
return nil, nil
}
out := make([]*proto.PolicyCompact, 0, len(policies))
idxByPolicy := make(map[*types.Policy][]uint32, len(policies))
for _, pol := range policies {
if !pol.HasSeqID() || !pol.Enabled {
continue
}
for _, r := range pol.Rules {
if r == nil || !r.Enabled {
continue
}
idxByPolicy[pol] = append(idxByPolicy[pol], uint32(len(out)))
out = append(out, e.encodePolicyRule(pol, r))
}
}
return out, idxByPolicy
}
// encodePolicyRule maps a single PolicyRule under pol to a PolicyCompact entry.
func (e *componentEncoder) encodePolicyRule(pol *types.Policy, r *types.PolicyRule) *proto.PolicyCompact {
return &proto.PolicyCompact{
Id: pol.AccountSeqID,
Action: networkmap.GetProtoAction(string(r.Action)),
Protocol: networkmap.GetProtoProtocol(string(r.Protocol)),
Bidirectional: r.Bidirectional,
Ports: portsToUint32(r.Ports),
PortRanges: portRangesToProto(r.PortRanges),
SourceGroupIds: e.groupSeqIDs(r.Sources),
DestinationGroupIds: e.groupSeqIDs(r.Destinations),
AuthorizedUser: r.AuthorizedUser,
AuthorizedGroups: e.encodeAuthorizedGroups(r.AuthorizedGroups),
SourceResource: e.resourceToProto(r.SourceResource),
DestinationResource: e.resourceToProto(r.DestinationResource),
SourcePostureCheckSeqIds: e.postureCheckSeqs(pol.SourcePostureChecks),
}
}
// groupSeqIDs maps the xid group IDs in src to their per-account seq ids,
// dropping any group that has no seq id assigned.
func (e *componentEncoder) groupSeqIDs(src []string) []uint32 {
if len(src) == 0 {
return nil
}
out := make([]uint32, 0, len(src))
for _, gid := range src {
if seq, ok := e.groupSeq(gid); ok {
out = append(out, seq)
}
}
return out
}
// unionPolicies merges c.Policies with every policy referenced by
// c.ResourcePoliciesMap, deduplicating by pointer identity. Resource-only
// policies (relevant to a NetworkResource but not to peer-pair traffic)
// only live in ResourcePoliciesMap; without this union step they'd be lost
// from the wire and the client's resource-policy lookup would come back
// empty.
func unionPolicies(policies []*types.Policy, resourcePolicies map[string][]*types.Policy) []*types.Policy {
// Fast path: non-router peers have no resource-only policies, so the
// "union" is identical to `policies`. Skip the dedup map allocation.
if len(resourcePolicies) == 0 {
return policies
}
seen := make(map[*types.Policy]struct{}, len(policies))
out := make([]*types.Policy, 0, len(policies))
for _, p := range policies {
if p == nil {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
out = append(out, p)
}
for _, list := range resourcePolicies {
for _, p := range list {
if p == nil {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
out = append(out, p)
}
}
return out
}
// encodeAuthorizedGroups translates rule.AuthorizedGroups (map keyed by
// group xid → local-user names) to the wire form (map keyed by group
// account_seq_id → UserNameList). Groups without a seq id are dropped —
// matches how source/destination group references handle the same case.
func (e *componentEncoder) encodeAuthorizedGroups(m map[string][]string) map[uint32]*proto.UserNameList {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.UserNameList, len(m))
for groupID, names := range m {
seq, ok := e.groupSeq(groupID)
if !ok {
continue
}
out[seq] = &proto.UserNameList{Names: append([]string(nil), names...)}
}
return out
}
func (e *componentEncoder) groupSeq(groupID string) (uint32, bool) {
g, ok := e.components.Groups[groupID]
if !ok || !g.HasSeqID() {
return 0, false
}
return g.AccountSeqID, true
}
// resourceToProto translates types.Resource for the wire. For peer-typed
// resources the peer id is converted to a peer index into the envelope's
// peers array. For other resource types only the type string is shipped
// today (Calculate's resource-typed rule path consults SourceResource only
// for "peer" — other types fall through to group-based lookup).
func (e *componentEncoder) resourceToProto(r types.Resource) *proto.ResourceCompact {
if r.ID == "" && r.Type == "" {
return nil
}
out := &proto.ResourceCompact{Type: string(r.Type)}
if r.Type == types.ResourceTypePeer && r.ID != "" {
if idx, ok := e.peerOrder[r.ID]; ok {
out.PeerIndexSet = true
out.PeerIndex = idx
}
}
return out
}
// postureCheckSeqs translates a slice of posture-check xids to their
// per-account integer ids using the NetworkMapComponents.PostureCheckXIDToSeq
// lookup. Unresolvable xids are silently dropped — matches how group/peer
// references handle the same case.
func (e *componentEncoder) postureCheckSeqs(xids []string) []uint32 {
if len(xids) == 0 || len(e.components.PostureCheckXIDToSeq) == 0 {
return nil
}
out := make([]uint32, 0, len(xids))
for _, xid := range xids {
if seq, ok := e.components.PostureCheckXIDToSeq[xid]; ok {
out = append(out, seq)
}
}
return out
}
// networkSeq translates a Network xid to its per-account integer id using
// the NetworkMapComponents.NetworkXIDToSeq lookup. Returns (0,false) when
// the xid isn't known — callers decide whether to skip the parent record.
func (e *componentEncoder) networkSeq(xid string) (uint32, bool) {
if xid == "" {
return 0, false
}
seq, ok := e.components.NetworkXIDToSeq[xid]
if !ok || seq == 0 {
return 0, false
}
return seq, true
}
func (e *componentEncoder) encodeDNSSettings(s *types.DNSSettings) *proto.DNSSettingsCompact {
if s == nil || len(s.DisabledManagementGroups) == 0 {
return nil
}
out := &proto.DNSSettingsCompact{
DisabledManagementGroupIds: make([]uint32, 0, len(s.DisabledManagementGroups)),
}
for _, gid := range s.DisabledManagementGroups {
if seq, ok := e.groupSeq(gid); ok {
out.DisabledManagementGroupIds = append(out.DisabledManagementGroupIds, seq)
}
}
return out
}
func (e *componentEncoder) encodeRoutes(routes []*nbroute.Route) []*proto.RouteRaw {
if len(routes) == 0 {
return nil
}
out := make([]*proto.RouteRaw, 0, len(routes))
for _, r := range routes {
if r == nil {
continue
}
rr := &proto.RouteRaw{
Id: r.AccountSeqID,
NetId: string(r.NetID),
Description: r.Description,
KeepRoute: r.KeepRoute,
NetworkType: int32(r.NetworkType),
Masquerade: r.Masquerade,
Metric: int32(r.Metric),
Enabled: r.Enabled,
SkipAutoApply: r.SkipAutoApply,
Domains: r.Domains.ToPunycodeList(),
GroupIds: e.groupIDsToSeq(r.Groups),
AccessControlGroupIds: e.groupIDsToSeq(r.AccessControlGroups),
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
}
if r.Network.IsValid() {
rr.NetworkCidr = r.Network.String()
}
if r.Peer != "" {
if idx, ok := e.peerOrder[r.Peer]; ok {
rr.PeerIndexSet = true
rr.PeerIndex = idx
}
}
out = append(out, rr)
}
return out
}
func (e *componentEncoder) groupIDsToSeq(groupIDs []string) []uint32 {
if len(groupIDs) == 0 {
return nil
}
out := make([]uint32, 0, len(groupIDs))
for _, gid := range groupIDs {
if seq, ok := e.groupSeq(gid); ok {
out = append(out, seq)
}
}
return out
}
func (e *componentEncoder) encodeNameServerGroups(nsgs []*nbdns.NameServerGroup) []*proto.NameServerGroupRaw {
if len(nsgs) == 0 {
return nil
}
out := make([]*proto.NameServerGroupRaw, 0, len(nsgs))
for _, nsg := range nsgs {
if nsg == nil {
continue
}
entry := &proto.NameServerGroupRaw{
Id: nsg.AccountSeqID,
Name: nsg.Name,
Description: nsg.Description,
Nameservers: encodeNameServers(nsg.NameServers),
GroupIds: e.groupIDsToSeq(nsg.Groups),
Primary: nsg.Primary,
Domains: nsg.Domains,
Enabled: nsg.Enabled,
SearchDomainsEnabled: nsg.SearchDomainsEnabled,
}
out = append(out, entry)
}
return out
}
func encodeNameServers(servers []nbdns.NameServer) []*proto.NameServer {
if len(servers) == 0 {
return nil
}
out := make([]*proto.NameServer, 0, len(servers))
for _, s := range servers {
out = append(out, &proto.NameServer{
IP: s.IP.String(),
NSType: int64(s.NSType),
Port: int64(s.Port),
})
}
return out
}
func encodeSimpleRecords(records []nbdns.SimpleRecord) []*proto.SimpleRecord {
if len(records) == 0 {
return nil
}
out := make([]*proto.SimpleRecord, 0, len(records))
for _, r := range records {
out = append(out, &proto.SimpleRecord{
Name: r.Name,
Type: int64(r.Type),
Class: r.Class,
TTL: int64(r.TTL),
RData: r.RData,
})
}
return out
}
func encodeCustomZones(zones []nbdns.CustomZone) []*proto.CustomZone {
if len(zones) == 0 {
return nil
}
out := make([]*proto.CustomZone, 0, len(zones))
for _, z := range zones {
out = append(out, &proto.CustomZone{
Domain: z.Domain,
Records: encodeSimpleRecords(z.Records),
SearchDomainDisabled: z.SearchDomainDisabled,
NonAuthoritative: z.NonAuthoritative,
})
}
return out
}
func (e *componentEncoder) encodeNetworkResources(resources []*resourceTypes.NetworkResource) []*proto.NetworkResourceRaw {
if len(resources) == 0 {
return nil
}
out := make([]*proto.NetworkResourceRaw, 0, len(resources))
for _, r := range resources {
if r == nil {
continue
}
entry := &proto.NetworkResourceRaw{
Id: r.AccountSeqID,
Name: r.Name,
Description: r.Description,
Type: string(r.Type),
Address: r.Address,
DomainValue: r.Domain,
Enabled: r.Enabled,
}
if seq, ok := e.networkSeq(r.NetworkID); ok {
entry.NetworkSeq = seq
}
if r.Prefix.IsValid() {
entry.PrefixCidr = r.Prefix.String()
}
out = append(out, entry)
}
return out
}
func (e *componentEncoder) encodeRoutersMap(routersMap map[string]map[string]*routerTypes.NetworkRouter) map[uint32]*proto.NetworkRouterList {
if len(routersMap) == 0 {
return nil
}
out := make(map[uint32]*proto.NetworkRouterList, len(routersMap))
for networkXID, routers := range routersMap {
if len(routers) == 0 {
continue
}
netSeq, ok := e.networkSeq(networkXID)
if !ok {
continue
}
entries := make([]*proto.NetworkRouterEntry, 0, len(routers))
for peerID, r := range routers {
if r == nil {
continue
}
entry := &proto.NetworkRouterEntry{
Id: r.AccountSeqID,
PeerGroupIds: e.groupIDsToSeq(r.PeerGroups),
Masquerade: r.Masquerade,
Metric: int32(r.Metric),
Enabled: r.Enabled,
}
if idx, ok := e.peerOrder[peerID]; ok {
entry.PeerIndexSet = true
entry.PeerIndex = idx
}
entries = append(entries, entry)
}
out[netSeq] = &proto.NetworkRouterList{Entries: entries}
}
return out
}
func (e *componentEncoder) encodeResourcePoliciesMap(rpm map[string][]*types.Policy, policyToIdxs map[*types.Policy][]uint32) map[uint32]*proto.PolicyIndexes {
if len(rpm) == 0 {
return nil
}
// resourceXIDToSeq is local to one encode — built from components.NetworkResources
// (small slice). Network resources without seq id are dropped, matching how
// other components-without-seq are silently filtered.
resourceXIDToSeq := make(map[string]uint32, len(e.components.NetworkResources))
for _, r := range e.components.NetworkResources {
if r != nil && r.AccountSeqID != 0 {
resourceXIDToSeq[r.ID] = r.AccountSeqID
}
}
out := make(map[uint32]*proto.PolicyIndexes, len(rpm))
for resourceXID, policies := range rpm {
seq, ok := resourceXIDToSeq[resourceXID]
if !ok {
continue
}
idxs := make([]uint32, 0, len(policies)*2)
for _, pol := range policies {
idxs = append(idxs, policyToIdxs[pol]...)
}
if len(idxs) == 0 {
continue
}
out[seq] = &proto.PolicyIndexes{Indexes: idxs}
}
return out
}
func (e *componentEncoder) encodeGroupIDToUserIDs(m map[string][]string) map[uint32]*proto.UserIDList {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.UserIDList, len(m))
for groupID, userIDs := range m {
seq, ok := e.groupSeq(groupID)
if !ok || len(userIDs) == 0 {
continue
}
out[seq] = &proto.UserIDList{UserIds: userIDs}
}
return out
}
func stringSetToSlice(s map[string]struct{}) []string {
if len(s) == 0 {
return nil
}
out := make([]string, 0, len(s))
for k := range s {
out = append(out, k)
}
return out
}
func (e *componentEncoder) encodePostureFailedPeers(m map[string]map[string]struct{}) map[uint32]*proto.PeerIndexSet {
if len(m) == 0 {
return nil
}
out := make(map[uint32]*proto.PeerIndexSet, len(m))
for checkXID, failedPeerIDs := range m {
seq, ok := e.components.PostureCheckXIDToSeq[checkXID]
if !ok || seq == 0 {
continue
}
idxs := make([]uint32, 0, len(failedPeerIDs))
for peerID := range failedPeerIDs {
if idx, ok := e.peerOrder[peerID]; ok {
idxs = append(idxs, idx)
}
}
if len(idxs) == 0 {
continue
}
out[seq] = &proto.PeerIndexSet{PeerIndexes: idxs}
}
return out
}
// toAccountSettingsCompact always returns a non-nil message — the client
// dereferences it unconditionally during Calculate(), so a nil here would
// crash the receiver. A missing types.AccountSettingsInfo on the server
// (which shouldn't happen in production but the encoder is exported)
// degrades to login_expiration_enabled = false, which makes
// LoginExpired() return false for every peer.
func toAccountSettingsCompact(s *types.AccountSettingsInfo) *proto.AccountSettingsCompact {
if s == nil {
return &proto.AccountSettingsCompact{}
}
return &proto.AccountSettingsCompact{
PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled,
PeerLoginExpirationNs: int64(s.PeerLoginExpiration),
}
}
func toAccountNetwork(n *types.Network) *proto.AccountNetwork {
if n == nil {
return nil
}
out := &proto.AccountNetwork{
Identifier: n.Identifier,
NetCidr: n.Net.String(),
Dns: n.Dns,
Serial: n.CurrentSerial(),
}
if len(n.NetV6.IP) > 0 {
out.NetV6Cidr = n.NetV6.String()
}
return out
}
func toPeerCompact(p *nbpeer.Peer, agentVersionIdx uint32) *proto.PeerCompact {
pc := &proto.PeerCompact{
WgPubKey: decodeWgKey(p.Key),
SshPubKey: []byte(p.SSHKey),
DnsLabel: p.DNSLabel,
AgentVersionIdx: agentVersionIdx,
AddedWithSsoLogin: p.UserID != "",
LoginExpirationEnabled: p.LoginExpirationEnabled,
SshEnabled: p.SSHEnabled,
SupportsIpv6: p.SupportsIPv6(),
SupportsSourcePrefixes: p.SupportsSourcePrefixes(),
ServerSshAllowed: p.Meta.Flags.ServerSSHAllowed,
}
if p.LastLogin != nil {
pc.LastLoginUnixNano = p.LastLogin.UnixNano()
}
switch {
case !p.IP.IsValid():
// leave Ip nil
case p.IP.Is4() || p.IP.Is4In6():
ip := p.IP.Unmap().As4()
pc.Ip = ip[:]
default:
ip := p.IP.As16()
pc.Ip = ip[:]
}
if p.IPv6.IsValid() {
ip := p.IPv6.As16()
pc.Ipv6 = ip[:]
}
return pc
}
// decodeWgKey returns the raw 32 bytes of a base64-encoded WireGuard public
// key, or nil for an empty / malformed key.
func decodeWgKey(s string) []byte {
if s == "" {
return nil
}
out := make([]byte, wgKeyRawLen)
n, err := base64.StdEncoding.Decode(out, []byte(s))
if err != nil || n != wgKeyRawLen {
return nil
}
return out
}
func portsToUint32(ports []string) []uint32 {
if len(ports) == 0 {
return nil
}
out := make([]uint32, 0, len(ports))
for _, p := range ports {
v, err := strconv.ParseUint(p, 10, 16)
if err != nil {
continue
}
out = append(out, uint32(v))
}
return out
}
func portRangesToProto(ranges []types.RulePortRange) []*proto.PortInfo_Range {
if len(ranges) == 0 {
return nil
}
out := make([]*proto.PortInfo_Range, 0, len(ranges))
for _, r := range ranges {
out = append(out, &proto.PortInfo_Range{
Start: uint32(r.Start),
End: uint32(r.End),
})
}
return out
}

View File

@@ -1,879 +0,0 @@
package grpc
import (
"bytes"
"cmp"
"net"
"net/netip"
"slices"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
goproto "google.golang.org/protobuf/proto"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
)
const testWgKeyA = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
const testWgKeyB = "BBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
const testWgKeyC = "CBCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopq="
// canonicalize rewrites a NetworkMapComponentsFull in place into a canonical
// form: peers reordered by wg_pub_key, with the rest of the message rewritten
// to reference the new peer indexes. Groups, policies, and router indexes are
// also sorted. After canonicalize, two envelopes built from the same logical
// input compare byte-equal via proto.Equal.
//
// This lives on the test side — the encoder itself emits in map-iteration
// order. Test-side normalization is the contract for "two encodes are
// equivalent".
func canonicalize(full *proto.NetworkMapComponentsFull) {
if full == nil {
return
}
// Canonicalize agent_versions first: sort the slice and rewrite each
// peer's AgentVersionIdx accordingly. The empty placeholder stays at
// index 0 by convention.
avRemap := make(map[uint32]uint32, len(full.AgentVersions))
if len(full.AgentVersions) > 0 {
// Pair version → original index, sort, rebuild.
type avEntry struct {
version string
oldIdx uint32
}
entries := make([]avEntry, len(full.AgentVersions))
for i, v := range full.AgentVersions {
entries[i] = avEntry{version: v, oldIdx: uint32(i)}
}
// Empty stays at 0; sort the rest by string. Tiebreaker on oldIdx
// keeps the canonicalize output stable when two entries compare
// equal (the encoder dedups, but defending against future inputs).
slices.SortFunc(entries, func(a, b avEntry) int {
if a.version == "" && b.version != "" {
return -1
}
if b.version == "" && a.version != "" {
return 1
}
if c := cmp.Compare(a.version, b.version); c != 0 {
return c
}
return cmp.Compare(a.oldIdx, b.oldIdx)
})
newVersions := make([]string, len(entries))
for newIdx, e := range entries {
avRemap[e.oldIdx] = uint32(newIdx)
newVersions[newIdx] = e.version
}
full.AgentVersions = newVersions
}
for _, p := range full.Peers {
if newIdx, ok := avRemap[p.AgentVersionIdx]; ok {
p.AgentVersionIdx = newIdx
}
}
type peerEntry struct {
peer *proto.PeerCompact
oldIdx uint32
}
entries := make([]peerEntry, len(full.Peers))
for i, p := range full.Peers {
entries[i] = peerEntry{peer: p, oldIdx: uint32(i)}
}
// DnsLabel is unique per peer; it tiebreaks on equal WgPubKey (e.g. both
// nil from malformed keys, or both empty for placeholders).
slices.SortFunc(entries, func(a, b peerEntry) int {
if c := bytes.Compare(a.peer.WgPubKey, b.peer.WgPubKey); c != 0 {
return c
}
return cmp.Compare(a.peer.DnsLabel, b.peer.DnsLabel)
})
remap := make(map[uint32]uint32, len(entries))
newPeers := make([]*proto.PeerCompact, len(entries))
for newIdx, e := range entries {
remap[e.oldIdx] = uint32(newIdx)
newPeers[newIdx] = e.peer
}
full.Peers = newPeers
full.RouterPeerIndexes = remapAndSort(full.RouterPeerIndexes, remap)
for _, g := range full.Groups {
g.PeerIndexes = remapAndSort(g.PeerIndexes, remap)
}
slices.SortFunc(full.Groups, func(a, b *proto.GroupCompact) int { return cmp.Compare(a.Id, b.Id) })
for _, r := range full.Routes {
if r.PeerIndexSet {
if newIdx, ok := remap[r.PeerIndex]; ok {
r.PeerIndex = newIdx
}
}
slices.Sort(r.GroupIds)
slices.Sort(r.AccessControlGroupIds)
slices.Sort(r.PeerGroupIds)
}
slices.SortFunc(full.Routes, func(a, b *proto.RouteRaw) int { return cmp.Compare(a.Id, b.Id) })
for _, list := range full.RoutersMap {
for _, entry := range list.Entries {
if entry.PeerIndexSet {
if newIdx, ok := remap[entry.PeerIndex]; ok {
entry.PeerIndex = newIdx
}
}
slices.Sort(entry.PeerGroupIds)
}
slices.SortFunc(list.Entries, func(a, b *proto.NetworkRouterEntry) int { return cmp.Compare(a.Id, b.Id) })
}
for _, set := range full.PostureFailedPeers {
set.PeerIndexes = remapAndSort(set.PeerIndexes, remap)
}
for _, p := range full.Policies {
slices.Sort(p.SourceGroupIds)
slices.Sort(p.DestinationGroupIds)
}
// Sort policies by (Id, source_group_ids, destination_group_ids) so that
// multiple PolicyCompact entries sharing the same Id (one per rule, when
// a Policy has multiple rules) still get a deterministic order. After
// sorting we remap indexes in ResourcePoliciesMap.
policyOldOrder := make(map[*proto.PolicyCompact]uint32, len(full.Policies))
for i, p := range full.Policies {
policyOldOrder[p] = uint32(i)
}
slices.SortFunc(full.Policies, func(a, b *proto.PolicyCompact) int {
if c := cmp.Compare(a.Id, b.Id); c != 0 {
return c
}
if c := slices.Compare(a.SourceGroupIds, b.SourceGroupIds); c != 0 {
return c
}
return slices.Compare(a.DestinationGroupIds, b.DestinationGroupIds)
})
policyRemap := make(map[uint32]uint32, len(full.Policies))
for newIdx, p := range full.Policies {
policyRemap[policyOldOrder[p]] = uint32(newIdx)
}
for _, idxs := range full.ResourcePoliciesMap {
idxs.Indexes = remapAndSort(idxs.Indexes, policyRemap)
}
for _, list := range full.GroupIdToUserIds {
slices.Sort(list.UserIds)
}
slices.Sort(full.AllowedUserIds)
}
func remapAndSort(idxs []uint32, remap map[uint32]uint32) []uint32 {
out := make([]uint32, 0, len(idxs))
for _, i := range idxs {
if newIdx, ok := remap[i]; ok {
out = append(out, newIdx)
}
}
slices.Sort(out)
return out
}
// envelopesEquivalent decodes both envelopes, canonicalizes them, and reports
// whether they're proto.Equal. Use instead of byte-comparing marshaled output:
// the encoder is intentionally non-deterministic.
func envelopesEquivalent(a, b *proto.NetworkMapEnvelope) bool {
canonicalize(a.GetFull())
canonicalize(b.GetFull())
return goproto.Equal(a, b)
}
func newTestComponents() *types.NetworkMapComponents {
peerA := &nbpeer.Peer{
ID: "peer-a",
Key: testWgKeyA,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 1}),
DNSLabel: "peera",
SSHKey: "ssh-a",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
peerB := &nbpeer.Peer{
ID: "peer-b",
Key: testWgKeyB,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 2}),
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}),
DNSLabel: "peerb",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.25.0"},
}
peerC := &nbpeer.Peer{
ID: "peer-c",
Key: testWgKeyC,
IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
DNSLabel: "peerc",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
return &types.NetworkMapComponents{
PeerID: "peer-a",
Network: &types.Network{
Identifier: "net-test",
Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)},
Serial: 7,
},
AccountSettings: &types.AccountSettingsInfo{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: 2 * time.Hour,
},
Peers: map[string]*nbpeer.Peer{
"peer-a": peerA,
"peer-b": peerB,
"peer-c": peerC,
},
Groups: map[string]*types.Group{
"group-src": {ID: "group-src", AccountSeqID: 1, Name: "Src", Peers: []string{"peer-a"}},
"group-dst": {ID: "group-dst", AccountSeqID: 2, Name: "Dst", Peers: []string{"peer-b", "peer-c"}},
},
Policies: []*types.Policy{
{
ID: "pol-1",
AccountSeqID: 10,
Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-1", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true,
Ports: []string{"22", "80"},
PortRanges: []types.RulePortRange{{Start: 8000, End: 8100}},
Sources: []string{"group-src"},
Destinations: []string{"group-dst"},
}},
},
},
RouterPeers: map[string]*nbpeer.Peer{"peer-c": peerC},
}
}
func TestEncodeNetworkMapEnvelope_Basic(t *testing.T) {
c := newTestComponents()
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: c,
DNSDomain: "netbird.cloud",
})
require.NotNil(t, env)
full := env.GetFull()
require.NotNil(t, full, "envelope must contain Full payload")
assert.EqualValues(t, 7, full.Serial)
assert.Equal(t, "netbird.cloud", full.DnsDomain)
require.NotNil(t, full.Network)
assert.Equal(t, "net-test", full.Network.Identifier)
assert.Equal(t, "100.64.0.0/10", full.Network.NetCidr)
require.NotNil(t, full.AccountSettings)
assert.True(t, full.AccountSettings.PeerLoginExpirationEnabled)
assert.EqualValues(t, (2 * time.Hour).Nanoseconds(), full.AccountSettings.PeerLoginExpirationNs)
require.Len(t, full.Peers, 3)
byLabel := map[string]*proto.PeerCompact{}
for _, p := range full.Peers {
assert.Len(t, p.WgPubKey, 32, "wg key must be raw 32 bytes")
assert.Len(t, p.Ip, 4, "ipv4 must be raw 4 bytes")
byLabel[p.DnsLabel] = p
}
assert.Len(t, byLabel["peerb"].Ipv6, 16, "peer-b has ipv6 → 16 bytes")
}
func TestEncodeNetworkMapEnvelope_RepeatEncodesEquivalent(t *testing.T) {
c := newTestComponents()
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
// Hammer it 100 times — Go map iteration is randomized per call, so each
// run produces different wire bytes, but the canonicalized form must
// match.
for i := 0; i < 100; i++ {
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
require.True(t, envelopesEquivalent(expected, got),
"encode #%d must be semantically equivalent to first encode", i)
}
}
func TestEncodeNetworkMapEnvelope_ConcurrentEncodesEquivalent(t *testing.T) {
c := newTestComponents()
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
const goroutines = 50
var wg sync.WaitGroup
wg.Add(goroutines)
results := make([]*proto.NetworkMapEnvelope, goroutines)
for i := 0; i < goroutines; i++ {
i := i
go func() {
defer wg.Done()
results[i] = EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
}()
}
wg.Wait()
for i, got := range results {
require.NotNil(t, got, "goroutine %d returned nil", i)
require.True(t, envelopesEquivalent(expected, got),
"goroutine %d produced inequivalent envelope", i)
}
}
func TestEncodeNetworkMapEnvelope_GroupsByAccountSeqID(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Groups, 2)
groupByID := map[uint32]*proto.GroupCompact{}
for _, g := range full.Groups {
groupByID[g.Id] = g
}
require.Contains(t, groupByID, uint32(1))
require.Contains(t, groupByID, uint32(2))
assert.Equal(t, "Src", groupByID[1].Name)
assert.Equal(t, "Dst", groupByID[2].Name)
assert.Len(t, groupByID[1].PeerIndexes, 1)
assert.Len(t, groupByID[2].PeerIndexes, 2)
}
func TestEncodeNetworkMapEnvelope_PolicyExpansion(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Policies, 1)
pc := full.Policies[0]
assert.EqualValues(t, 10, pc.Id)
assert.Equal(t, proto.RuleAction_ACCEPT, pc.Action)
assert.Equal(t, proto.RuleProtocol_TCP, pc.Protocol)
assert.True(t, pc.Bidirectional)
assert.Equal(t, []uint32{22, 80}, pc.Ports)
require.Len(t, pc.PortRanges, 1)
assert.EqualValues(t, 8000, pc.PortRanges[0].Start)
assert.EqualValues(t, 8100, pc.PortRanges[0].End)
assert.Equal(t, []uint32{1}, pc.SourceGroupIds)
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
}
func TestEncodeNetworkMapEnvelope_RouterIndexes(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.RouterPeerIndexes, 1)
idx := full.RouterPeerIndexes[0]
require.Less(t, int(idx), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[idx].DnsLabel)
}
func TestEncodeNetworkMapEnvelope_AgentVersionDedup(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.AgentVersions, 3, "empty placeholder + 2 distinct versions")
assert.Equal(t, "", full.AgentVersions[0], "index 0 reserved for empty version")
assert.ElementsMatch(t, []string{"0.40.0", "0.25.0"}, full.AgentVersions[1:],
"two distinct versions, order depends on map iteration")
idxByLabel := map[string]uint32{}
for _, p := range full.Peers {
idxByLabel[p.DnsLabel] = p.AgentVersionIdx
}
assert.Equal(t, idxByLabel["peera"], idxByLabel["peerc"], "peers with the same agent version share an index")
assert.NotEqual(t, idxByLabel["peera"], idxByLabel["peerb"])
}
func TestEncodeNetworkMapEnvelope_DisabledPolicySkipped(t *testing.T) {
c := newTestComponents()
c.Policies[0].Enabled = false
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
assert.Empty(t, full.Policies)
}
func TestEncodeNetworkMapEnvelope_GroupZeroSeqIDSkipped(t *testing.T) {
c := newTestComponents()
c.Groups["group-src"].AccountSeqID = 0
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Groups, 1, "groups with AccountSeqID=0 are not yet persisted and must be skipped")
assert.EqualValues(t, 2, full.Groups[0].Id)
require.Len(t, full.Policies, 1)
pc := full.Policies[0]
assert.Empty(t, pc.SourceGroupIds, "rule references a group that was filtered out → no group id on wire")
assert.Equal(t, []uint32{2}, pc.DestinationGroupIds)
}
func TestEncodeNetworkMapEnvelope_TwoPeersSameMalformedKey(t *testing.T) {
// Both peers have nil WgPubKey after decode; canonicalize must still
// produce a stable order using DnsLabel as a tiebreaker, so 100 encodes
// canonicalize identically.
c := newTestComponents()
c.Peers["peer-a"].Key = "garbage-a-!!!"
c.Peers["peer-b"].Key = "garbage-b-!!!"
expected := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
for i := 0; i < 100; i++ {
got := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
require.True(t, envelopesEquivalent(expected, got),
"encode #%d with two same-key peers must canonicalize equivalently", i)
}
}
func TestEncodeNetworkMapEnvelope_MalformedWgKey(t *testing.T) {
c := newTestComponents()
c.Peers["peer-a"].Key = "not-base64-!!!"
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Peers, 3)
var byLabel = map[string]*proto.PeerCompact{}
for _, p := range full.Peers {
byLabel[p.DnsLabel] = p
}
assert.Nil(t, byLabel["peera"].WgPubKey, "peer with malformed key encodes nil WgPubKey")
assert.Len(t, byLabel["peerb"].WgPubKey, 32, "other peers retain their key")
}
func TestEncodeNetworkMapEnvelope_IPv6OnlyPeer(t *testing.T) {
c := newTestComponents()
v6Only := &nbpeer.Peer{
ID: "peer-v6",
Key: testWgKeyA,
IPv6: netip.AddrFrom16([16]byte{0xfd, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9}),
DNSLabel: "peerv6",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
c.Peers["peer-v6"] = v6Only
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var found *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peerv6" {
found = p
}
}
require.NotNil(t, found, "ipv6-only peer must be present")
assert.Empty(t, found.Ip, "no IPv4 address → empty Ip")
assert.Len(t, found.Ipv6, 16)
}
func TestEncodeNetworkMapEnvelope_PeerWithoutIP(t *testing.T) {
c := newTestComponents()
c.Peers["peer-noip"] = &nbpeer.Peer{
ID: "peer-noip",
Key: testWgKeyA,
DNSLabel: "peernoip",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var found *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peernoip" {
found = p
}
}
require.NotNil(t, found)
assert.Empty(t, found.Ip)
assert.Empty(t, found.Ipv6)
}
func TestEncodeNetworkMapEnvelope_EmptyInput(t *testing.T) {
c := &types.NetworkMapComponents{
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
}
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c})
full := env.GetFull()
require.NotNil(t, full)
assert.Empty(t, full.Peers)
assert.Empty(t, full.Groups)
assert.Empty(t, full.Policies)
assert.Empty(t, full.RouterPeerIndexes)
require.NotNil(t, full.AccountSettings, "AccountSettingsCompact must always be emitted (client dereferences it unconditionally)")
}
func TestEncodeNetworkMapEnvelope_PeerLoginExpirationFields(t *testing.T) {
c := newTestComponents()
now := time.Date(2024, 1, 2, 3, 4, 5, 0, time.UTC)
c.Peers["peer-a"].UserID = "user-1"
c.Peers["peer-a"].LoginExpirationEnabled = true
c.Peers["peer-a"].LastLogin = &now
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
var pa *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peera" {
pa = p
}
}
require.NotNil(t, pa)
assert.True(t, pa.AddedWithSsoLogin)
assert.True(t, pa.LoginExpirationEnabled)
assert.Equal(t, now.UnixNano(), pa.LastLoginUnixNano)
// peer-b has no UserID and no LastLogin → all fields zero-value.
var pb *proto.PeerCompact
for _, p := range full.Peers {
if p.DnsLabel == "peerb" {
pb = p
}
}
require.NotNil(t, pb)
assert.False(t, pb.AddedWithSsoLogin)
assert.False(t, pb.LoginExpirationEnabled)
assert.Zero(t, pb.LastLoginUnixNano)
}
func TestEncodeNetworkMapEnvelope_RoutesRoundTrip(t *testing.T) {
c := newTestComponents()
c.Routes = []*nbroute.Route{
{
ID: "route-peer",
AccountSeqID: 100,
NetID: "net-A",
Description: "via peer-c",
Network: netip.MustParsePrefix("10.0.0.0/16"),
Peer: "peer-c", // peer ID, not WG key
Groups: []string{"group-src"},
AccessControlGroups: []string{"group-dst"},
Enabled: true,
},
{
ID: "route-peergroup",
AccountSeqID: 101,
NetID: "net-B",
Network: netip.MustParsePrefix("10.1.0.0/16"),
PeerGroups: []string{"group-src", "group-dst"},
Enabled: true,
},
{
ID: "route-no-seq",
AccountSeqID: 0, // unset — should still ship (no group seq filter on routes)
Network: netip.MustParsePrefix("10.2.0.0/16"),
Enabled: true,
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Routes, 3)
byNetID := map[string]*proto.RouteRaw{}
for _, r := range full.Routes {
byNetID[r.NetId] = r
}
r1 := byNetID["net-A"]
require.NotNil(t, r1)
assert.True(t, r1.PeerIndexSet, "route with peer must set peer_index_set")
require.Less(t, int(r1.PeerIndex), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[r1.PeerIndex].DnsLabel)
assert.Equal(t, []uint32{1}, r1.GroupIds, "group-src has AccountSeqID 1")
assert.Equal(t, []uint32{2}, r1.AccessControlGroupIds, "group-dst has AccountSeqID 2")
assert.Empty(t, r1.PeerGroupIds)
r2 := byNetID["net-B"]
require.NotNil(t, r2)
assert.False(t, r2.PeerIndexSet, "route with peer_groups must NOT set peer_index_set")
assert.ElementsMatch(t, []uint32{1, 2}, r2.PeerGroupIds)
}
func TestEncodeNetworkMapEnvelope_RouteWithMissingPeerLeavesIndexUnset(t *testing.T) {
c := newTestComponents()
c.Routes = []*nbroute.Route{{
ID: "route-x",
AccountSeqID: 100,
Peer: "peer-not-in-components",
Network: netip.MustParsePrefix("10.0.0.0/16"),
Enabled: true,
}}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Routes, 1)
assert.False(t, full.Routes[0].PeerIndexSet,
"missing peer reference must not pretend to point at peer index 0")
}
func TestEncodeNetworkMapEnvelope_ResourceOnlyPolicyShippedAndIndexed(t *testing.T) {
c := newTestComponents()
// Policy that exists ONLY in ResourcePoliciesMap, not in c.Policies. This
// is the I1 case — without unionPolicies the encoder would silently
// drop it from the wire.
resourceOnlyPolicy := &types.Policy{
ID: "pol-resource", AccountSeqID: 99, Enabled: true,
Rules: []*types.PolicyRule{{
ID: "rule-r", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolTCP,
Sources: []string{"group-src"},
Destinations: []string{"group-dst"},
}},
}
c.ResourcePoliciesMap = map[string][]*types.Policy{
"resource-x": {c.Policies[0], resourceOnlyPolicy}, // shared + resource-only
}
// Resource must appear in components.NetworkResources with a seq id —
// encoder uses that to translate the xid map key to uint32.
c.NetworkResources = []*resourceTypes.NetworkResource{
{ID: "resource-x", AccountSeqID: 77, Name: "res-x", Enabled: true},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.Policies, 2, "encoded policies must include both peer-traffic and resource-only")
policyByID := map[uint32]*proto.PolicyCompact{}
policyIdxByID := map[uint32]uint32{}
for i, p := range full.Policies {
policyByID[p.Id] = p
policyIdxByID[p.Id] = uint32(i)
}
require.Contains(t, policyByID, uint32(10), "original peer-traffic policy id 10")
require.Contains(t, policyByID, uint32(99), "resource-only policy id 99")
require.Contains(t, full.ResourcePoliciesMap, uint32(77))
idxs := full.ResourcePoliciesMap[77].Indexes
require.Len(t, idxs, 2)
assert.ElementsMatch(t, []uint32{policyIdxByID[10], policyIdxByID[99]}, idxs,
"resource policies map must reference both wire policy indexes")
}
func TestEncodeNetworkMapEnvelope_NameServerGroups(t *testing.T) {
c := newTestComponents()
c.NameServerGroups = []*nbdns.NameServerGroup{{
ID: "nsg-1", AccountSeqID: 50, Name: "Main", Description: "primary",
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53,
}},
Groups: []string{"group-src", "group-not-persisted"},
Primary: true, Enabled: true,
Domains: []string{"corp.example"},
}}
c.Groups["group-not-persisted"] = &types.Group{ID: "group-not-persisted", AccountSeqID: 0, Peers: []string{}}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.NameserverGroups, 1)
nsg := full.NameserverGroups[0]
assert.EqualValues(t, 50, nsg.Id)
assert.Equal(t, "Main", nsg.Name)
assert.True(t, nsg.Primary)
require.Len(t, nsg.Nameservers, 1)
assert.Equal(t, "8.8.8.8", nsg.Nameservers[0].IP)
assert.Equal(t, []uint32{1}, nsg.GroupIds, "group-not-persisted is filtered out (AccountSeqID=0)")
}
func TestEncodeNetworkMapEnvelope_PostureFailedPeers(t *testing.T) {
c := newTestComponents()
c.PostureCheckXIDToSeq = map[string]uint32{"check-1": 33}
c.PostureFailedPeers = map[string]map[string]struct{}{
"check-1": {
"peer-a": {},
"peer-b": {},
"peer-not-in-account": {},
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.PostureFailedPeers, uint32(33))
idxs := full.PostureFailedPeers[33].PeerIndexes
assert.Len(t, idxs, 2, "missing peer is silently dropped (filterPostureFailedPeers guarantees presence in real data)")
}
func TestEncodeNetworkMapEnvelope_RoutersMap(t *testing.T) {
c := newTestComponents()
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
"net-1": {
"peer-c": {
ID: "router-1", AccountSeqID: 200,
Peer: "peer-c", Masquerade: true, Metric: 10, Enabled: true,
},
},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.RoutersMap, uint32(5))
entries := full.RoutersMap[5].Entries
require.Len(t, entries, 1)
e := entries[0]
assert.EqualValues(t, 200, e.Id)
assert.True(t, e.PeerIndexSet)
require.Less(t, int(e.PeerIndex), len(full.Peers))
assert.Equal(t, "peerc", full.Peers[e.PeerIndex].DnsLabel)
assert.True(t, e.Masquerade)
assert.EqualValues(t, 10, e.Metric)
assert.True(t, e.Enabled)
}
func TestEncodeNetworkMapEnvelope_RouterPeerNotInComponentsPeers(t *testing.T) {
// Router peer in c.RouterPeers but NOT in c.Peers (validation may have
// filtered it). indexRouterPeers runs before encodeRoutersMap, so the
// peer_index reference must still resolve.
c := newTestComponents()
delete(c.Peers, "peer-c")
routerPeer := &nbpeer.Peer{
ID: "peer-c", Key: testWgKeyC, IP: netip.AddrFrom4([4]byte{100, 64, 0, 3}),
DNSLabel: "peerc", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}
c.RouterPeers = map[string]*nbpeer.Peer{"peer-c": routerPeer}
c.NetworkXIDToSeq = map[string]uint32{"net-1": 5}
c.RoutersMap = map[string]map[string]*routerTypes.NetworkRouter{
"net-1": {"peer-c": {ID: "r-1", AccountSeqID: 1, Peer: "peer-c", Enabled: true}},
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Contains(t, full.RoutersMap, uint32(5))
require.Len(t, full.RoutersMap[5].Entries, 1)
e := full.RoutersMap[5].Entries[0]
assert.True(t, e.PeerIndexSet, "router peer must be indexed even when not in c.Peers")
}
func TestEncodeNetworkMapEnvelope_DNSSettingsFiltersUnpersistedGroups(t *testing.T) {
c := newTestComponents()
c.DNSSettings = &types.DNSSettings{
DisabledManagementGroups: []string{"group-src", "group-missing", "group-no-seq"},
}
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.NotNil(t, full.DnsSettings)
assert.Equal(t, []uint32{1}, full.DnsSettings.DisabledManagementGroupIds,
"only group-src (AccountSeqID=1) survives — missing and unpersisted are dropped")
}
func TestEncodeNetworkMapEnvelope_GroupIDToUserIDs(t *testing.T) {
c := newTestComponents()
c.GroupIDToUserIDs = map[string][]string{
"group-src": {"user-1", "user-2"},
"group-no-seq": {"user-3"}, // group not persisted → drop
"group-missing": {"user-4"}, // group not in components → drop
}
c.Groups["group-no-seq"] = &types.Group{ID: "group-no-seq", AccountSeqID: 0}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.Len(t, full.GroupIdToUserIds, 1, "only persisted+present groups survive")
require.Contains(t, full.GroupIdToUserIds, uint32(1))
assert.ElementsMatch(t, []string{"user-1", "user-2"}, full.GroupIdToUserIds[1].UserIds)
}
func TestToProxyPatch_EmptyInputReturnsNil(t *testing.T) {
assert.Nil(t, toProxyPatch(nil, "netbird.cloud", false, false))
assert.Nil(t, toProxyPatch(&types.NetworkMap{}, "netbird.cloud", false, false),
"empty NetworkMap (no peers, rules, routes etc) → nil patch so proto3 omits the field")
}
func TestToProxyPatch_PopulatesAllFields(t *testing.T) {
nm := &types.NetworkMap{
Peers: []*nbpeer.Peer{{
ID: "ext-peer", Key: testWgKeyA, IP: netip.AddrFrom4([4]byte{100, 64, 0, 9}),
DNSLabel: "extpeer", Meta: nbpeer.PeerSystemMeta{WtVersion: "0.40.0"},
}},
FirewallRules: []*types.FirewallRule{{
PeerIP: "100.64.0.9", Action: "accept", Direction: 0, Protocol: "tcp",
}},
}
patch := toProxyPatch(nm, "netbird.cloud", false, false)
require.NotNil(t, patch)
assert.Len(t, patch.Peers, 1)
assert.Len(t, patch.FirewallRules, 1)
}
// TestEncodeNetworkMapEnvelope_ProxyPatchPropagated covers the ProxyPatch
// pass-through in both encoder branches (normal path + nil-Components
// graceful-degrade). Guards against a regression that drops `ProxyPatch:`
// from one of the envelope struct literals.
func TestEncodeNetworkMapEnvelope_ProxyPatchPropagated(t *testing.T) {
patch := &proto.ProxyPatch{
ForwardingRules: []*proto.ForwardingRule{{
Protocol: proto.RuleProtocol_TCP,
DestinationPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 80}},
TranslatedAddress: net.IPv4(10, 0, 0, 1).To4(),
TranslatedPort: &proto.PortInfo{PortSelection: &proto.PortInfo_Port{Port: 8080}},
}},
}
t.Run("normal_path", func(t *testing.T) {
c := newTestComponents()
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: c,
ProxyPatch: patch,
}).GetFull()
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the normal encode path")
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
})
t.Run("nil_components_graceful_degrade", func(t *testing.T) {
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: nil,
ProxyPatch: patch,
}).GetFull()
require.NotNil(t, full.ProxyPatch, "ProxyPatch must propagate through the nil-Components branch too")
assert.Len(t, full.ProxyPatch.ForwardingRules, 1)
})
}
func TestEncodeNetworkMapEnvelope_NilComponentsGracefulDegrade(t *testing.T) {
// nil Components → minimal envelope, no crash. Matches the legacy
// behaviour for missing/unvalidated peers.
env := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: nil,
DNSDomain: "netbird.cloud",
})
require.NotNil(t, env)
full := env.GetFull()
require.NotNil(t, full)
require.NotNil(t, full.AccountSettings, "AccountSettings must always be non-nil")
assert.Equal(t, "netbird.cloud", full.DnsDomain)
assert.Empty(t, full.Peers)
assert.Empty(t, full.Policies)
}
func TestEncodeNetworkMapEnvelope_AccountSettingsAlwaysEmitted(t *testing.T) {
c := &types.NetworkMapComponents{
Network: &types.Network{Identifier: "x", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(10, 32)}},
// AccountSettings deliberately nil
}
full := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{Components: c}).GetFull()
require.NotNil(t, full.AccountSettings, "client dereferences AccountSettings unconditionally during Calculate(); a nil here would crash the receiver")
assert.False(t, full.AccountSettings.PeerLoginExpirationEnabled)
assert.Zero(t, full.AccountSettings.PeerLoginExpirationNs)
}

View File

@@ -1,192 +0,0 @@
package grpc
import (
"context"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/networkmap"
"github.com/netbirdio/netbird/shared/management/proto"
)
// ToComponentSyncResponse builds a SyncResponse carrying the compact
// NetworkMapEnvelope for capability-aware peers. The legacy proto.NetworkMap
// field is intentionally left empty — capable peers ignore it and the
// envelope alone is the authoritative wire shape.
//
// PeerConfig is computed once server-side using the receiving peer's own
// account-level network metadata. EnableSSH inside PeerConfig is left at
// peer.SSHEnabled (the peer's local setting); account-policy-driven SSH is
// computed by the client from the envelope's GroupIDToUserIDs / AllowedUserIDs
// inside Calculate(), so the SshConfig.SshEnabled bit may flip true on the
// client even though the server-side PeerConfig reports false.
func ToComponentSyncResponse(
ctx context.Context,
config *nbconfig.Config,
httpConfig *nbconfig.HttpServerConfig,
deviceFlowConfig *nbconfig.DeviceAuthorizationFlow,
peer *nbpeer.Peer,
turnCredentials *Token,
relayCredentials *Token,
components *types.NetworkMapComponents,
proxyPatch *types.NetworkMap,
dnsName string,
checks []*posture.Checks,
settings *types.Settings,
extraSettings *types.ExtraSettings,
peerGroups []string,
dnsFwdPort int64,
) *proto.SyncResponse {
network := networkOrZero(components)
enableSSH := computeSSHEnabledForPeer(components, peer)
peerConfig := toPeerConfig(peer, network, dnsName, settings, httpConfig, deviceFlowConfig, enableSSH)
includeIPv6 := peer.SupportsIPv6() && peer.IPv6.IsValid()
useSourcePrefixes := peer.SupportsSourcePrefixes()
userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim
}
envelope := EncodeNetworkMapEnvelope(ComponentsEnvelopeInput{
Components: components,
PeerConfig: peerConfig,
DNSDomain: dnsName,
DNSForwarderPort: dnsFwdPort,
UserIDClaim: userIDClaim,
ProxyPatch: toProxyPatch(proxyPatch, dnsName, includeIPv6, useSourcePrefixes),
})
resp := &proto.SyncResponse{
PeerConfig: peerConfig,
NetworkMapEnvelope: envelope,
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
resp.NetbirdConfig = integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
return resp
}
// networkOrZero returns components.Network or a zero Network — toPeerConfig
// dereferences network.Net which would panic on nil.
func networkOrZero(c *types.NetworkMapComponents) *types.Network {
if c == nil || c.Network == nil {
return &types.Network{}
}
return c.Network
}
// toProxyPatch converts a proxy-injected *types.NetworkMap into the wire
// patch the components envelope ships alongside. Returns nil when there are
// no fragments to merge — proto3 omits a nil message field, so the receiver
// sees no patch and skips the merge step entirely.
//
// We reuse the legacy proto-conversion helpers (toProtocolRoutes,
// toProtocolFirewallRules, toProtocolRoutesFirewallRules,
// appendRemotePeerConfig, ForwardingRule.ToProto) because the proxy
// delivers fragments pre-expanded — there's no raw component shape to
// derive them from. Components purity isn't violated: proxy data isn't
// policy-graph-derived, it's externally injected post-Calculate, so the
// client merges it on top of its locally-computed NetworkMap.
func toProxyPatch(nm *types.NetworkMap, dnsName string, includeIPv6, useSourcePrefixes bool) *proto.ProxyPatch {
if nm == nil {
return nil
}
if len(nm.Peers) == 0 && len(nm.OfflinePeers) == 0 && len(nm.FirewallRules) == 0 &&
len(nm.Routes) == 0 && len(nm.RoutesFirewallRules) == 0 && len(nm.ForwardingRules) == 0 {
return nil
}
patch := &proto.ProxyPatch{
Peers: networkmap.AppendRemotePeerConfig(nil, nm.Peers, dnsName, includeIPv6),
OfflinePeers: networkmap.AppendRemotePeerConfig(nil, nm.OfflinePeers, dnsName, includeIPv6),
FirewallRules: networkmap.ToProtocolFirewallRules(nm.FirewallRules, includeIPv6, useSourcePrefixes),
Routes: networkmap.ToProtocolRoutes(nm.Routes),
RouteFirewallRules: networkmap.ToProtocolRoutesFirewallRules(nm.RoutesFirewallRules),
}
if len(nm.ForwardingRules) > 0 {
patch.ForwardingRules = make([]*proto.ForwardingRule, 0, len(nm.ForwardingRules))
for _, r := range nm.ForwardingRules {
patch.ForwardingRules = append(patch.ForwardingRules, r.ToProto())
}
}
return patch
}
// computeSSHEnabledForPeer mirrors the SSH-server-activation bit that
// Calculate() folds into NetworkMap.EnableSSH. Components-format peers
// receive a freshly-computed PeerConfig.SshConfig.SshEnabled at sync time;
// without this helper the field would be incorrectly false for any peer
// that's the destination of an SSH-enabling policy without having
// peer.SSHEnabled set locally.
//
// Mirrors the two activation paths Calculate() uses:
// 1. Explicit: rule.Protocol == NetbirdSSH and peer is in the rule's
// destinations.
// 2. Legacy implicit: rule covers TCP/22 or TCP/22022 (or ALL), peer is in
// destinations, AND the peer has SSHEnabled set locally — this is the
// "allow-all/TCP-22 implies SSH activation for SSH-capable peers" path.
//
// The full SSH AuthorizedUsers map is still produced by the client when it
// runs Calculate() over the envelope.
func computeSSHEnabledForPeer(c *types.NetworkMapComponents, peer *nbpeer.Peer) bool {
if c == nil || peer == nil {
return false
}
// Mirror Calculate's `getAllPeersFromGroups` invariant: target peer must
// exist in c.Peers, otherwise no rule applies to it.
if _, ok := c.Peers[peer.ID]; !ok {
return false
}
for _, policy := range c.Policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if ruleEnablesSSHForPeer(c, rule, peer) {
return true
}
}
}
return false
}
// ruleEnablesSSHForPeer returns true when rule is active, targets peer, and
// either explicitly authorises SSH or covers the legacy TCP/22 path while the
// peer itself has SSH enabled locally.
func ruleEnablesSSHForPeer(c *types.NetworkMapComponents, rule *types.PolicyRule, peer *nbpeer.Peer) bool {
if rule == nil || !rule.Enabled {
return false
}
if !peerInDestinations(c, rule, peer.ID) {
return false
}
if rule.Protocol == types.PolicyRuleProtocolNetbirdSSH {
return true
}
return peer.SSHEnabled && types.PolicyRuleImpliesLegacySSH(rule)
}
// peerInDestinations reports whether peerID is in any of rule.Destinations'
// groups (or matches DestinationResource if it's a peer-typed resource —
// for non-peer types Calculate falls through to group lookup, so we mirror
// that exactly to avoid silent divergence).
func peerInDestinations(c *types.NetworkMapComponents, rule *types.PolicyRule, peerID string) bool {
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
return rule.DestinationResource.ID == peerID
}
for _, groupID := range rule.Destinations {
if c.IsPeerInGroup(peerID, groupID) {
return true
}
}
return false
}

View File

@@ -1,184 +0,0 @@
package grpc
import (
"testing"
"github.com/stretchr/testify/assert"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types"
)
// TestComputeSSHEnabledForPeer covers both Calculate-mirroring branches:
// explicit NetbirdSSH protocol, and the legacy implicit case where a
// TCP/22 (or 22022 / ALL / port-range-covering-22) rule activates SSH when
// the destination peer has SSHEnabled=true locally.
func TestComputeSSHEnabledForPeer(t *testing.T) {
const targetPeerID = "target"
const targetGroupID = "g_dst"
mkComponents := func(rule *types.PolicyRule, sshEnabled bool) (*types.NetworkMapComponents, *nbpeer.Peer) {
peer := &nbpeer.Peer{ID: targetPeerID, SSHEnabled: sshEnabled}
group := &types.Group{ID: targetGroupID, Name: "dst", Peers: []string{targetPeerID}}
return &types.NetworkMapComponents{
Peers: map[string]*nbpeer.Peer{targetPeerID: peer},
Groups: map[string]*types.Group{targetGroupID: group},
Policies: []*types.Policy{{
ID: "p",
Enabled: true,
Rules: []*types.PolicyRule{rule},
}},
}, peer
}
cases := []struct {
name string
peerSSH bool
rule types.PolicyRule
wantEnabled bool
}{
{
name: "explicit-netbird-ssh-activates-regardless-of-peer-ssh",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-tcp-22-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-tcp-22-without-peer-ssh-disabled",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22"},
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "implicit-tcp-22022-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"22022"},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-all-protocol-with-peer-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolALL,
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "implicit-port-range-covers-22",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolTCP,
PortRanges: []types.RulePortRange{{Start: 20, End: 30}},
Destinations: []string{targetGroupID},
},
wantEnabled: true,
},
{
name: "tcp-80-no-ssh",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"},
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "disabled-rule-skipped",
peerSSH: true,
rule: types.PolicyRule{
Enabled: false, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{targetGroupID},
},
wantEnabled: false,
},
{
name: "peer-not-in-destinations",
peerSSH: true,
rule: types.PolicyRule{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{"g_other"}, // target not in this group
},
wantEnabled: false,
},
{
name: "peer-typed-destination-resource-matches",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
DestinationResource: types.Resource{ID: targetPeerID, Type: types.ResourceTypePeer},
},
wantEnabled: true,
},
{
name: "non-peer-destination-resource-falls-through-to-groups",
peerSSH: false,
rule: types.PolicyRule{
Enabled: true,
Protocol: types.PolicyRuleProtocolNetbirdSSH,
DestinationResource: types.Resource{ID: targetPeerID, Type: "host"}, // wrong type
Destinations: []string{targetGroupID}, // saved by group fallback
},
wantEnabled: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
c, peer := mkComponents(&tc.rule, tc.peerSSH)
got := computeSSHEnabledForPeer(c, peer)
assert.Equal(t, tc.wantEnabled, got)
})
}
}
// TestComputeSSHEnabledForPeer_TargetMissingFromComponents covers the
// belt-and-suspenders presence guard mirroring Calculate's
// getAllPeersFromGroups invariant.
func TestComputeSSHEnabledForPeer_TargetMissingFromComponents(t *testing.T) {
peer := &nbpeer.Peer{ID: "missing", SSHEnabled: true}
c := &types.NetworkMapComponents{
Peers: map[string]*nbpeer.Peer{}, // target peer NOT present
Groups: map[string]*types.Group{
"g": {ID: "g", Peers: []string{"missing"}},
},
Policies: []*types.Policy{{
ID: "p", Enabled: true,
Rules: []*types.PolicyRule{{
Enabled: true, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Destinations: []string{"g"},
}},
}},
}
assert.False(t, computeSSHEnabledForPeer(c, peer),
"missing target peer must short-circuit to false, not consult policies")
}
// TestComputeSSHEnabledForPeer_NilInputs guards the cheap nil-checks at
// function entry — Calculate doesn't accept nil either, but the helper is
// exported indirectly via ToComponentSyncResponse and may receive nil
// components on graceful-degrade paths.
func TestComputeSSHEnabledForPeer_NilInputs(t *testing.T) {
assert.False(t, computeSSHEnabledForPeer(nil, &nbpeer.Peer{ID: "x"}))
assert.False(t, computeSSHEnabledForPeer(&types.NetworkMapComponents{}, nil))
}

View File

@@ -10,20 +10,24 @@ import (
"github.com/hashicorp/go-version"
nbversion "github.com/netbirdio/netbird/version"
log "github.com/sirupsen/logrus"
goproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/networkmap"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
"github.com/netbirdio/netbird/shared/sshauth"
)
const (
@@ -155,8 +159,8 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
Routes: networkmap.ToProtocolRoutes(networkMap.Routes),
DNSConfig: networkmap.ToProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig, networkMap.EnableSSH),
},
Checks: toProtocolChecks(ctx, checks),
@@ -169,7 +173,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.NetworkMap.PeerConfig = response.PeerConfig
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
remotePeers = networkmap.AppendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, includeIPv6)
if !shouldSkipSendingDeprecatedRemotePeers(peer.Meta.WtVersion) {
response.RemotePeers = remotePeers
@@ -179,13 +183,13 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
response.RemotePeersIsEmpty = len(remotePeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = networkmap.AppendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, includeIPv6)
firewallRules := networkmap.ToProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules, includeIPv6, useSourcePrefixes)
response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
routesFirewallRules := networkmap.ToProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
routesFirewallRules := toProtocolRoutesFirewallRules(networkMap.RoutesFirewallRules)
response.NetworkMap.RoutesFirewallRules = routesFirewallRules
response.NetworkMap.RoutesFirewallRulesIsEmpty = len(routesFirewallRules) == 0
@@ -198,7 +202,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
}
if networkMap.AuthorizedUsers != nil {
hashedUsers, machineUsers := networkmap.BuildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
hashedUsers, machineUsers := buildAuthorizedUsersProto(ctx, networkMap.AuthorizedUsers)
userIDClaim := auth.DefaultUserIDClaim
if httpConfig != nil && httpConfig.AuthUserIDClaim != "" {
userIDClaim = httpConfig.AuthUserIDClaim
@@ -238,6 +242,33 @@ func encodeSessionExpiresAt(deadline time.Time) *timestamppb.Timestamp {
return timestamppb.New(deadline)
}
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
userIDToIndex := make(map[string]uint32)
var hashedUsers [][]byte
machineUsers := make(map[string]*proto.MachineUserIndexes, len(authorizedUsers))
for machineUser, users := range authorizedUsers {
indexes := make([]uint32, 0, len(users))
for userID := range users {
idx, exists := userIDToIndex[userID]
if !exists {
hash, err := sshauth.HashUserID(userID)
if err != nil {
log.WithContext(ctx).Errorf("failed to hash user id %s: %v", userID, err)
continue
}
idx = uint32(len(hashedUsers))
userIDToIndex[userID] = idx
hashedUsers = append(hashedUsers, hash[:])
}
indexes = append(indexes, idx)
}
machineUsers[machineUser] = &proto.MachineUserIndexes{Indexes: indexes}
}
return hashedUsers, machineUsers
}
func shouldSkipSendingDeprecatedRemotePeers(peerVersion string) bool {
if nbversion.IsDevelopmentVersion(peerVersion) {
return true
@@ -251,6 +282,51 @@ func shouldSkipSendingDeprecatedRemotePeers(peerVersion string) bool {
return precomputedDeprecatedRemotePeersConstraint.Check(peerNBVersion)
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, includeIPv6 bool) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
allowedIPs := []string{rPeer.IP.String() + "/32"}
if includeIPv6 && rPeer.IPv6.IsValid() {
allowedIPs = append(allowedIPs, rPeer.IPv6.String()+"/128")
}
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: allowedIPs,
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
AgentVersion: rPeer.Meta.WtVersion,
})
}
return dst
}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *cache.DNSConfigCache, forwardPort int64) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
ForwarderPort: forwardPort,
}
for _, zone := range update.CustomZones {
protoZone := convertToProtoCustomZone(zone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
cacheKey := nsGroup.ID
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
} else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
}
return protoUpdate
}
func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
switch configProto {
case nbconfig.UDP:
@@ -268,6 +344,203 @@ func ToResponseProto(configProto nbconfig.Protocol) proto.HostConfig_Protocol {
}
}
func toProtocolRoutes(routes []*nbroute.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes {
protoRoutes = append(protoRoutes, toProtocolRoute(r))
}
return protoRoutes
}
func toProtocolRoute(route *nbroute.Route) *proto.Route {
return &proto.Route{
ID: string(route.ID),
NetID: string(route.NetID),
Network: route.Network.String(),
Domains: route.Domains.ToPunycodeList(),
NetworkType: int64(route.NetworkType),
Peer: route.Peer,
Metric: int64(route.Metric),
Masquerade: route.Masquerade,
KeepRoute: route.KeepRoute,
SkipAutoApply: route.SkipAutoApply,
}
}
// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
// When useSourcePrefixes is true, the compact SourcePrefixes field is populated
// alongside the deprecated PeerIP for forward compatibility.
// Wildcard rules ("0.0.0.0") are expanded into separate v4 and v6 SourcePrefixes
// when includeIPv6 is true.
func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSourcePrefixes bool) []*proto.FirewallRule {
result := make([]*proto.FirewallRule, 0, len(rules))
for i := range rules {
rule := rules[i]
fwRule := &proto.FirewallRule{
PolicyID: []byte(rule.PolicyID),
PeerIP: rule.PeerIP, //nolint:staticcheck // populated for backward compatibility
Direction: getProtoDirection(rule.Direction),
Action: getProtoAction(rule.Action),
Protocol: getProtoProtocol(rule.Protocol),
Port: rule.Port,
}
if useSourcePrefixes && rule.PeerIP != "" {
result = append(result, populateSourcePrefixes(fwRule, rule, includeIPv6)...)
}
if shouldUsePortRange(fwRule) {
fwRule.PortInfo = rule.PortRange.ToProto()
}
result = append(result, fwRule)
}
return result
}
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {
addr, err := netip.ParseAddr(rule.PeerIP)
if err != nil {
return nil
}
if !addr.IsUnspecified() {
fwRule.SourcePrefixes = [][]byte{netiputil.EncodeAddr(addr.Unmap())}
return nil
}
// IPv4Unspecified/0 is always valid, error is impossible.
v4Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv4Unspecified(), 0))
fwRule.SourcePrefixes = [][]byte{v4Wildcard}
if !includeIPv6 {
return nil
}
v6Rule := goproto.Clone(fwRule).(*proto.FirewallRule)
v6Rule.PeerIP = "::" //nolint:staticcheck // populated for backward compatibility
// IPv6Unspecified/0 is always valid, error is impossible.
v6Wildcard, _ := netiputil.EncodePrefix(netip.PrefixFrom(netip.IPv6Unspecified(), 0))
v6Rule.SourcePrefixes = [][]byte{v6Wildcard}
if shouldUsePortRange(v6Rule) {
v6Rule.PortInfo = rule.PortRange.ToProto()
}
return []*proto.FirewallRule{v6Rule}
}
// getProtoDirection converts the direction to proto.RuleDirection.
func getProtoDirection(direction int) proto.RuleDirection {
if direction == types.FirewallRuleDirectionOUT {
return proto.RuleDirection_OUT
}
return proto.RuleDirection_IN
}
func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule {
result := make([]*proto.RouteFirewallRule, len(rules))
for i := range rules {
rule := rules[i]
result[i] = &proto.RouteFirewallRule{
SourceRanges: rule.SourceRanges,
Action: getProtoAction(rule.Action),
Destination: rule.Destination,
Protocol: getProtoProtocol(rule.Protocol),
PortInfo: getProtoPortInfo(rule),
IsDynamic: rule.IsDynamic,
Domains: rule.Domains.ToPunycodeList(),
PolicyID: []byte(rule.PolicyID),
RouteID: string(rule.RouteID),
}
}
return result
}
// getProtoAction converts the action to proto.RuleAction.
func getProtoAction(action string) proto.RuleAction {
if action == string(types.PolicyTrafficActionDrop) {
return proto.RuleAction_DROP
}
return proto.RuleAction_ACCEPT
}
// getProtoProtocol converts the protocol to proto.RuleProtocol.
func getProtoProtocol(protocol string) proto.RuleProtocol {
switch types.PolicyRuleProtocolType(protocol) {
case types.PolicyRuleProtocolALL:
return proto.RuleProtocol_ALL
case types.PolicyRuleProtocolTCP:
return proto.RuleProtocol_TCP
case types.PolicyRuleProtocolUDP:
return proto.RuleProtocol_UDP
case types.PolicyRuleProtocolICMP:
return proto.RuleProtocol_ICMP
default:
return proto.RuleProtocol_UNKNOWN
}
}
// getProtoPortInfo converts the port info to proto.PortInfo.
func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
var portInfo proto.PortInfo
if rule.Port != 0 {
portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)}
} else if portRange := rule.PortRange; portRange.Start != 0 && portRange.End != 0 {
portInfo.PortSelection = &proto.PortInfo_Range_{
Range: &proto.PortInfo_Range{
Start: uint32(portRange.Start),
End: uint32(portRange.End),
},
}
}
return &portInfo
}
func shouldUsePortRange(rule *proto.FirewallRule) bool {
return rule.Port == "" && (rule.Protocol == proto.RuleProtocol_UDP || rule.Protocol == proto.RuleProtocol_TCP)
}
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
SearchDomainDisabled: zone.SearchDomainDisabled,
NonAuthoritative: zone.NonAuthoritative,
}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}
// buildJWTConfig constructs JWT configuration for SSH servers from management server config
func buildJWTConfig(config *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.JWTConfig {
if config == nil || config.AuthAudience == "" {

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/shared/management/networkmap"
)
func TestToProtocolDNSConfigWithCache(t *testing.T) {
@@ -63,13 +62,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) {
}
// First run with config1
result1 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
result1 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Second run with config2
result2 := networkmap.ToProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
result2 := toProtocolDNSConfig(config2, &cache, int64(network_map.DnsForwarderPort))
// Third run with config1 again
result3 := networkmap.ToProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
result3 := toProtocolDNSConfig(config1, &cache, int64(network_map.DnsForwarderPort))
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
@@ -101,7 +100,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
@@ -109,7 +108,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &cache.DNSConfigCache{}
networkmap.ToProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
toProtocolDNSConfig(testData, cache, int64(network_map.DnsForwarderPort))
}
})
}

View File

@@ -11,9 +11,9 @@ import (
const (
reconnThreshold = 5 * time.Minute
baseBlockDuration = 30 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
metaChangeLimit = 5 // Number of reconnections with different metadata that triggers a ban of one peer
)
type lfConfig struct {
@@ -142,6 +142,7 @@ func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
h := fnv.New64a()
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))

View File

@@ -1016,31 +1016,7 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
return status.Errorf(codes.Internal, "failed to get peer groups %s", err)
}
dnsName := s.networkMapController.GetDNSDomain(settings)
var plainResp *proto.SyncResponse
if s.networkMapController.PeerNeedsComponents(peer) {
// Capable peer: discard the legacy NetworkMap that SyncAndMarkPeer
// computed and recompute the raw components instead. This wastes one
// Calculate() call per initial-sync — the component-based wire
// format is what the peer actually consumes. The streaming path
// (network_map.Controller.UpdateAccountPeers) skips this duplication
// because it dispatches by capability before computing.
//
// TODO: refactor SyncPeer / SyncAndMarkPeer / their mocks + manager
// interfaces to return PeerNetworkMapResult so the initial-sync path
// stops doing duplicate work. Deferred until the client-side
// decoder lands and there's a real deployment of capability=3 peers
// worth optimizing for.
_, components, proxyPatch, _, _, err := s.networkMapController.GetValidatedPeerWithComponents(ctx, false, peer.AccountID, peer)
if err != nil {
log.WithContext(ctx).Errorf("failed to build components for peer %s on initial sync: %v", peer.ID, err)
return status.Errorf(codes.Internal, "failed to build initial sync envelope")
}
plainResp = ToComponentSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, components, proxyPatch, dnsName, postureChecks, settings, settings.Extra, peerGroups, dnsFwdPort)
} else {
plainResp = ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, dnsName, postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
}
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
key, err := s.secretsManager.GetWGKey()
if err != nil {

View File

@@ -1636,14 +1636,6 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil
}
for _, g := range newGroupsToCreate {
seq, err := transaction.AllocateAccountSeqID(ctx, userAuth.AccountId, types.AccountSeqEntityGroup)
if err != nil {
return fmt.Errorf("error allocating group seq id: %w", err)
}
g.AccountSeqID = seq
}
if err = transaction.CreateGroups(ctx, userAuth.AccountId, newGroupsToCreate); err != nil {
return fmt.Errorf("error saving groups: %w", err)
}

View File

@@ -3170,16 +3170,6 @@ func TestAccount_SetJWTGroups(t *testing.T) {
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added")
var newJWTGroup *types.Group
for _, g := range groups {
if g.Name == "group3" {
newJWTGroup = g
break
}
}
require.NotNil(t, newJWTGroup, "JIT-created JWT group not found")
assert.NotZero(t, newJWTGroup.AccountSeqID, "JIT-created JWT group must have a non-zero AccountSeqID")
})
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {

Some files were not shown because too many files have changed in this diff Show More