Compare commits

..

19 Commits

Author SHA1 Message Date
pascal
d6759ce93c run on meta update 2026-06-18 00:12:36 +02:00
Maycon Santos
8d9580e491 [misc] improve goreleaser with RC handling and update docker builds (#6438)
- introduce variables to avoid publishing latest docker tags and installers
- Refactor .goreleaser.yaml to simplify docker configurations and add environment-driven flags
- removed management debug containers (it was doing only log var)
- Stopped building arm v6 32bits in favor of v7 32 bits for services (not client)
- Add target argument to docker files
2026-06-17 20:13:13 +02:00
Viktor Liu
5bd7c6c7ea [client] Detect and recover from a stalled signal receive stream (#6459) 2026-06-17 18:48:09 +02:00
Zoltan Papp
8ae2cd0a08 [client] Fix ios route notify ordering (#6454)
* [client] fix iOS route-update reordering that black-holed IPv6 on exit-node disable

On iOS the route notifier delivered each prefix update from its own
fire-and-forget goroutine (notify -> `go func`), so Go provided no ordering
guarantee between consecutive updates. It also read currentPrefixes inside
that goroutine without holding the lock, racing the next OnNewPrefixes write.

On exit-node disable the core removes the default routes as two separate
prefix updates (0.0.0.0/0, then the synthesized ::/0). When the two
goroutines were reordered, the stale snapshot still containing ::/0 was
delivered last and clobbered the correct default-free one. iOS then kept the
::/0 default route on the tunnel with no exit node to carry it, black-holing
all IPv6 traffic while IPv4 recovered correctly.

Fix: deliver updates through a single worker goroutine fed by a buffered
channel, preserving production order, and snapshot the joined prefix string
under the mutex so it can't race a concurrent update. Buffered so producers
(which run under the route manager lock) don't block on the listener callback.

* [client] close iOS notifier delivery goroutine on Stop, unbounded queue

The delivery goroutine was never stopped, leaking on every engine
restart. Add Notifier.Close, called from the route manager Stop after
routing cleanup.

Replace the buffered update channel with a cond-driven linked-list
queue so route-update producers (running under the route manager lock)
never block when the listener callback is slow.
2026-06-17 18:29:33 +02:00
Pascal Fischer
e4397d4d46 [management] remove nmap calc from login (#6449) 2026-06-17 16:37:24 +02:00
Viktor Liu
6fbc90b4d3 [client, relay] Expose relay transport and connection errors in status and metrics (#6342) 2026-06-17 15:41:48 +02:00
Riccardo Manfrin
5095e17cc5 [management] fix flaky Test_SaveAccount_Large from random IP collision (#6452) 2026-06-17 14:00:50 +02:00
Zoltan Papp
6df0175607 [client] Add IsLoginRequiredCached for iOS mobile client (#6447)
Expose a network-free login-required check backed by the in-memory status
recorder. Unlike IsLoginRequired(), which creates a fresh auth client and
performs a blocking network call, IsLoginRequiredCached() reports whether the
LAST observed management error was an auth failure (PermissionDenied/
InvalidArgument).

This lets the iOS connection listener detect a mid-session token expiry from
within onDisconnected during teardown without blocking on a slow or
unavailable network.
2026-06-16 16:15:19 +02:00
Zoltan Papp
3c23700e56 [client] Add iOS debug bundle support in Go (#6270)
* Add iOS debug bundle support in Go

Thread cacheDir through NewClient -> RunOniOS -> MobileDependency.TempDir
so the iOS client can pass its sandbox-writable cache directory for
debug bundle zip file creation instead of os.TempDir().

Move log collection into platform-dispatched addPlatformLog():
- iOS: adds the file-based Go client log (with rotation, stderr/stdout
  companions and anonymization handled by addLogfile) plus the Swift app
  log (swift-log.log) written by the iOS app into the same log directory
- Other non-Android platforms: existing file-based log + systemd fallback

Narrow the debug_nonandroid.go build tag to !android && !ios so iOS no
longer attempts the systemd journal fallback.

Add a DebugBundle() entry point to the iOS Go client that generates a
bundle, uploads it and returns the upload key. It works with or without
a running engine: when the engine is up it reuses the live config, sync
response and client metrics; otherwise it loads the config from disk (or
the preloaded tvOS config). Guard the live config/ConnectClient behind a
state mutex since DebugBundle may run on a different thread.

* Include the iOS state file in the debug bundle

addStateFile() resolved the state path via ServiceManager.GetStatePath(),
which on iOS points at a hard-coded default that does not exist in the app
sandbox, so the state file was silently skipped.

Add an optional StatePath to GeneratorDependencies and use it when set,
falling back to the ServiceManager default otherwise. The iOS DebugBundle
passes the client's actual state file path (the App Group profile state),
matching the Android bundle which includes the state file.

* ios: enable sync response persistence for debug bundle

Turn on sync response persistence before starting the engine so
DebugBundle can include the network map. On iOS the store is disk-backed
(see syncstore) to keep the map out of the constrained process memory.

* ios: pass log file path through NewClient constructor (#6393)

Add logFilePath field to Client struct and expose it as a parameter
in NewClient so callers provide the Go log path at construction time.
Wire it into DebugBundle via GeneratorDependencies.LogPath so the
debug bundle includes client.log and swift-log.log regardless of
whether the bundle is triggered by the app or the management server.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

* ios: pass log file path to engine for remote debug bundles

RunOniOS started the engine with an empty LogPath, so EngineConfig.LogPath
was never set. Management-triggered (jobs) debug bundles read the log path
from the engine config, so they collected no client logs (client.log,
rotated logs, swift-log.log). The GUI path was unaffected because it passes
c.logFilePath directly to the bundle generator.

Thread c.logFilePath through RunOniOS into the engine config so remote
bundles include the client logs too.

---------

Co-authored-by: evgeniyChepelev <68751844+evgeniyChepelev@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-16 15:54:46 +02:00
Pascal Fischer
38ad2b67e8 [proxy] fix context for udprelay (#6444) 2026-06-16 14:41:17 +02:00
Pascal Fischer
01aa49433e [management] delete targets when deleting exposed service (#6442) 2026-06-16 14:33:24 +02:00
Zoltan Papp
08a2b63675 [client] propagate exit-node deselect to synthesized v6 (::/0) route (#6296)
* [client] propagate exit-node deselect to synthesized v6 (::/0) route

When a client deselects an IPv4 exit node, the auto-generated IPv6 default
route (::/0) was still selected and pushed onto the tunnel interface, even
though the user disabled the exit node. On an exit node without a real IPv6
egress this blackholes IPv6 traffic, and because clients prefer IPv6 (happy
eyeballs) it can break general connectivity.

Root cause: the synthesized v6 route gets a different NetID than its v4 base
(base + "-v6"). The route selector keys deselects by NetID and defaults
unknown NetIDs to selected, so the "-v6" entry was never matched by the v4
deselect. The effectiveNetID() mirror that solves exactly this is used by
HasUserSelectionForRoute and FilterSelectedExitNodes, but categorizeUserSelection
called the raw IsSelected(), bypassing it and mis-categorizing the v6 pair as
user-selected.

Add RouteSelector.IsSelectedForExitNode(), which applies effectiveNetID before
the selection check, and use it in categorizeUserSelection. IsSelected() is left
untouched so non-exit code paths don't make unrelated "*-v6" routes inherit v4
state. Adds regression tests for the v4/v6 deselect mirror and explicit-v6
override.

* [client] add DIAG logging to trace exit-node v6 (::/0) route filtering

Temporary diagnostics to find why a deselected v4 exit node's synthesized
::/0 route still reaches the tunnel. Logs the full install path: incoming
client networks, route-selector state before/after the management-driven
update, what updateExitNodeSelections deselects/selects, and per-route
KEEP/SKIP/DROP decisions in FilterSelectedExitNodes and applyExitNodeFilter.
To be reverted once the real root cause is confirmed from a client log.

* [client] clear orphaned v6 exit selection when v4 pair is toggled

Root cause of the leaking ::/0 route, confirmed from client logs: the
synthesized "-v6" exit route could stay explicitly selected in the persisted
route-selector state while its v4 base was deselected (selected=[...-v6],
deselected=[...v4base]). Because the v6 entry then has its own explicit state,
effectiveNetID stops mirroring the v4 base, so FilterSelectedExitNodes keeps
::/0 and it is installed on the tunnel even though the user disabled the exit
node. This happened because the iOS SDK's deselect only pairs the "-v6" sibling
via ExpandV6ExitPairs when the v6 route is present in the current routesMap; a
deselect at a moment it wasn't expanded left the v6 selection orphaned.

Fix at the selector write path so it is independent of routesMap timing: when a
v4 exit NetID is selected or deselected, clear any orphaned explicit state on
its "-v6" sibling (clearPairedV6Locked), unless the sibling is part of the same
batch (the deliberate ExpandV6ExitPairs case). The v6 then falls back to
inheriting the v4 base via effectiveNetID, so a v4 deselect also drops ::/0 and
a v4 select brings both back.

Adds regression tests: a stale explicit v6 selection is cleared by a later v4
deselect, and an explicit v6 select made in the same batch is preserved.

* [ios] compute route connection status in the bridge

The iOS bridge exposed a route's Network as a possibly comma-joined string
("0.0.0.0/0, ::/0" for a merged exit node) but no connection status, forcing
the UI to infer status by string-matching that joined value against peer
routes — which never matched for the merged exit node, leaving it stuck as
not-connected. Android already computes status in the core (findBestRoutePeer).

Mirror that here: add a Status field to RoutesSelectionInfo and compute it from
the connected peers' route tables, matching the route's primary prefix, a merged
exit node's extra v6 prefix, or a dynamic route's domain pattern (the key the
route manager records). The UI can now read the status directly.

* [client] remove exit-node v6 DIAG logging and tidy routeselector

Drop the temporary DIAG diagnostics added to trace the leaking ::/0 route
(the root cause is fixed and confirmed). Also reorganize routeselector.go so
the exit-node helpers (clearPairedV6Locked, isExitNode) sit next to the
exit-node code paths and MarshalJSON/UnmarshalJSON are grouped together.

* [client] mirror v4 exit selection onto v6 pair at write time

The synthesized "-v6" exit route shares its v4 base's NetID plus a "-v6"
suffix. Selection state was reconciled at read time via effectiveNetID, a
mirror that could only be applied on exit-node code paths, which forced a
parallel IsSelectedForExitNode() alongside IsSelected() and a clearPairedV6Locked()
orphan cleanup on every toggle. That machinery still missed the case observed
in the field: a persisted state with the v4 base deselected but its "-v6"
sibling explicitly selected (orphaned). Because effectiveNetID returns the v6
entry itself once it carries explicit state, and clearPairedV6Locked only fires
on a live toggle, the loaded orphan survived and the ::/0 route leaked onto the
tunnel despite the exit node being disabled, breaking IPv6 (happy eyeballs).

Treat the v4/v6 exit pair as a single toggle and keep state consistent at write
time instead. RouteSelector.SyncPairedSelection forces the "-v6" entry to match
its v4 base unconditionally, resetting any orphaned explicit state. The route
manager, which knows the route prefixes, computes the pairs (V6ExitMergeSet) and
calls it from updateRouteSelectorFromManagement before selection is read, so both
collectExitNodeInfo and FilterSelectedExitNodes see consistent state, including
pairs loaded from persisted selector state.

This removes effectiveNetID, IsSelectedForExitNode and clearPairedV6Locked; the
selector is literal again and no longer needs the "exit-node paths only" caveat.
HasUserSelectionForRoute and applyExitNodeFilter use the raw NetID.

Adds a selector test for SyncPairedSelection (including the orphaned-v6 case) and
a route-manager test reproducing the persisted-orphan scenario from the field log.

* [client] add DIAG logging to trace v6 exit-pair mirror

The write-time mirror did not eliminate the leak in field testing. Re-add the
DIAG diagnostics around the exit-node selection flow to capture a fresh trace:

- UpdateRoutes: incoming client networks, selector state before/after the
  management update, and the networks remaining after FilterSelectedExitNodes.
- mirrorV6ExitPairSelections: the NetIDs present in this update and the v6 pairs
  V6ExitMergeSet derives from them (reveals whether the v4 base and its ::/0 pair
  are present in the same update so the pair can be matched).
- SyncPairedSelection: the base/paired state before and after the sync.
- FilterSelectedExitNodes / applyExitNodeFilter: per-route SKIP/KEEP/DROP and the
  selection lookups behind each decision.
- updateExitNodeSelections / logExitNodeUpdate: categorization and deselect set.

Temporary; to be removed once the root cause is confirmed.

* [client] remove v6 exit-pair mirror DIAG logging

Drop the temporary DIAG diagnostics added to trace the v4/v6 exit-pair mirror.
The field log confirmed the write-time mirror keeps the pair consistent (the
::/0 route is only ever applied alongside its v4 base and is dropped on deselect),
so the diagnostics are no longer needed.
2026-06-16 12:27:58 +02:00
Maycon Santos
b3f9e6588a [management] sync openapi spec and test for diff on workflows (#6437)
* [management] sync openapi spec and test for diff on workflows

* [management] pin oapi-codegen version to v2.7.1
2026-06-15 17:53:25 +02:00
Pascal Fischer
967e2d6864 [management] network map for affected peers (#6105) 2026-06-15 17:43:22 +02:00
Zoltan Papp
e7c1d364c3 [management] treat ci- builds as development for remote jobs (#6436)
* fix(management): treat ci- builds as development for remote jobs

CI snapshot builds use a "ci-<sha>" version string that did not match
IsDevelopmentVersion, so the remote-jobs minimum-version gate rejected
them. Recognize the "ci-" prefix as a development build.

* fix(management): treat dev- builds as development for remote jobs

Dev snapshot builds use a "dev-<sha>" version string that did not match
IsDevelopmentVersion, so the remote-jobs minimum-version gate rejected
them. Recognize the "dev-" prefix as a development build, alongside the
existing "ci-" prefix.
2026-06-15 17:22:40 +02:00
Viktor Liu
a44198fd77 [client] Add dialWebSocket method to WASM client (#5980) 2026-06-15 16:43:24 +02:00
Viktor Liu
b57f714350 [client] Drop signaling-side ICE candidate filter, drop overlay STUN at mux read-side instead (#6142) 2026-06-15 16:37:03 +02:00
Viktor Liu
f893abc41d [client] Recover from tun device read/write panics and restart the client (#6419) 2026-06-15 16:36:00 +02:00
Lee Sang Hoon
60067619a1 [proxy] Keep custom TCP listeners alive after mapping batches (#6415) 2026-06-15 12:21:24 +02:00
147 changed files with 8491 additions and 2354 deletions

View File

@@ -9,10 +9,13 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.5"
GORELEASER_VER: "v2.14.3"
SIGN_PIPE_VER: "v0.1.6"
GORELEASER_VER: "v2.16.0"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
flags: ""
SKIP_PUBLISH: "true"
SKIP_DOCKER_PUSH: "false"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -130,8 +133,6 @@ jobs:
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
env:
flags: ""
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -143,8 +144,27 @@ jobs:
id: semver_parser
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Set snapshot flag
if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: |
echo "flags=--snapshot" >> $GITHUB_ENV
- name: Set build vars
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
run: |
if [[ "x-${{ steps.semver_parser.outputs.prerelease }}" == "x-" && "x-${{ github.repository }}" == "x-netbirdio/netbird" ]]; then
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
echo "SKIP_PUBLISH=false" >> $GITHUB_ENV
else
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
fi
if [[ "x-${{ github.repository }}" != "x-netbirdio/netbird" ]]; then
echo "SKIP_DOCKER_PUSH=true" >> $GITHUB_ENV
fi
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
@@ -161,6 +181,8 @@ jobs:
${{ runner.os }}-go-releaser-
- name: Install modules
run: go mod tidy
- name: run openapi generator
run: bash shared/management/http/api/generate.sh
- name: check git status
run: git --no-pager diff --exit-code
- name: Set up QEMU
@@ -210,6 +232,8 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
SKIP_PUBLISH: ${{ env.SKIP_PUBLISH }}
SKIP_DOCKER_PUSH: ${{ env.SKIP_DOCKER_PUSH }}
- name: Verify RPM signatures
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
@@ -332,8 +356,22 @@ jobs:
id: semver_parser
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Set snapshot flag
if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: |
echo "flags=--snapshot" >> $GITHUB_ENV
- name: Set build vars
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
run: |
if [[ "x-${{ steps.semver_parser.outputs.prerelease }}" == "x-" && "x-${{ github.repository }}" == "x-netbirdio/netbird" ]]; then
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
echo "SKIP_PUBLISH=false" >> $GITHUB_ENV
else
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
fi
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
@@ -393,6 +431,7 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
SKIP_PUBLISH: ${{ env.SKIP_PUBLISH }}
- name: Verify RPM signatures
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '

View File

@@ -1,5 +1,7 @@
version: 2
env:
- SKIP_PUBLISH={{ if index .Env "SKIP_PUBLISH" }}{{ .Env.SKIP_PUBLISH }}{{ else }}true{{ end }}
- SKIP_DOCKER_PUSH={{ if index .Env "SKIP_DOCKER_PUSH" }}{{ .Env.SKIP_DOCKER_PUSH }}{{ else }}false{{ end }}
project_name: netbird
builds:
- id: netbird-wasm
@@ -74,6 +76,8 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -88,6 +92,8 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -102,6 +108,8 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -122,6 +130,8 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -136,6 +146,8 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -150,6 +162,8 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -170,6 +184,8 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -222,670 +238,192 @@ nfpms:
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
dockers:
- image_templates:
- netbirdio/netbird:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
ids:
- netbird
goarch: amd64
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
ids:
- netbird
goarch: arm64
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
ids:
- netbird
goarch: arm
goarm: 6
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
ids:
- netbird
goarch: amd64
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
ids:
- netbird
goarch: arm64
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
ids:
- netbird
goarch: arm
goarm: 6
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-amd64
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
ids:
- netbird-relay
goarch: amd64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
ids:
- netbird-relay
goarch: arm64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
ids:
- netbird-relay
goarch: arm
goarm: 6
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-amd64
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
ids:
- netbird-signal
goarch: amd64
use: buildx
dockerfile: signal/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
ids:
- netbird-signal
goarch: arm64
use: buildx
dockerfile: signal/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
ids:
- netbird-signal
goarch: arm
goarm: 6
use: buildx
dockerfile: signal/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-amd64
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
ids:
- netbird-mgmt
goarch: amd64
use: buildx
dockerfile: management/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
ids:
- netbird-mgmt
goarch: arm64
use: buildx
dockerfile: management/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-arm
ids:
- netbird-mgmt
goarch: arm
goarm: 6
use: buildx
dockerfile: management/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-amd64
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
ids:
- netbird-mgmt
goarch: amd64
use: buildx
dockerfile: management/Dockerfile.debug
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
ids:
- netbird-mgmt
goarch: arm64
use: buildx
dockerfile: management/Dockerfile.debug
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-arm
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
ids:
- netbird-mgmt
goarch: arm
goarm: 6
use: buildx
dockerfile: management/Dockerfile.debug
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-amd64
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
ids:
- netbird-upload
goarch: amd64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
ids:
- netbird-upload
goarch: arm64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
ids:
- netbird-upload
goarch: arm
goarm: 6
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
ids:
- netbird-server
goarch: amd64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
ids:
- netbird-server
goarch: arm64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
ids:
- netbird-server
goarch: arm
goarm: 6
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
ids:
- netbird-proxy
goarch: amd64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
ids:
- netbird-proxy
goarch: arm64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
ids:
- netbird-proxy
goarch: arm
goarm: 6
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
- netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/netbird:latest
image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
- netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/netbird:{{ .Version }}-rootless
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/netbird:rootless-latest
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/relay:{{ .Version }}
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/relay:latest
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/signal:{{ .Version }}
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- netbirdio/signal:{{ .Version }}-arm
- netbirdio/signal:{{ .Version }}-amd64
- name_template: netbirdio/signal:latest
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- netbirdio/signal:{{ .Version }}-arm
- netbirdio/signal:{{ .Version }}-amd64
- name_template: netbirdio/management:{{ .Version }}
image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- netbirdio/management:{{ .Version }}-arm
- netbirdio/management:{{ .Version }}-amd64
- name_template: netbirdio/management:latest
image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- netbirdio/management:{{ .Version }}-arm
- netbirdio/management:{{ .Version }}-amd64
- name_template: netbirdio/management:debug-latest
image_templates:
- netbirdio/management:{{ .Version }}-debug-arm64v8
- netbirdio/management:{{ .Version }}-debug-arm
- netbirdio/management:{{ .Version }}-debug-amd64
- name_template: netbirdio/upload:{{ .Version }}
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/upload:latest
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:{{ .Version }}
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:latest
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:latest
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: ghcr.io/netbirdio/netbird:rootless-latest
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: ghcr.io/netbirdio/relay:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/relay:latest
image_templates:
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/signal:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/signal:latest
image_templates:
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:latest
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:debug-latest
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
- name_template: ghcr.io/netbirdio/upload:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/upload:latest
image_templates:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:latest
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:latest
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
dockers_v2:
- id: netbird
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird
images:
- netbirdio/netbird
- ghcr.io/netbirdio/netbird
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
platforms:
- linux/amd64
- linux/arm64
- linux/arm/6
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: netbird-rootless
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird
images:
- netbirdio/netbird
- ghcr.io/netbirdio/netbird
tags:
- "v{{ .Version }}-rootless"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
platforms:
- linux/amd64
- linux/arm64
- linux/arm/6
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: relay
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-relay
images:
- netbirdio/relay
- ghcr.io/netbirdio/relay
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: relay/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: signal
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-signal
images:
- netbirdio/signal
- ghcr.io/netbirdio/signal
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: signal/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: management
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-mgmt
images:
- netbirdio/management
- ghcr.io/netbirdio/management
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: management/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: upload
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-upload
images:
- netbirdio/upload
- ghcr.io/netbirdio/upload
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: upload-server/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: netbird-server
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-server
images:
- netbirdio/netbird-server
- ghcr.io/netbirdio/netbird-server
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: combined/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: netbird-proxy
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-proxy
images:
- netbirdio/reverse-proxy
- ghcr.io/netbirdio/reverse-proxy
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: proxy/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
brews:
- ids:
- default
skip_upload: "{{ .Env.SKIP_PUBLISH }}"
repository:
owner: netbirdio
name: homebrew-tap
@@ -902,6 +440,7 @@ brews:
uploads:
- name: debian
skip: "{{ .Env.SKIP_PUBLISH }}"
ids:
- netbird_deb
mode: archive
@@ -910,6 +449,7 @@ uploads:
method: PUT
- name: yum
skip: "{{ .Env.SKIP_PUBLISH }}"
ids:
- netbird_rpm
mode: archive

View File

@@ -1,5 +1,6 @@
version: 2
env:
- SKIP_PUBLISH={{ if index .Env "SKIP_PUBLISH" }}{{ .Env.SKIP_PUBLISH }}{{ else }}true{{ end }}
project_name: netbird-ui
builds:
- id: netbird-ui
@@ -101,6 +102,7 @@ nfpms:
uploads:
- name: debian
skip: "{{ .Env.SKIP_PUBLISH }}"
ids:
- netbird_ui_deb
mode: archive
@@ -109,6 +111,7 @@ uploads:
method: PUT
- name: yum
skip: "{{ .Env.SKIP_PUBLISH }}"
ids:
- netbird_ui_rpm
mode: archive

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.23.3
FROM alpine:3.24
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \
@@ -21,7 +21,7 @@ ENV \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
ARG NETBIRD_BINARY=netbird
ARG TARGETPLATFORM
ARG NETBIRD_BINARY=$TARGETPLATFORM/netbird
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird

View File

@@ -4,7 +4,7 @@
# podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.0
FROM alpine:3.24
RUN apk add --no-cache \
bash \
@@ -27,7 +27,7 @@ ENV \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
ARG NETBIRD_BINARY=netbird
ARG TARGETPLATFORM
ARG NETBIRD_BINARY=$TARGETPLATFORM/netbird
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird

View File

@@ -23,7 +23,6 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/mdm"
"github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
@@ -76,13 +75,6 @@ type Client struct {
connectClient *internal.ConnectClient
config *profilemanager.Config
cacheDir string
// mdmLoader holds the per-Client MDM policy source. Set by
// SetMDMPolicyFetcher (called from the Kotlin side). Each Run
// passes this loader to the resolved Config so applyMDMPolicy
// picks up the active overlay. Nil means "MDM enforcement off
// for this Client".
mdmLoader *mdm.Loader
}
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
@@ -137,7 +129,6 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
if err != nil {
return err
}
c.applyMDMOverlay(cfg)
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
@@ -182,7 +173,6 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
if err != nil {
return err
}
c.applyMDMOverlay(cfg)
c.recorder.UpdateManagementAddress(cfg.ManagementURL.String())
c.recorder.UpdateRosenpass(cfg.RosenpassEnabled, cfg.RosenpassPermissive)
@@ -240,7 +230,6 @@ func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (strin
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
c.applyMDMOverlay(cfg)
cacheDir = platformFiles.CacheDir()
}

View File

@@ -1,80 +0,0 @@
//go:build android
package android
import (
"encoding/json"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/mdm"
)
// PolicyFetcher is the mobile-side bridge for the MDM managed-config
// snapshot. The native layer (Kotlin) implements this and registers
// the instance per Client via Client.SetMDMPolicyFetcher. Every
// invocation of fetchJSON must read the current RestrictionsManager
// state and return the result as a JSON-encoded map[string]any string.
//
// JSON is used because gomobile does not support map[string]any
// crossing the JNI boundary — the adapter on the Go side parses the
// string back into the map[string]any expected by mdm.Loader.
//
// Return value contract:
// - "" (empty) : interpreted as "no MDM source / no managed keys"
// - "{}" : managed config explicitly empty
// - "{...}" : JSON object with key/value pairs
// - malformed JSON : logged and treated as empty
type PolicyFetcher interface {
FetchJSON() string
}
// jsonFetcherAdapter wraps a gomobile-exposed PolicyFetcher into the
// internal mdm.PolicyFetcher interface, taking care of JSON decoding
// on every Fetch.
type jsonFetcherAdapter struct {
inner PolicyFetcher
}
func (a *jsonFetcherAdapter) Fetch() map[string]any {
raw := a.inner.FetchJSON()
if raw == "" {
return nil
}
var out map[string]any
if err := json.Unmarshal([]byte(raw), &out); err != nil {
log.Warnf("MDM mobile fetcher: invalid JSON payload from native: %v", err)
return nil
}
return out
}
// SetMDMPolicyFetcher registers the native-provided MDM policy fetcher
// on this Client. Call once from the gomobile-init code (Kotlin
// Application.onCreate or Service onCreate) before invoking Run /
// RunWithoutLogin. Passing nil disables MDM enforcement on this
// Client.
//
// The fetcher is held as a *mdm.Loader instance on the Client (no
// package-level state) — multiple Clients in the same process get
// independent Loaders, and tests can inject fakes per Client.
func (c *Client) SetMDMPolicyFetcher(p PolicyFetcher) {
if p == nil {
c.mdmLoader = mdm.NewLoader(nil)
return
}
c.mdmLoader = mdm.NewLoader(&jsonFetcherAdapter{inner: p})
}
// applyMDMOverlay applies the Client-held MDM Loader's current policy
// on top of the just-read Config. Called immediately after every
// UpdateOrCreateConfig — profilemanager's apply() initialises the
// policy to empty and leaves overlay responsibility to the lifecycle
// owner. No-op when no fetcher was registered.
func (c *Client) applyMDMOverlay(cfg *profilemanager.Config) {
if cfg == nil || c.mdmLoader == nil {
return
}
cfg.ApplyMDMPolicy(c.mdmLoader.Load())
}

View File

@@ -17,7 +17,6 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/mdm"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/util"
@@ -249,11 +248,6 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
if err != nil {
return fmt.Errorf("read config file %s: %v", configFilePath, err)
}
// CLI standalone login: profilemanager no longer auto-applies MDM,
// so layer in the OS-native policy here. Desktop builds construct
// a Loader with no fetcher — the build-tagged loadPlatform reads
// the registry/plist directly.
config.ApplyMDMPolicy(mdm.NewLoader(nil).Load())
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
if err != nil {

View File

@@ -21,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/mdm"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -188,10 +187,6 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
if err != nil {
return fmt.Errorf("get config file: %v", err)
}
// CLI foreground path runs without the daemon Server: layer in the
// active MDM policy explicitly so a forced ManagementURL / PSK /
// other managed key actually takes effect on this run.
config.ApplyMDMPolicy(mdm.NewLoader(nil).Load())
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)

View File

@@ -21,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/mdm"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
@@ -216,10 +215,6 @@ func New(opts Options) (*Client, error) {
if err != nil {
return nil, fmt.Errorf("create config: %w", err)
}
// Embedded path runs without the daemon Server: apply the active
// MDM policy explicitly so a forced ManagementURL / PSK / other
// managed key takes effect on this embedded engine instance.
config.ApplyMDMPolicy(mdm.NewLoader(nil).Load())
if opts.PrivateKey != "" {
config.PrivateKey = opts.PrivateKey

View File

@@ -41,7 +41,6 @@ type ICEBind struct {
*wgConn.StdNetBind
transportNet transport.Net
filterFn udpmux.FilterFn
address wgaddr.Address
mtu uint16
@@ -61,12 +60,11 @@ type ICEBind struct {
ipv6Conn *net.UDPConn
}
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
func NewICEBind(transportNet transport.Net, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
transportNet: transportNet,
filterFn: filterFn,
address: address,
mtu: mtu,
endpoints: make(map[netip.Addr]net.Conn),
@@ -265,7 +263,6 @@ func (s *ICEBind) createOrUpdateMux() {
udpmux.UniversalUDPMuxParams{
UDPConn: muxConn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},

View File

@@ -289,7 +289,7 @@ func setupICEBind(t *testing.T) *ICEBind {
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/10"),
}
return NewICEBind(transportNet, nil, address, 1280)
return NewICEBind(transportNet, address, 1280)
}
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {

View File

@@ -1,10 +1,13 @@
package device
import (
"fmt"
"net/netip"
"runtime/debug"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
)
@@ -41,10 +44,13 @@ type PacketCapture interface {
type FilteredDevice struct {
tun.Device
filter PacketFilter
capture atomic.Pointer[PacketCapture]
mutex sync.RWMutex
closeOnce sync.Once
filter PacketFilter
capture atomic.Pointer[PacketCapture]
// panicHandler is invoked after a panic in the underlying device is
// recovered in Read or Write.
panicHandler atomic.Pointer[func()]
mutex sync.RWMutex
closeOnce sync.Once
}
// newDeviceFilter constructor function
@@ -70,7 +76,7 @@ func (d *FilteredDevice) Close() error {
// Read wraps read method with filtering feature
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
if n, err = d.deviceRead(bufs, sizes, offset); err != nil {
return 0, err
}
@@ -112,7 +118,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RUnlock()
if filter == nil {
return d.Device.Write(bufs, offset)
return d.deviceWrite(bufs, offset)
}
filteredBufs := make([][]byte, 0, len(bufs))
@@ -125,9 +131,44 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
}
}
n, err := d.Device.Write(filteredBufs, offset)
n += dropped
return n, err
n, err := d.deviceWrite(filteredBufs, offset)
if err != nil {
return n, err
}
return n + dropped, nil
}
// deviceRead calls the underlying device Read, recovering from panics in the
// wintun read path and converting them into errors.
func (d *FilteredDevice) deviceRead(bufs [][]byte, sizes []int, offset int) (n int, err error) {
defer d.recoverFromPanic("read", &n, &err)
return d.Device.Read(bufs, sizes, offset)
}
// deviceWrite calls the underlying device Write, recovering from panics in the
// wintun write path and converting them into errors.
func (d *FilteredDevice) deviceWrite(bufs [][]byte, offset int) (n int, err error) {
defer d.recoverFromPanic("write", &n, &err)
return d.Device.Write(bufs, offset)
}
// recoverFromPanic converts a panic in the underlying device into a regular
// error and invokes the registered panic handler. The wintun read path is
// known to panic on zero-length packets that third-party filter drivers can
// place in the ring.
func (d *FilteredDevice) recoverFromPanic(op string, n *int, err *error) {
r := recover()
if r == nil {
return
}
log.Errorf("recovered panic in tun device %s: %v\n%s", op, r, debug.Stack())
*n = 0
*err = fmt.Errorf("tun device %s panic: %v", op, r)
if handler := d.panicHandler.Load(); handler != nil {
(*handler)()
}
}
// SetFilter sets packet filter to device
@@ -137,6 +178,17 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) {
d.mutex.Unlock()
}
// SetPanicHandler registers a handler invoked after a recovered panic in Read
// or Write. The device is unusable after such a panic; the handler should
// trigger recreation of the interface. Pass nil to remove.
func (d *FilteredDevice) SetPanicHandler(handler func()) {
if handler == nil {
d.panicHandler.Store(nil)
return
}
d.panicHandler.Store(&handler)
}
// SetCapture sets or clears the packet capture sink. Pass nil to disable.
// Uses atomic store so the hot path (Read/Write) is a single pointer load
// with no locking overhead when capture is off.

View File

@@ -221,3 +221,60 @@ func TestDeviceWrapperRead(t *testing.T) {
}
})
}
func TestDeviceWrapperReadPanic(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Read(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
// Reproduce the wintun zero-length packet panic (index out of range).
packet := make([]byte, 0)
return int(packet[0]), nil
})
wrapped := newDeviceFilter(tun)
handlerCalled := false
wrapped.SetPanicHandler(func() { handlerCalled = true })
n, err := wrapped.Read([][]byte{{}}, []int{0}, 0)
if err == nil {
t.Errorf("expected error from recovered panic, got nil")
}
if n != 0 {
t.Errorf("expected n=0, got %d", n)
}
if !handlerCalled {
t.Errorf("expected panic handler to be called")
}
}
func TestDeviceWrapperWritePanic(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Write(gomock.Any(), gomock.Any()).
DoAndReturn(func(bufs [][]byte, offset int) (int, error) {
packet := make([]byte, 0)
return int(packet[0]), nil
})
wrapped := newDeviceFilter(tun)
handlerCalled := false
wrapped.SetPanicHandler(func() { handlerCalled = true })
n, err := wrapped.Write([][]byte{{0x45, 0x00}}, 0)
if err == nil {
t.Errorf("expected error from recovered panic, got nil")
}
if n != 0 {
t.Errorf("expected n=0, got %d", n)
}
if !handlerCalled {
t.Errorf("expected panic handler to be called")
}
}

View File

@@ -32,8 +32,6 @@ type TunKernelDevice struct {
link *wgLink
udpMuxConn net.PacketConn
udpMux *udpmux.UniversalUDPMuxDefault
filterFn udpmux.FilterFn
}
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
@@ -104,7 +102,6 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
bindParams := udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(rawSock),
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,
MTU: t.mtu,
}

View File

@@ -63,7 +63,6 @@ type WGIFaceOpts struct {
MTU uint16
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn udpmux.FilterFn
DisableDNS bool
}

View File

@@ -11,7 +11,7 @@ import (
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
var tun WGTunDevice
if netstack.IsEnabled() {

View File

@@ -9,7 +9,7 @@ import (
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
if netstack.IsEnabled() {
wgIFace := &WGIface{

View File

@@ -10,7 +10,7 @@ import (
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),

View File

@@ -14,7 +14,7 @@ import (
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
return &WGIface{
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
userspaceBind: true,
@@ -30,7 +30,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
return &WGIface{
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind),
userspaceBind: true,

View File

@@ -8,8 +8,6 @@ import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
@@ -22,10 +20,6 @@ import (
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// FilterFn is a function that filters out candidates based on the address.
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
@@ -43,7 +37,6 @@ type UniversalUDPMuxParams struct {
UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration
Net transport.Net
FilterFn FilterFn
WGAddress wgaddr.Address
MTU uint16
}
@@ -68,7 +61,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
filterFn: params.FilterFn,
address: params.WGAddress,
}
@@ -115,15 +107,12 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
}
}
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
// UDPConn is a wrapper around UDPMux conn that overrides WriteTo to drop packets destined for the overlay subnet.
type UDPConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
filterFn FilterFn
// TODO: reset cache on route changes
addrCache sync.Map
address wgaddr.Address
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
address wgaddr.Address
}
// GetPacketConn returns the underlying PacketConn
@@ -132,65 +121,16 @@ func (u *UDPConn) GetPacketConn() net.PacketConn {
}
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return u.PacketConn.WriteTo(b, addr)
}
if isRouted, found := u.addrCache.Load(addr.String()); found {
return u.handleCachedAddress(isRouted.(bool), b, addr)
}
return u.handleUncachedAddress(b, addr)
}
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil {
return 0, err
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr)
if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err)
return nil
}
a, err := netip.ParseAddr(host)
if err != nil {
log.Errorf("Failed to parse address %s: %v", addr, err)
return nil
}
if u.address.Network.Contains(a) {
dst := udpAddr.AddrPort().Addr().Unmap()
if (u.address.Network.IsValid() && u.address.Network.Contains(dst)) || (u.address.IPv6Net.IsValid() && u.address.IPv6Net.Contains(dst)) {
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return 0, fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else {
u.addrCache.Store(addr.String(), isRouted)
if isRouted {
// Extra log, as the error only shows up with ICE logging enabled
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
}
}
return nil
}
func getHostFromAddr(addr net.Addr) (string, error) {
host, _, err := net.SplitHostPort(addr.String())
return host, err
return u.PacketConn.WriteTo(b, addr)
}
// GetSharedConn returns the shared udp conn
@@ -225,6 +165,13 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
return nil
}
src := udpAddr.AddrPort().Addr().Unmap()
wg := m.params.WGAddress
if (wg.Network.IsValid() && wg.Network.Contains(src)) || (wg.IPv6Net.IsValid() && wg.IPv6Net.Contains(src)) {
log.Debugf("dropping STUN message from overlay source %s", udpAddr)
return nil
}
if m.isXORMappedResponse(msg, udpAddr.String()) {
err := m.handleXORMappedResponse(udpAddr, msg)
if err != nil {

View File

@@ -66,7 +66,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
iceBind := bind.NewICEBind(nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,

View File

@@ -22,7 +22,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
iceBind := bind.NewICEBind(nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,

View File

@@ -118,6 +118,8 @@ func (c *ConnectClient) RunOniOS(
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
stateFilePath string,
cacheDir string,
logFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
debug.SetGCPercent(5)
@@ -127,8 +129,9 @@ func (c *ConnectClient) RunOniOS(
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
StateFilePath: stateFilePath,
TempDir: cacheDir,
}
return c.run(mobileDependency, nil, "")
return c.run(mobileDependency, nil, logFilePath)
}
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {

View File

@@ -250,6 +250,7 @@ type BundleGenerator struct {
syncResponse *mgmProto.SyncResponse
logPath string
tempDir string
statePath string
cpuProfile []byte
capturePath string
refreshStatus func() // Optional callback to refresh status before bundle generation
@@ -276,6 +277,7 @@ type GeneratorDependencies struct {
SyncResponse *mgmProto.SyncResponse
LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
StatePath string // Path to the state file. If empty, the ServiceManager default path is used.
CPUProfile []byte
CapturePath string
RefreshStatus func()
@@ -299,6 +301,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
syncResponse: deps.SyncResponse,
logPath: deps.LogPath,
tempDir: deps.TempDir,
statePath: deps.StatePath,
cpuProfile: deps.CPUProfile,
capturePath: deps.CapturePath,
refreshStatus: deps.RefreshStatus,
@@ -850,8 +853,11 @@ func (g *BundleGenerator) maskSecrets() {
}
func (g *BundleGenerator) addStateFile() error {
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()
path := g.statePath
if path == "" {
sm := profilemanager.NewServiceManager("")
path = sm.GetStatePath()
}
if path == "" {
return nil
}

View File

@@ -0,0 +1,36 @@
//go:build ios
package debug
import (
"path/filepath"
log "github.com/sirupsen/logrus"
)
// swiftLogFile is the Swift app log written by the iOS app into the same log
// directory as the Go client log, so it can be collected into the bundle.
const swiftLogFile = "swift-log.log"
// addPlatformLog collects logs for the iOS debug bundle. iOS has no logcat or
// systemd journal, so we rely on file-based logs. addLogfile handles the Go
// client log (logPath) with rotation, the stderr/stdout companions and
// anonymization. The iOS app writes its own Swift log into the same directory,
// so we add it alongside the Go log.
func (g *BundleGenerator) addPlatformLog() error {
if err := g.addLogfile(); err != nil {
return err
}
if g.logPath == "" {
return nil
}
swiftLogPath := filepath.Join(filepath.Dir(g.logPath), swiftLogFile)
if err := g.addSingleLogfile(swiftLogPath, swiftLogFile); err != nil {
// The Swift log is best-effort: the app may not have written it yet.
log.Warnf("failed to add %s to debug bundle: %v", swiftLogFile, err)
}
return nil
}

View File

@@ -1,4 +1,4 @@
//go:build !android
//go:build !android && !ios
package debug

View File

@@ -53,7 +53,6 @@ import (
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/syncstore"
"github.com/netbirdio/netbird/client/internal/updater"
@@ -240,7 +239,7 @@ type Engine struct {
syncStore syncstore.Store
syncStoreDir string
flowManager nftypes.FlowManager
flowManager nftypes.FlowManager
// auto-update
updateManager *updater.Manager
@@ -531,6 +530,10 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("create wg interface: %w", err)
}
if filteredDevice := e.wgInterface.GetDevice(); filteredDevice != nil {
filteredDevice.SetPanicHandler(e.triggerClientRestart)
}
if err := e.createFirewall(); err != nil {
e.close()
return err
@@ -1711,6 +1714,13 @@ func (e *Engine) receiveSignalEvents() {
return e.ctx.Err()
}
// Self-addressed heartbeat: the signal client's receive watchdog
// round-trips this through the server to confirm the receive stream
// is delivering. Liveness is already recorded before this handler.
if msg.GetBody().GetType() == sProto.Body_HEARTBEAT {
return nil
}
conn, ok := e.peerStore.PeerConn(msg.Key)
if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key)
@@ -1909,7 +1919,6 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
WGPrivKey: e.config.WgPrivateKey.String(),
MTU: e.config.MTU,
TransportNet: transportNet,
FilterFn: e.addrViaRoutes,
DisableDNS: e.config.DisableDNS,
}
@@ -2157,21 +2166,6 @@ func (e *Engine) startNetworkMonitor() {
}()
}
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
var vpnRoutes []netip.Prefix
for _, routes := range e.routeManager.GetClientRoutes() {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
return true, prefix, nil
}
return false, netip.Prefix{}, nil
}
func (e *Engine) stopDNSServer() {
if e.dnsServer == nil {
return

View File

@@ -1024,14 +1024,17 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return d.relayStates
}
// extend the list of stun, turn servers with relay address
// extend the list of stun, turn servers with the relay server connections
relayStates := slices.Clone(d.relayStates)
// if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
if err != nil {
// TODO add their status
states := d.relayMgr.RelayStates()
if len(states) == 0 {
// no relay connection tracked yet; surface configured servers as
// unavailable with the real reconnect error when known
err := relayClient.ErrRelayClientNotConnected
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
err = connErr
}
for _, r := range d.relayMgr.ServerURLs() {
relayStates = append(relayStates, relay.ProbeResult{
URI: r,
@@ -1041,10 +1044,14 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return relayStates
}
relayState := relay.ProbeResult{
URI: instanceAddr,
for _, rs := range states {
relayStates = append(relayStates, relay.ProbeResult{
URI: rs.URL,
Err: rs.Err,
Transport: rs.Transport,
})
}
return append(relayStates, relayState)
return relayStates
}
func (d *Status) ForwardingRules() []firewall.ForwardRule {
@@ -1405,6 +1412,7 @@ func (fs FullStatus) ToProto() *proto.FullStatus {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
Transport: relayState.Transport,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"time"
@@ -165,10 +164,6 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
return
}
if candidateViaRoutes(candidate, haRoutes) {
return
}
if err := w.agent.AddRemoteCandidate(candidate); err != nil {
w.log.Errorf("error while handling remote candidate")
return
@@ -589,34 +584,6 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
return ec, nil
}
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
var routePrefixes []netip.Prefix
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
routePrefixes = append(routePrefixes, routes[0].Network)
}
}
for _, prefix := range routePrefixes {
// default route is handled by route exclusion / ip rules
if prefix.Bits() == 0 {
continue
}
if prefix.Contains(addr) {
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
return true
}
}
return false
}
func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}

View File

@@ -58,6 +58,10 @@ var DefaultInterfaceBlacklist = []string{
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
}
// loadMDMPolicy is the package-level indirection used by apply() to read the
// active MDM policy. Tests override this to inject a fake policy.
var loadMDMPolicy = mdm.LoadPolicy
// ConfigInput carries configuration changes to the client
type ConfigInput struct {
ManagementURL string
@@ -176,27 +180,14 @@ type Config struct {
MTU uint16
// policy is the MDM policy that produced the currently-set values
// for any MDM-enforced fields. Set by ApplyMDMPolicy on every
// invocation. Never persisted to disk. Callers query enforcement
// state via Policy() and the mdm.Policy API (HasKey, ManagedKeys,
// IsEmpty).
// policy is the MDM policy that produced the currently-set values for
// any MDM-enforced fields. Set by applyMDMPolicy at the tail of apply()
// and reset on every apply() invocation. Never persisted to disk.
// Callers query enforcement state via Policy() and the mdm.Policy API
// (HasKey, ManagedKeys, IsEmpty).
policy *mdm.Policy `json:"-"`
}
// ApplyMDMPolicy overlays the supplied MDM Policy on top of the
// currently resolved Config values. Idempotent — pass an empty Policy
// to clear any prior overlay. The lifecycle owner (Server.getConfig
// on desktop, the Client.Run path on mobile) calls this with
// loader.Load() once the per-process Loader is known; the Config
// itself holds no reference to the Loader.
func (config *Config) ApplyMDMPolicy(policy *mdm.Policy) {
if config == nil {
return
}
config.applyMDMPolicy(policy)
}
// Policy returns the MDM policy applied to this Config. Returns a non-nil
// empty Policy when MDM enforcement is inactive; callers can always invoke
// HasKey / ManagedKeys / IsEmpty without a nil check.
@@ -643,11 +634,9 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
// Initialise the MDM overlay to "no enforcement" so Config.Policy()
// never returns a stale or nil policy on a freshly applied Config.
// Lifecycle owners that want to enforce a real MDM policy invoke
// Config.ApplyMDMPolicy(loader.Load()) after this returns.
config.applyMDMPolicy(mdm.NewPolicy(nil))
// MDM is the last override layer: any key present in the policy
// supersedes defaults, on-disk config, env vars and CLI input.
config.applyMDMPolicy(loadMDMPolicy())
return updated, nil
}

View File

@@ -10,58 +10,24 @@ import (
"github.com/netbirdio/netbird/client/mdm"
)
// fakeFetcher implements mdm.PolicyFetcher returning a pre-set policy
// map. Test helper used to construct a Loader without touching the OS
// or any package-level state.
type fakeFetcher struct{ values map[string]any }
func (f *fakeFetcher) Fetch() map[string]any { return f.values }
// loaderFor builds an mdm.Loader whose loadPlatform returns the
// supplied Policy's underlying values.
func loaderFor(policy *mdm.Policy) *mdm.Loader {
if policy == nil || policy.IsEmpty() {
return mdm.NewLoader(&fakeFetcher{values: nil})
}
values := make(map[string]any)
for _, k := range policy.ManagedKeys() {
if v, ok := policy.GetString(k); ok {
values[k] = v
continue
}
if v, ok := policy.GetBool(k); ok {
values[k] = v
continue
}
if v, ok := policy.GetInt(k); ok {
values[k] = v
continue
}
if v, ok := policy.GetStringSlice(k); ok {
values[k] = v
}
}
return mdm.NewLoader(&fakeFetcher{values: values})
}
// configWithMDM is the test convenience that builds a Config via
// UpdateOrCreateConfig and overlays the supplied MDM policy on top —
// mirrors the production pattern (Server.getConfig / Client.applyMDMOverlay)
// where the Loader lives outside Config and the apply step is driven
// by the lifecycle owner.
func configWithMDM(t *testing.T, input ConfigInput, policy *mdm.Policy) *Config {
// withMDMPolicy temporarily overrides the package-level loadMDMPolicy hook so
// apply() observes the supplied Policy. The original loader is restored at
// test cleanup.
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
t.Helper()
cfg, err := UpdateOrCreateConfig(input)
require.NoError(t, err)
require.NotNil(t, cfg)
cfg.ApplyMDMPolicy(loaderFor(policy).Load())
return cfg
prev := loadMDMPolicy
loadMDMPolicy = func() *mdm.Policy { return policy }
t.Cleanup(func() { loadMDMPolicy = prev })
}
func TestApply_MDMEmpty_NoEnforcement(t *testing.T) {
cfg := configWithMDM(t, ConfigInput{
withMDMPolicy(t, mdm.NewPolicy(nil))
cfg, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
}, mdm.NewPolicy(nil))
})
require.NoError(t, err)
require.NotNil(t, cfg)
assert.True(t, cfg.Policy().IsEmpty(), "no MDM source ⇒ empty Policy")
assert.False(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
@@ -73,15 +39,18 @@ func TestApply_MDMEmpty_NoEnforcement(t *testing.T) {
func TestApply_MDMOnly_OverridesDefaults(t *testing.T) {
const mdmURL = "https://corp.mdm.example.com:443"
cfg := configWithMDM(t, ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
}, mdm.NewPolicy(map[string]any{
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: mdmURL,
mdm.KeyDisableClientRoutes: true,
mdm.KeyBlockInbound: true,
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
})
require.NoError(t, err)
require.NotNil(t, cfg)
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
assert.True(t, cfg.DisableClientRoutes)
assert.True(t, cfg.BlockInbound)
@@ -96,25 +65,33 @@ func TestApply_MDMBeatsCLIInput(t *testing.T) {
const mdmURL = "https://mdm.example.com:443"
const cliURL = "https://cli.example.com:443"
cfg := configWithMDM(t, ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
ManagementURL: cliURL,
}, mdm.NewPolicy(map[string]any{
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: mdmURL,
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
ManagementURL: cliURL,
})
require.NoError(t, err)
require.NotNil(t, cfg)
// MDM wins over CLI-supplied management URL.
assert.Equal(t, mdmURL, cfg.ManagementURL.String())
assert.True(t, cfg.Policy().HasKey(mdm.KeyManagementURL))
}
func TestApply_MDMInvalidURL_KeepsPreviousValue(t *testing.T) {
cfg := configWithMDM(t, ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
}, mdm.NewPolicy(map[string]any{
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "not-a-url",
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
})
require.NoError(t, err)
require.NotNil(t, cfg)
// Invalid MDM URL is logged and skipped: default URL stays in place
// to keep the client functional.
assert.Equal(t, DefaultManagementURL, cfg.ManagementURL.String())
@@ -129,20 +106,24 @@ func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
tmp := filepath.Join(t.TempDir(), "config.json")
// Seed without MDM.
configWithMDM(t, ConfigInput{
withMDMPolicy(t, mdm.NewPolicy(nil))
_, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: tmp,
DisableClientRoutes: boolPtr(false),
RosenpassEnabled: boolPtr(false),
}, mdm.NewPolicy(nil))
})
require.NoError(t, err)
// Now enable MDM enforcement for these keys.
cfg := configWithMDM(t, ConfigInput{
ConfigPath: tmp,
}, mdm.NewPolicy(map[string]any{
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyDisableClientRoutes: true,
mdm.KeyRosenpassEnabled: true,
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{ConfigPath: tmp})
require.NoError(t, err)
require.NotNil(t, cfg)
assert.True(t, cfg.DisableClientRoutes, "MDM override should flip on-disk false to true")
assert.True(t, cfg.RosenpassEnabled)
assert.True(t, cfg.Policy().HasKey(mdm.KeyDisableClientRoutes))
@@ -152,12 +133,16 @@ func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
func TestApply_MDMPreSharedKeyRedactionSentinelRejected(t *testing.T) {
const maskSentinel = "**********"
cfg := configWithMDM(t, ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
}, mdm.NewPolicy(map[string]any{
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyPreSharedKey: maskSentinel,
}))
cfg, err := UpdateOrCreateConfig(ConfigInput{
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
})
require.NoError(t, err)
require.NotNil(t, cfg)
// Mask sentinel must not be persisted as the actual PSK.
assert.NotEqual(t, maskSentinel, cfg.PreSharedKey)
// Key still marked managed so user writes are still rejected.

View File

@@ -32,6 +32,9 @@ type ProbeResult struct {
URI string
Err error
Addr string
// Transport is the negotiated relay transport, empty
// for stun/turn probes or when not connected.
Transport string
}
type StunTurnProbe struct {

View File

@@ -22,14 +22,14 @@ type removePeerCall struct {
}
type mockServer struct {
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
}
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
@@ -51,7 +51,7 @@ func (m *mockServer) RemovePeer(id rp.PeerID) error {
return m.removeErr
}
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Close() error { m.closed = true; return nil }
type setPSKCall struct {

View File

@@ -41,4 +41,3 @@ func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
_, err = DeterministicSeedKey(long, short)
require.Error(t, err)
}

View File

@@ -9,6 +9,7 @@ import (
"net/url"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
@@ -332,6 +333,8 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
}
}
m.notifier.Close()
m.mux.Lock()
defer m.mux.Unlock()
m.clientRoutes = nil
@@ -700,6 +703,8 @@ func resolveURLsToIPs(urls []string) []net.IP {
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
m.mirrorV6ExitPairSelections(clientRoutes)
// An explicit user "deselect all" must not be overridden by management auto-apply.
// Auto-applying an exit node here would call SelectRoutes, which clears the
// deselect-all flag and re-enables every route the user turned off.
@@ -716,6 +721,24 @@ func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HA
m.logExitNodeUpdate(exitNodeInfo)
}
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
// consistent with its v4 base. The v4/v6 exit pair is a single toggle, so the v6
// entry always follows the base: deselecting the v4 exit node also drops its ::/0
// pair, and any stale (orphaned) explicit selection on the v6 entry is reset. This
// runs before selection is read so both collectExitNodeInfo and FilterSelectedExitNodes
// see consistent state, including pairs loaded from persisted selector state.
func (m *DefaultManager) mirrorV6ExitPairSelections(clientRoutes route.HAMap) {
routesByNetID := make(map[route.NetID][]*route.Route, len(clientRoutes))
for haID, routes := range clientRoutes {
routesByNetID[haID.NetID()] = routes
}
for v6ID := range route.V6ExitMergeSet(routesByNetID) {
baseID := route.NetID(strings.TrimSuffix(string(v6ID), route.V6ExitSuffix))
m.routeSelector.SyncPairedSelection(baseID, v6ID)
}
}
type exitNodeInfo struct {
allIDs []route.NetID
selectedByManagement []route.NetID

View File

@@ -0,0 +1,47 @@
package routemanager
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
// TestUpdateRouteSelectorFromManagement_MirrorsV6ExitPair reproduces the bug seen
// in netbird-engine.log: persisted selector state has the v4 exit node deselected
// but its synthesized "-v6" pair explicitly selected (orphaned), so the ::/0 route
// leaked onto the tunnel. The management update must mirror the v4 deselect onto the
// v6 pair so FilterSelectedExitNodes drops it.
func TestUpdateRouteSelectorFromManagement_MirrorsV6ExitPair(t *testing.T) {
const (
v4ID = route.NetID("Exit Node (raspberrypi)")
v6ID = route.NetID("Exit Node (raspberrypi)-v6")
)
all := []route.NetID{v4ID, v6ID}
rs := routeselector.NewRouteSelector()
// Orphan the v6 selection: select the pair, then deselect only the v4 base.
require.NoError(t, rs.SelectRoutes([]route.NetID{v4ID, v6ID}, true, all))
require.NoError(t, rs.DeselectRoutes([]route.NetID{v4ID}, all))
require.True(t, rs.IsSelected(v6ID), "precondition: orphaned v6 selection survives v4 deselect")
m := &DefaultManager{routeSelector: rs}
v4Route := &route.Route{NetID: v4ID, Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: v6ID, Network: netip.MustParsePrefix("::/0")}
clientRoutes := route.HAMap{
"Exit Node (raspberrypi)|0.0.0.0/0": {v4Route},
"Exit Node (raspberrypi)-v6|::/0": {v6Route},
}
m.updateRouteSelectorFromManagement(clientRoutes)
assert.False(t, rs.IsSelected(v6ID), "v6 pair must follow the v4 base deselect after the management update")
filtered := rs.FilterSelectedExitNodes(clientRoutes)
assert.Empty(t, filtered, "deselected v4 exit node must not leak its ::/0 pair onto the tunnel")
}

View File

@@ -16,7 +16,7 @@ import (
type Notifier struct {
initialRoutes []*route.Route
currentRoutes []*route.Route
fakeIPRoutes []*route.Route
fakeIPRoutes []*route.Route
listener listener.NetworkChangeListener
listenerMux sync.Mutex
@@ -119,3 +119,7 @@ func (n *Notifier) GetInitialRouteRanges() []string {
sort.Strings(initialStrings)
return initialStrings
}
func (n *Notifier) Close() {
// unused
}

View File

@@ -3,6 +3,7 @@
package notifier
import (
"container/list"
"net/netip"
"slices"
"sort"
@@ -14,19 +15,26 @@ import (
)
type Notifier struct {
mu sync.Mutex
cond *sync.Cond
currentPrefixes []string
listener listener.NetworkChangeListener
listenerMux sync.Mutex
listener listener.NetworkChangeListener
queue *list.List
closed bool
}
func NewNotifier() *Notifier {
return &Notifier{}
n := &Notifier{
queue: list.New(),
}
n.cond = sync.NewCond(&n.mu)
go n.deliverLoop()
return n
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.mu.Lock()
defer n.mu.Unlock()
n.listener = listener
}
@@ -43,32 +51,52 @@ func (n *Notifier) OnNewRoutes(route.HAMap) {
}
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
newNets := make([]string, 0, len(prefixes))
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}
sort.Strings(newNets)
n.mu.Lock()
if slices.Equal(n.currentPrefixes, newNets) {
n.mu.Unlock()
return
}
n.currentPrefixes = newNets
n.notify()
routes := strings.Join(n.currentPrefixes, ",")
n.queue.PushBack(routes)
n.cond.Signal()
n.mu.Unlock()
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(n.currentPrefixes, ","))
}(n.listener)
func (n *Notifier) Close() {
n.mu.Lock()
n.closed = true
n.cond.Signal()
n.mu.Unlock()
}
func (n *Notifier) GetInitialRouteRanges() []string {
return nil
}
func (n *Notifier) deliverLoop() {
for {
n.mu.Lock()
for n.queue.Len() == 0 && !n.closed {
n.cond.Wait()
}
if n.closed && n.queue.Len() == 0 {
n.mu.Unlock()
return
}
routes := n.queue.Remove(n.queue.Front()).(string)
l := n.listener
n.mu.Unlock()
if l != nil {
l.OnNetworkChanged(routes)
}
}
}

View File

@@ -38,3 +38,7 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
func (n *Notifier) GetInitialRouteRanges() []string {
return []string{}
}
func (n *Notifier) Close() {
// unused
}

View File

@@ -121,9 +121,12 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
return Nexthop{}, vars.ErrRouteNotAllowed
}
// Check if the prefix is part of any local subnets
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
// BSDs blackhole a /32 added inside a directly-connected subnet; Linux/Windows need it to beat the wt0 route.
switch runtime.GOOS {
case "darwin", "freebsd", "netbsd", "openbsd", "dragonfly":
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
}
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route

View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
"fmt"
"slices"
"strings"
"sync"
"github.com/hashicorp/go-multierror"
@@ -132,6 +131,33 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
return rs.isSelectedLocked(routeID)
}
// SyncPairedSelection forces pairedID's explicit selection state to match baseID's,
// so a synthesized "-v6" exit route always follows its v4 base: selecting or
// deselecting the v4 exit node governs the ::/0 pair, and any stale (orphaned)
// explicit state on the v6 entry is reset. The v4/v6 exit pair is treated as a single
// toggle, so the v6 entry carries no independent selection of its own.
func (rs *RouteSelector) SyncPairedSelection(baseID, pairedID route.NetID) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.deselectAll {
return
}
_, baseSelected := rs.selectedRoutes[baseID]
_, baseDeselected := rs.deselectedRoutes[baseID]
delete(rs.selectedRoutes, pairedID)
delete(rs.deselectedRoutes, pairedID)
switch {
case baseSelected:
rs.selectedRoutes[pairedID] = struct{}{}
case baseDeselected:
rs.deselectedRoutes[pairedID] = struct{}{}
}
}
// FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
rs.mu.RLock()
@@ -151,14 +177,13 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
}
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
// transparently apply to the synthesized v6 entry.
// The lookup is literal; v4/v6 exit pairs are kept consistent at write time via SyncPairedSelection,
// so a synthesized "-v6" entry carries the same explicit state as its v4 base.
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
return rs.hasUserSelectionForRouteLocked(routeID)
}
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
@@ -187,83 +212,6 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
return filtered
}
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
name := string(id)
if !strings.HasSuffix(name, route.V6ExitSuffix) {
return id
}
if _, ok := rs.selectedRoutes[id]; ok {
return id
}
if _, ok := rs.deselectedRoutes[id]; ok {
return id
}
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func (rs *RouteSelector) applyExitNodeFilter(
id route.HAUniqueID,
netID route.NetID,
rt []*route.Route,
out route.HAMap,
) {
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
// drops the synthesized v6 entry that lacks its own explicit state.
effective := rs.effectiveNetID(netID)
if rs.hasUserSelectionForRouteLocked(effective) {
if rs.isSelectedLocked(effective) {
out[id] = rt
}
return
}
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {
if !r.SkipAutoApply {
sel = append(sel, r)
}
}
return sel
}
// MarshalJSON implements the json.Marshaler interface
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
rs.mu.RLock()
@@ -317,3 +265,59 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
return nil
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func (rs *RouteSelector) applyExitNodeFilter(
id route.HAUniqueID,
netID route.NetID,
rt []*route.Route,
out route.HAMap,
) {
if rs.hasUserSelectionForRouteLocked(netID) {
if rs.isSelectedLocked(netID) {
out[id] = rt
}
return
}
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {
if !r.SkipAutoApply {
sel = append(sel, r)
}
}
return sel
}

View File

@@ -330,39 +330,73 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
assert.Len(t, filtered, 0) // No routes should be selected
}
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
// v4 base, so legacy persisted selections that predate v6 pairing transparently
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
// TestRouteSelector_V6ExitPairSync covers SyncPairedSelection, which keeps a v4
// exit node and its synthesized "-v6" counterpart consistent. The selector itself
// is literal and never infers a v6 entry's state from its v4 base; callers that know
// the pairing (exit-node code paths) call SyncPairedSelection to force the v6 entry
// to follow the base, treating the pair as a single toggle.
func TestRouteSelector_V6ExitPairSync(t *testing.T) {
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
t.Run("selector lookups stay literal without sync", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
// The selector does not pair-resolve: the v6 entry is independent until synced.
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 entry has no state of its own")
assert.True(t, rs.IsSelected("exit1-v6"), "unsynced v6 entry stays selected by default")
// unrelated v6 with no v4 base touched is unaffected
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
// A route literally named "exit1-something" must never pair-resolve either.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
})
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
// via the mirror; the mirror only applies in exit-node code paths.
assert.False(t, rs.IsSelected("corp"))
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
})
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
t.Run("sync mirrors deselected v4 base onto v6", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.IsSelected("exit1"))
assert.False(t, rs.IsSelected("exit1-v6"), "v6 pair follows v4 base deselect")
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 carries explicit deselect after sync")
})
t.Run("sync mirrors selected v4 base onto v6", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1"}, false, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.True(t, rs.IsSelected("exit1"))
assert.True(t, rs.IsSelected("exit1-v6"), "v6 pair follows v4 base select")
})
t.Run("sync clears v6 state when base has no explicit selection", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
require.True(t, rs.HasUserSelectionForRoute("exit1-v6"))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"),
"v6 explicit state is cleared so it follows management like its base")
})
// Regression for the observed bug (see netbird-engine.log): persisted state has
// the v4 base deselected but the v6 sibling explicitly selected (orphaned). The
// sync must reset the orphan so the ::/0 route does not leak onto the tunnel.
t.Run("sync clears orphaned explicit v6 selection on deselected base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
// Prior state: both explicitly selected, then only the v4 base deselected,
// leaving the v6 entry as a stale explicit selection.
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1", "exit1-v6"}, true, all))
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
require.True(t, rs.IsSelected("exit1-v6"), "precondition: orphaned v6 selection")
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.IsSelected("exit1-v6"), "orphaned v6 selection reset to follow v4 deselect")
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
@@ -370,23 +404,14 @@ func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
assert.Empty(t, filtered, "deselecting v4 base must drop the v6 pair even if it was explicitly selected before")
})
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
// A route literally named "exit1-something" must not pair-resolve.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
})
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
t.Run("filter drops synced v6 pair of deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
@@ -399,6 +424,15 @@ func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
})
t.Run("deselectAll makes sync a no-op", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
rs.DeselectAllRoutes()
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"), "sync must not write explicit state under deselectAll")
})
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))

View File

@@ -17,6 +17,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -25,6 +26,7 @@ import (
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
types "github.com/netbirdio/netbird/upload-server/types"
)
// ConnectionListener export internal Listener for mobile
@@ -54,6 +56,7 @@ type selectRoute struct {
Network netip.Prefix
Domains domain.List
Selected bool
Status string
extraNetworks []netip.Prefix
}
@@ -65,6 +68,8 @@ func init() {
type Client struct {
cfgFile string
stateFile string
cacheDir string
logFilePath string
recorder *peer.Status
ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex
@@ -75,16 +80,21 @@ type Client struct {
onHostDnsFn func([]string)
dnsManager dns.IosDnsManager
loginComplete bool
connectClient *internal.ConnectClient
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
preloadedConfig *profilemanager.Config
stateMu sync.RWMutex
connectClient *internal.ConnectClient
config *profilemanager.Config
}
// NewClient instantiate a new Client
func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
func NewClient(cfgFile, stateFile, cacheDir, logFilePath, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
return &Client{
cfgFile: cfgFile,
stateFile: stateFile,
cacheDir: cacheDir,
logFilePath: logFilePath,
deviceName: deviceName,
osName: osName,
osVersion: osVersion,
@@ -161,8 +171,13 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, connectClient)
// Persist the latest sync response so DebugBundle can include the network
// map. On iOS this is backed by disk to keep it out of the constrained
// process memory (see the syncstore package).
connectClient.SetSyncResponsePersistence(true)
return connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
}
// Stop the internal client and free the resources
@@ -174,6 +189,84 @@ func (c *Client) Stop() {
}
c.ctxCancel()
c.setState(nil, nil)
}
// DebugBundle generates a debug bundle, uploads it and returns the upload key.
// It works with or without a running engine: when the engine is up it reuses
// the live config, sync response and client metrics; otherwise it loads the
// config from disk (or the preloaded tvOS config).
func (c *Client) DebugBundle(anonymize bool) (string, error) {
cfg, cc := c.stateSnapshot()
// If the engine hasn't been started, load config so we can reach management.
if cfg == nil {
if c.preloadedConfig != nil {
cfg = c.preloadedConfig
} else {
var err error
// Use DirectUpdateOrCreateConfig to avoid atomic file operations
// (temp file + rename) blocked by the tvOS sandbox.
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
StateFilePath: c.stateFile,
})
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
}
}
deps := debug.GeneratorDependencies{
InternalConfig: cfg,
StatusRecorder: c.recorder,
TempDir: c.cacheDir,
StatePath: c.stateFile,
LogPath: c.logFilePath,
}
if cc != nil {
resp, err := cc.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
deps.SyncResponse = resp
if e := cc.Engine(); e != nil {
if cm := e.GetClientMetrics(); cm != nil {
deps.ClientMetrics = cm
}
}
}
bundleGenerator := debug.NewBundleGenerator(
deps,
debug.BundleConfig{
Anonymize: anonymize,
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
if err != nil {
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle uploaded with key %s", key)
return key, nil
}
// SetTraceLogLevel configure the logger to trace level
@@ -227,6 +320,16 @@ func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener()
}
// IsLoginRequiredCached reports whether the LAST observed management error was an
// auth failure (PermissionDenied/InvalidArgument), using the in-memory status
// recorder. Unlike IsLoginRequired() it performs NO network call, so it is safe to
// call from the connection listener during teardown (e.g. onDisconnected) without
// blocking on a slow or unavailable network. Returns false while connected to
// management or when the last error was not auth-related.
func (c *Client) IsLoginRequiredCached() bool {
return c.recorder.IsLoginRequired()
}
func (c *Client) IsLoginRequired() bool {
var ctx context.Context
//nolint
@@ -354,11 +457,12 @@ func (c *Client) ClearLoginComplete() {
}
func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
if c.connectClient == nil {
_, connectClient := c.stateSnapshot()
if connectClient == nil {
return nil, fmt.Errorf("not connected")
}
engine := c.connectClient.Engine()
engine := connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("not connected")
}
@@ -377,9 +481,57 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged)
resolvedDomains := c.recorder.GetResolvedDomainsStates()
// Compute each route's connection status in the core (mirroring the Android
// bridge), so the UI doesn't have to infer it by string-matching the joined
// Network value against peer routes. For a merged exit node the status reflects
// whichever of the v4/v6 prefixes is served by a connected peer; for dynamic
// (DNS) routes the peer route key is the domain pattern (see dynamic.Route.String).
connectedRoutes := c.connectedRouteSet()
for _, r := range routes {
r.Status = routeStatus(r, connectedRoutes)
}
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
}
// connectedRouteSet returns the set of route keys (as strings) currently served by a
// connected peer, gathered across all connected peers' route tables. The keys match
// what the route manager records: a prefix string for static routes (e.g. "0.0.0.0/0")
// and the domain pattern for dynamic routes (e.g. "*.example.com").
func (c *Client) connectedRouteSet() map[string]struct{} {
connected := map[string]struct{}{}
for _, p := range c.recorder.GetFullStatus().Peers {
if p.ConnStatus != peer.StatusConnected {
continue
}
for r := range p.GetRoutes() {
connected[r] = struct{}{}
}
}
return connected
}
// routeStatus reports "Connected" if any of the route's keys is served by a connected
// peer: the primary Network prefix, an extra v6 network of a merged exit node, or the
// domain pattern for a dynamic DNS route. Otherwise "Idle".
func routeStatus(r *selectRoute, connectedRoutes map[string]struct{}) string {
keys := make([]string, 0, 1+len(r.extraNetworks))
if len(r.Domains) > 0 {
keys = append(keys, r.Domains.SafeString())
} else {
keys = append(keys, r.Network.String())
}
for _, extra := range r.extraNetworks {
keys = append(keys, extra.String())
}
for _, k := range keys {
if _, ok := connectedRoutes[k]; ok {
return peer.StatusConnected.String()
}
}
return peer.StatusIdle.String()
}
func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute {
var routes []*selectRoute
for id, rt := range routesMap {
@@ -462,6 +614,7 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
Network: netStr,
Domains: &domainDetails,
Selected: r.Selected,
Status: r.Status,
})
}
@@ -470,11 +623,12 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
}
func (c *Client) SelectRoute(id string) error {
if c.connectClient == nil {
_, connectClient := c.stateSnapshot()
if connectClient == nil {
return fmt.Errorf("not connected")
}
engine := c.connectClient.Engine()
engine := connectClient.Engine()
if engine == nil {
return fmt.Errorf("not connected")
}
@@ -500,10 +654,11 @@ func (c *Client) SelectRoute(id string) error {
}
func (c *Client) DeselectRoute(id string) error {
if c.connectClient == nil {
_, connectClient := c.stateSnapshot()
if connectClient == nil {
return fmt.Errorf("not connected")
}
engine := c.connectClient.Engine()
engine := connectClient.Engine()
if engine == nil {
return fmt.Errorf("not connected")
}
@@ -527,6 +682,22 @@ func (c *Client) DeselectRoute(id string) error {
return nil
}
// setState stores the running engine state so DebugBundle can reuse the live
// config and ConnectClient. It is cleared on Stop.
func (c *Client) setState(cfg *profilemanager.Config, cc *internal.ConnectClient) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
c.config = cfg
c.connectClient = cc
}
// stateSnapshot returns the current config and ConnectClient under the lock.
func (c *Client) stateSnapshot() (*profilemanager.Config, *internal.ConnectClient) {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.config, c.connectClient
}
func formatDuration(d time.Duration) string {
ds := d.String()
dotIndex := strings.Index(ds, ".")

View File

@@ -20,6 +20,7 @@ type RoutesSelectionInfo struct {
Network string
Domains *DomainDetails
Selected bool
Status string
}
type DomainCollection interface {

View File

@@ -84,48 +84,16 @@ func NewPolicy(values map[string]any) *Policy {
return &Policy{values: values}
}
// PolicyFetcher is implemented by mobile platforms (Android / iOS) that
// push the OS-managed configuration into the Go runtime instead of
// having Go read an on-disk source directly. Desktop platforms ignore
// this interface — Loader.loadPlatform on windows/darwin reads the
// registry / plist on its own. A Loader constructed with a non-nil
// fetcher delegates to it on mobile; passing nil disables MDM
// enforcement (loadPlatform returns nil values).
type PolicyFetcher interface {
Fetch() map[string]any
}
// Loader is the DI-friendly entry point for reading the active MDM
// policy. Construct one at the daemon's lifecycle owner (Server on
// desktop, gomobile-exposed bridge on mobile) and pass it to anything
// that needs to read MDM state (the reload ticker, profilemanager's
// Config). Each callsite has the Loader handed in instead of looking
// up package-level state.
type Loader struct {
fetcher PolicyFetcher
}
// NewLoader constructs a Loader. The fetcher is consulted only on
// mobile builds (ios || android); on desktop it is unused but accepted
// to keep a single constructor signature across platforms — pass nil
// on desktop.
func NewLoader(f PolicyFetcher) *Loader {
return &Loader{fetcher: f}
}
// Load reads the platform-native MDM configuration and returns a
// Policy. Returns an empty (but non-nil) Policy when no source is
// present, the source is empty, or the platform is unsupported.
// LoadPolicy reads the platform-native MDM configuration. Returns an
// empty (but non-nil) Policy when no source is present, the source is
// empty, or the platform is unsupported.
//
// Diagnostic logging differentiates the three states:
// - source absent / unsupported platform: trace log only
// - source present, zero keys: info "MDM enrolled (no managed keys)"
// - source present, N keys: info "MDM enrolled with N managed keys: [...]"
func (l *Loader) Load() *Policy {
if l == nil {
return &Policy{values: map[string]any{}}
}
values, err := l.loadPlatform()
func LoadPolicy() *Policy {
values, err := loadPlatformPolicy()
if err != nil {
log.Tracef("MDM policy load: %v", err)
return &Policy{values: map[string]any{}}

View File

@@ -25,10 +25,8 @@ import (
// writable plist, as a defense against tampered installs.
const policyPlistPath = "/Library/Managed Preferences/io.netbird.client.plist"
// loadPlatform reads the MDM-managed configuration from the macOS
// managed-preferences plist at policyPlistPath. The Loader's fetcher
// field is unused on this platform — the plist is the authoritative
// source. Returns:
// loadPlatformPolicy reads the MDM-managed configuration from the macOS
// managed-preferences plist at policyPlistPath. Returns:
// - (nil, nil) when the plist is absent (device not MDM-enrolled for
// NetBird, or admin has not yet pushed a payload)
// - (map, nil) with N entries when N managed values are present
@@ -41,13 +39,7 @@ const policyPlistPath = "/Library/Managed Preferences/io.netbird.client.plist"
// skipped so a stray entry in the payload does not block startup.
// Native plist value types map naturally onto the Policy accessor
// expectations (GetString / GetBool / GetInt / GetStringSlice).
func (l *Loader) loadPlatform() (map[string]any, error) {
// Honour the injected fetcher when present so tests (and any
// future non-macOS MDM channel) can short-circuit the plist read
// with a scripted policy.
if l != nil && l.fetcher != nil {
return l.fetcher.Fetch(), nil
}
func loadPlatformPolicy() (map[string]any, error) {
f, err := os.Open(policyPlistPath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {

View File

@@ -2,14 +2,13 @@
package mdm
// loadPlatform reads the OS-managed configuration via the native
// PolicyFetcher injected at Loader construction. Returns
// (nil, nil) — the platform-absent sentinel that Loader.Load treats as
// "no MDM source present" — when no fetcher was provided.
func (l *Loader) loadPlatform() (map[string]any, error) {
if l == nil || l.fetcher == nil {
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see Loader.Load.
return nil, nil
}
return l.fetcher.Fetch(), nil
// loadPlatformPolicy is unused on mobile: the native layer (Swift on iOS,
// Kotlin/Java on Android) reads the OS managed-config store and pushes the
// resulting dictionary in-process via a gomobile entry point that lands in
// Phase 5 / Phase 6. The stub keeps the package compilable for mobile
// builds and returns (nil, nil) — the platform-absent sentinel that
// LoadPolicy in policy.go treats as "no MDM source present".
func loadPlatformPolicy() (map[string]any, error) {
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
return nil, nil
}

View File

@@ -2,17 +2,13 @@
package mdm
// loadPlatform reads the MDM policy on platforms without a native MDM
// channel (Linux, FreeBSD). When no fetcher was injected the policy is
// (nil, nil) — the platform-absent sentinel that Loader.Load treats as
// "MDM enforcement disabled". A non-nil fetcher takes precedence: it
// is the test-seam used by unit tests to inject a scripted policy
// without touching the OS, and the same hook supports any future
// non-mobile OS that grows an out-of-band MDM channel.
func (l *Loader) loadPlatform() (map[string]any, error) {
if l != nil && l.fetcher != nil {
return l.fetcher.Fetch(), nil
}
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see Loader.Load.
// loadPlatformPolicy returns no policy on platforms without an MDM channel
// (Linux, FreeBSD). MDM enforcement is off and the client behaves as if
// the feature did not exist. Returns (nil, nil) — the platform-absent
// sentinel the caller (LoadPolicy in policy.go) treats as "no MDM
// source present"; an error here would just translate to the same
// outcome with an extra log line.
func loadPlatformPolicy() (map[string]any, error) {
//nolint:nilnil // (nil, nil) is the documented platform-absent sentinel; see LoadPolicy.
return nil, nil
}

View File

@@ -150,12 +150,10 @@ func TestPolicy_GetStringSlice(t *testing.T) {
})
}
func TestLoader_NilFetcherReturnsEmpty(t *testing.T) {
// Loader.Load with no fetcher (desktop construction) must degrade
// gracefully and never return nil; on linux loadPlatform is a stub
// returning (nil, nil), and Load is expected to translate that
// into a non-nil empty Policy.
p := NewLoader(nil).Load()
func TestLoadPolicy_PlatformStubReturnsEmpty(t *testing.T) {
// loadPlatformPolicy is a stub on every OS for Phase 1. LoadPolicy must
// degrade gracefully and never return nil.
p := LoadPolicy()
require.NotNil(t, p)
assert.True(t, p.IsEmpty())
assert.Empty(t, p.ManagedKeys())

View File

@@ -61,10 +61,8 @@ func readRegistryValue(k registry.Key, name, canonical string, out map[string]an
}
}
// loadPlatform reads the MDM-managed configuration from the Windows
// registry under HKLM\Software\Policies\NetBird. The Loader's fetcher
// field is unused on this platform — the registry is the
// authoritative source. Returns:
// loadPlatformPolicy reads the MDM-managed configuration from the
// Windows registry under HKLM\Software\Policies\NetBird. Returns:
// - (nil, nil) when the key is absent (device not MDM-enrolled for NetBird)
// - (map, nil) with N entries when N managed values are set (N may be 0)
// - (nil, err) on open / enumerate registry errors
@@ -72,13 +70,7 @@ func readRegistryValue(k registry.Key, name, canonical string, out map[string]an
// Per-value type coercion + skip-on-error is delegated to
// readRegistryValue. Unknown value names are logged and skipped so a
// malformed deployment does not block startup.
func (l *Loader) loadPlatform() (map[string]any, error) {
// Honour the injected fetcher when present so tests (and any
// future non-Windows MDM channel) can short-circuit the registry
// read with a scripted policy.
if l != nil && l.fetcher != nil {
return l.fetcher.Fetch(), nil
}
func loadPlatformPolicy() (map[string]any, error) {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, policyRegistryPath, registry.QUERY_VALUE)
if err != nil {
if errors.Is(err, registry.ErrNotExist) {

View File

@@ -15,33 +15,33 @@ import (
// instead, hence anticipating the ticker mechanism entirely.
const DefaultReloadInterval = 1 * time.Minute
// Ticker periodically re-reads the OS-native MDM policy via the
// injected Loader and invokes the onChange callback (supplied to Run)
// whenever the observed Policy diverges from the last observation
// (added / removed / changed keys). Launch with Run from a goroutine;
// cancel the supplied context to stop.
// policyLoader is the indirection through which the ticker reads the
// OS-native policy, both for the initial observation and on every tick.
// Production points it at LoadPolicy; tests in this package override it to
// feed a scripted sequence of policies without touching the real OS store.
var policyLoader = LoadPolicy
// Ticker periodically re-reads the OS-native MDM policy via LoadPolicy and
// invokes the onChange callback (supplied to Run) whenever the observed
// Policy diverges from the last observation (added / removed / changed
// keys). Launch with Run from a goroutine; cancel the supplied context
// to stop.
type Ticker struct {
interval time.Duration
loader *Loader
prev *Policy
}
// NewTicker constructs a Ticker that will re-read the OS-native policy
// every reloadInterval once Run is called. The Loader is injected so
// the ticker doesn't depend on any package-level state — production
// passes the daemon-owned Loader, tests pass a fake Loader (built with
// a fake PolicyFetcher).
//
// The initial snapshot is populated by calling loader.Load() at
// every reloadInterval once Run is called.
// The initial snapshot is populated by calling policyLoader at
// construction time so the first tick only fires
// onChange when the policy actually changed since boot — without
// this baseline the first tick would report every currently-managed
// key as "added" and trigger a spurious engine restart.
func NewTicker(reloadInterval time.Duration, loader *Loader) *Ticker {
func NewTicker(reloadInterval time.Duration) *Ticker {
return &Ticker{
interval: reloadInterval,
loader: loader,
prev: loader.Load(),
prev: policyLoader(),
}
}
@@ -58,7 +58,7 @@ func (t *Ticker) Run(ctx context.Context, onChange func(prev, curr *Policy) erro
log.Info("MDM policy reload ticker stopped")
return
case <-tk.C:
curr := t.loader.Load()
curr := policyLoader()
if policiesEqual(t.prev, curr) {
continue
}

View File

@@ -13,40 +13,28 @@ import (
// testReloadInterval for speeding up the ticker cadence under `go test`
const testReloadInterval = 1 * time.Second
// fakePolicyFetcher implements PolicyFetcher returning a scripted
// policy map. Goroutine-safe so the test can mutate the script while
// the ticker is observing it.
type fakePolicyFetcher struct {
mu sync.Mutex
values map[string]any
}
func (f *fakePolicyFetcher) Fetch() map[string]any {
f.mu.Lock()
defer f.mu.Unlock()
if f.values == nil {
return nil
}
out := make(map[string]any, len(f.values))
for k, v := range f.values {
out[k] = v
}
return out
}
func (f *fakePolicyFetcher) set(values map[string]any) {
f.mu.Lock()
defer f.mu.Unlock()
f.values = values
// withPolicyLoader overrides the package-level policyLoader for the duration
// of the test so the ticker observes a scripted policy instead of the real
// OS-native store. The original loader is restored on cleanup.
func withPolicyLoader(t *testing.T, fn func() *Policy) {
t.Helper()
prev := policyLoader
policyLoader = fn
t.Cleanup(func() { policyLoader = prev })
}
func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
fetcher := &fakePolicyFetcher{} // initial observation: empty (no enforcement)
loader := NewLoader(fetcher)
var mu sync.Mutex
current := NewPolicy(nil) // initial observation: empty (no enforcement)
withPolicyLoader(t, func() *Policy {
mu.Lock()
defer mu.Unlock()
return current
})
type change struct{ prev, curr *Policy }
changes := make(chan change, 1)
tk := NewTicker(testReloadInterval, loader)
tk := NewTicker(testReloadInterval)
require.Equal(t, testReloadInterval, tk.interval)
ctx, cancel := context.WithCancel(context.Background())
@@ -61,13 +49,15 @@ func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
})
close(done)
}()
// Stop Run and wait for it to exit before returning, so the test
// goroutine doesn't race the still-running ticker.
// Stop Run and wait for it to exit before returning, so the policyLoader
// restore in t.Cleanup can't race the ticker goroutine still reading it.
defer func() { cancel(); <-done }()
// Flip the OS-observed policy from empty to one managed key. The
// next tick must detect the diff and invoke onChange.
fetcher.set(map[string]any{KeyManagementURL: "https://mdm.example.com:443"})
// Flip the OS-observed policy from empty to one managed key. The next
// tick must detect the diff and invoke onChange.
mu.Lock()
current = NewPolicy(map[string]any{KeyManagementURL: "https://mdm.example.com:443"})
mu.Unlock()
select {
case c := <-changes:
@@ -79,11 +69,12 @@ func TestTicker_FiresOnChangeWithDelta(t *testing.T) {
}
func TestTicker_NoCallbackWhenPolicyUnchanged(t *testing.T) {
fetcher := &fakePolicyFetcher{values: map[string]any{KeyBlockInbound: true}}
loader := NewLoader(fetcher)
withPolicyLoader(t, func() *Policy {
return NewPolicy(map[string]any{KeyBlockInbound: true})
})
fired := make(chan struct{}, 1)
tk := NewTicker(testReloadInterval, loader)
tk := NewTicker(testReloadInterval)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
@@ -99,8 +90,8 @@ func TestTicker_NoCallbackWhenPolicyUnchanged(t *testing.T) {
}()
defer func() { cancel(); <-done }()
// Over ~2 ticks at the 1s test cadence the policy never changes,
// so the diff guard must suppress the callback entirely.
// Over ~2 ticks at the 1s test cadence the policy never changes, so the
// diff guard must suppress the callback entirely.
select {
case <-fired:
t.Fatal("onChange fired despite an unchanged policy")

View File

@@ -1849,10 +1849,13 @@ func (x *ManagementState) GetError() string {
// RelayState contains the latest state of the relay
type RelayState struct {
state protoimpl.MessageState `protogen:"open.v1"`
URI string `protobuf:"bytes,1,opt,name=URI,proto3" json:"URI,omitempty"`
Available bool `protobuf:"varint,2,opt,name=available,proto3" json:"available,omitempty"`
Error string `protobuf:"bytes,3,opt,name=error,proto3" json:"error,omitempty"`
state protoimpl.MessageState `protogen:"open.v1"`
URI string `protobuf:"bytes,1,opt,name=URI,proto3" json:"URI,omitempty"`
Available bool `protobuf:"varint,2,opt,name=available,proto3" json:"available,omitempty"`
Error string `protobuf:"bytes,3,opt,name=error,proto3" json:"error,omitempty"`
// transport is the negotiated relay transport (e.g. "ws", "quic"),
// empty for stun/turn probes or when not connected.
Transport string `protobuf:"bytes,4,opt,name=transport,proto3" json:"transport,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -1908,6 +1911,13 @@ func (x *RelayState) GetError() string {
return ""
}
func (x *RelayState) GetTransport() string {
if x != nil {
return x.Transport
}
return ""
}
type NSGroupState struct {
state protoimpl.MessageState `protogen:"open.v1"`
Servers []string `protobuf:"bytes,1,rep,name=servers,proto3" json:"servers,omitempty"`
@@ -6486,12 +6496,13 @@ const file_daemon_proto_rawDesc = "" +
"\x0fManagementState\x12\x10\n" +
"\x03URL\x18\x01 \x01(\tR\x03URL\x12\x1c\n" +
"\tconnected\x18\x02 \x01(\bR\tconnected\x12\x14\n" +
"\x05error\x18\x03 \x01(\tR\x05error\"R\n" +
"\x05error\x18\x03 \x01(\tR\x05error\"p\n" +
"\n" +
"RelayState\x12\x10\n" +
"\x03URI\x18\x01 \x01(\tR\x03URI\x12\x1c\n" +
"\tavailable\x18\x02 \x01(\bR\tavailable\x12\x14\n" +
"\x05error\x18\x03 \x01(\tR\x05error\"r\n" +
"\x05error\x18\x03 \x01(\tR\x05error\x12\x1c\n" +
"\ttransport\x18\x04 \x01(\tR\ttransport\"r\n" +
"\fNSGroupState\x12\x18\n" +
"\aservers\x18\x01 \x03(\tR\aservers\x12\x18\n" +
"\adomains\x18\x02 \x03(\tR\adomains\x12\x18\n" +

View File

@@ -378,6 +378,9 @@ message RelayState {
string URI = 1;
bool available = 2;
string error = 3;
// transport is the negotiated relay transport (e.g. "ws", "quic"),
// empty for stun/turn probes or when not connected.
string transport = 4;
}
message NSGroupState {

View File

@@ -20,6 +20,10 @@ import (
// a no-op echo, never as a conflict with the policy.
const preSharedKeyRedactedSentinel = "**********"
// loadMDMPolicy is the indirection used by server handlers to read the
// active MDM policy. Tests override this to inject a fake policy.
var loadMDMPolicy = mdm.LoadPolicy
// conflictCheck is a value-aware comparison between a single field in
// the incoming request and the corresponding MDM-enforced value. It
// runs only when the field was actually set in the request (presence

View File

@@ -110,15 +110,6 @@ type Server struct {
// stopped by the rootCtx cancellation.
mdmTicker *mdm.Ticker
// mdmLoader is the daemon-owned source of the active MDM policy.
// Constructed once during Server.Start (with a nil PolicyFetcher on
// desktop — the build-tagged Loader.loadPlatform reads the OS
// registry / plist directly) and injected into every consumer:
// mdmTicker for its periodic reload, the SetConfig / Login MDM
// gates for conflict detection, and every Config produced via
// getConfig() so its apply() picks up the same overlay.
mdmLoader *mdm.Loader
updateManager *updater.Manager
jwtCache *jwtCache
@@ -182,14 +173,8 @@ func (s *Server) Start() error {
// Runs re-resolves Config (re-running profilemanager.Config.apply which
// applies the freshly-read MDM policy as the last layer) and brings
// the engine back with the new values.
if s.mdmLoader == nil {
// Desktop builds pass a nil PolicyFetcher: the Loader's
// build-tagged loadPlatform reads the OS source directly
// (registry on Windows, plist on macOS, no-op elsewhere).
s.mdmLoader = mdm.NewLoader(nil)
}
if s.mdmTicker == nil {
s.mdmTicker = mdm.NewTicker(mdm.DefaultReloadInterval, s.mdmLoader)
s.mdmTicker = mdm.NewTicker(mdm.DefaultReloadInterval)
go s.mdmTicker.Run(s.rootCtx, s.onMDMPolicyChange)
}
@@ -385,7 +370,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
// by the active MDM policy. The error carries an MDMManagedFields-
// Violation detail listing the offending key names. Non-conflicting
// fields in the same request are not applied either.
policy := s.mdmLoader.Load()
policy := loadMDMPolicy()
if err := rejectMDMManagedFieldConflicts(mdmManagedFieldConflicts(msg, policy)); err != nil {
return nil, err
}
@@ -511,7 +496,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
if s.checkUpdateSettingsDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
}
policy := s.mdmLoader.Load()
policy := loadMDMPolicy()
if err := rejectMDMManagedFieldConflicts(loginRequestMDMConflicts(msg, policy)); err != nil {
return nil, err
}
@@ -1103,12 +1088,6 @@ func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*prof
return nil, false, fmt.Errorf("failed to get config: %w", err)
}
// Apply the daemon-owned MDM policy on top of the just-resolved
// Config. profilemanager's apply() initialises the policy to
// empty — the Loader lives outside Config, so this overlay step
// is driven externally here.
config.ApplyMDMPolicy(s.mdmLoader.Load())
return config, configExisted, nil
}
@@ -1163,9 +1142,6 @@ func (s *Server) logoutFromProfile(ctx context.Context, profileName, username st
if err != nil {
return fmt.Errorf("profile '%s' not found", profileName)
}
// Honour any MDM-enforced ManagementURL when issuing the logout
// RPC: the user-stored value may have been overridden by policy.
config.ApplyMDMPolicy(s.mdmLoader.Load())
return s.sendLogoutRequestWithConfig(ctx, config)
}
@@ -1598,11 +1574,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
log.Errorf("failed to get active profile config: %v", err)
return nil, fmt.Errorf("failed to get active profile config: %w", err)
}
// Overlay the active MDM policy so the response's MDMManagedFields
// list reflects what the GUI / CLI must render as read-only.
// profilemanager.GetConfig itself returns a Config without the
// overlay (Loader lives outside profilemanager).
cfg.ApplyMDMPolicy(s.mdmLoader.Load())
managementURL := cfg.ManagementURL
adminURL := cfg.AdminURL

View File

@@ -16,40 +16,14 @@ import (
"github.com/netbirdio/netbird/client/proto"
)
// fakeMDMFetcher implements mdm.PolicyFetcher returning a pre-set
// policy map. Tests build one per Server instance to inject a
// scripted MDM overlay via a Loader rather than via package-level state.
type fakeMDMFetcher struct{ values map[string]any }
func (f *fakeMDMFetcher) Fetch() map[string]any { return f.values }
// withMDMPolicy installs an mdm.Loader on the given Server whose
// loadPlatform returns the supplied Policy's underlying values. Use
// after setupServerWithProfile to inject the scripted policy the
// SetConfig / Login MDM gates will observe.
func withMDMPolicy(t *testing.T, s *Server, policy *mdm.Policy) {
// withMDMPolicy temporarily overrides the server-package loadMDMPolicy hook
// so SetConfig observes the supplied Policy. Restores the original loader
// at test cleanup.
func withMDMPolicy(t *testing.T, policy *mdm.Policy) {
t.Helper()
values := map[string]any{}
if policy != nil {
for _, k := range policy.ManagedKeys() {
if v, ok := policy.GetString(k); ok {
values[k] = v
continue
}
if v, ok := policy.GetBool(k); ok {
values[k] = v
continue
}
if v, ok := policy.GetInt(k); ok {
values[k] = v
continue
}
if v, ok := policy.GetStringSlice(k); ok {
values[k] = v
}
}
}
s.mdmLoader = mdm.NewLoader(&fakeMDMFetcher{values: values})
prev := loadMDMPolicy
loadMDMPolicy = func() *mdm.Policy { return policy }
t.Cleanup(func() { loadMDMPolicy = prev })
}
// setupServerWithProfile mirrors the boilerplate of TestSetConfig_AllFieldsSaved:
@@ -115,11 +89,12 @@ func extractViolation(t *testing.T, err error) *proto.MDMManagedFieldsViolation
}
func TestSetConfig_MDMReject_SingleField(t *testing.T) {
s, ctx, profName, username, _ := setupServerWithProfile(t)
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
}))
s, ctx, profName, username, _ := setupServerWithProfile(t)
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
ProfileName: profName,
Username: username,
@@ -131,13 +106,14 @@ func TestSetConfig_MDMReject_SingleField(t *testing.T) {
}
func TestSetConfig_MDMReject_MultipleFields(t *testing.T) {
s, ctx, profName, username, _ := setupServerWithProfile(t)
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
mdm.KeyBlockInbound: true,
mdm.KeyRosenpassEnabled: true,
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
mdm.KeyBlockInbound: true,
mdm.KeyRosenpassEnabled: true,
}))
s, ctx, profName, username, _ := setupServerWithProfile(t)
blockInbound := false
rosenpassEnabled := false
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
@@ -161,11 +137,12 @@ func TestSetConfig_MDMReject_AllOrNothing(t *testing.T) {
// enforced field AND a non-enforced field (RosenpassEnabled).
// The whole request must be rejected — non-conflicting fields are not
// applied either.
s, ctx, profName, username, cfgPath := setupServerWithProfile(t)
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
}))
s, ctx, profName, username, cfgPath := setupServerWithProfile(t)
rosenpassEnabled := true
_, err := s.SetConfig(ctx, &proto.SetConfigRequest{
ProfileName: profName,
@@ -187,11 +164,12 @@ func TestSetConfig_MDMReject_AllOrNothing(t *testing.T) {
func TestSetConfig_MDMAllow_NonManagedFields(t *testing.T) {
// MDM enforces ManagementURL but the user only writes RosenpassEnabled.
// Request must succeed.
s, ctx, profName, username, _ := setupServerWithProfile(t)
withMDMPolicy(t, s, mdm.NewPolicy(map[string]any{
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
}))
s, ctx, profName, username, _ := setupServerWithProfile(t)
rosenpassEnabled := true
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
ProfileName: profName,
@@ -205,8 +183,9 @@ func TestSetConfig_MDMAllow_NonManagedFields(t *testing.T) {
func TestSetConfig_MDMEmpty_NoEnforcement(t *testing.T) {
// No MDM policy active: any field can be written.
withMDMPolicy(t, mdm.NewPolicy(nil))
s, ctx, profName, username, _ := setupServerWithProfile(t)
withMDMPolicy(t, s, mdm.NewPolicy(nil))
resp, err := s.SetConfig(ctx, &proto.SetConfigRequest{
ProfileName: profName,

View File

@@ -98,6 +98,7 @@ type RelayStateOutputDetail struct {
URI string `json:"uri" yaml:"uri"`
Available bool `json:"available" yaml:"available"`
Error string `json:"error" yaml:"error"`
Transport string `json:"transport,omitempty" yaml:"transport,omitempty"`
}
type RelayStateOutput struct {
@@ -219,7 +220,8 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
RelayStateOutputDetail{
URI: relay.URI,
Available: available,
Error: relay.GetError(),
Error: relayErrorString(relay.GetError()),
Transport: relay.GetTransport(),
},
)
@@ -235,6 +237,12 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
}
}
// relayErrorString flattens a newline-joined aggregated relay error onto a
// single line for status output.
func relayErrorString(s string) string {
return strings.ReplaceAll(s, "\n", "; ")
}
func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput {
mappedNSGroups := make([]NsServerGroupStateOutput, 0, len(servers))
for _, pbNsGroupServer := range servers {
@@ -441,6 +449,8 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
} else if relay.Transport != "" {
available = fmt.Sprintf("%s via %s", available, relay.Transport)
}
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)

View File

@@ -647,3 +647,13 @@ func TestTimeAgo(t *testing.T) {
})
}
}
func TestMapRelaysTransport(t *testing.T) {
out := mapRelays([]*proto.RelayState{
{URI: "rels://relay.example:443", Available: true, Transport: "quic"},
{URI: "rels://relay2.example:443", Available: true, Transport: "ws"},
})
require.Len(t, out.Details, 2)
assert.Equal(t, "quic", out.Details[0].Transport)
assert.Equal(t, "ws", out.Details[1].Transport)
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
nbwebsocket "github.com/netbirdio/netbird/client/wasm/internal/websocket"
"github.com/netbirdio/netbird/util"
)
@@ -30,6 +31,7 @@ const (
pingTimeout = 10 * time.Second
defaultLogLevel = "warn"
defaultSSHDetectionTimeout = 20 * time.Second
dialWebSocketTimeout = 30 * time.Second
icmpEchoRequest = 8
icmpCodeEcho = 0
@@ -677,6 +679,7 @@ func createClientObject(client *netbird.Client) js.Value {
obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client)
obj["dialWebSocket"] = createDialWebSocketMethod(client)
obj["status"] = createStatusMethod(client)
obj["statusSummary"] = createStatusSummaryMethod(client)
obj["statusDetail"] = createStatusDetailMethod(client)
@@ -691,6 +694,74 @@ func createClientObject(client *netbird.Client) js.Value {
return js.ValueOf(obj)
}
func createDialWebSocketMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
url, protocols, timeout, errVal := parseDialWebSocketArgs(args)
if !errVal.IsUndefined() {
return errVal
}
return createPromise(func(resolve, reject js.Value) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := nbwebsocket.Dial(ctx, client, url, protocols)
if err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("dial websocket: %v", err)))
return
}
resolve.Invoke(nbwebsocket.NewJSInterface(conn))
})
})
}
func parseDialWebSocketArgs(args []js.Value) (url string, protocols []string, timeout time.Duration, errVal js.Value) {
if len(args) < 1 || args[0].Type() != js.TypeString {
return "", nil, 0, js.ValueOf("error: dialWebSocket requires a URL string argument")
}
url = args[0].String()
if len(args) >= 2 && !args[1].IsNull() && !args[1].IsUndefined() {
arr, err := jsStringArray(args[1])
if err != nil {
return "", nil, 0, js.ValueOf(fmt.Sprintf("error: protocols: %v", err))
}
protocols = arr
}
timeout = dialWebSocketTimeout
if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() {
if args[2].Type() != js.TypeNumber {
return "", nil, 0, js.ValueOf("error: timeoutMs must be a number")
}
timeoutMs := args[2].Int()
if timeoutMs <= 0 {
return "", nil, 0, js.ValueOf("error: timeout must be positive")
}
timeout = time.Duration(timeoutMs) * time.Millisecond
}
return url, protocols, timeout, js.Undefined()
}
// jsStringArray converts a JS array of strings to a Go []string.
func jsStringArray(v js.Value) ([]string, error) {
if !v.InstanceOf(js.Global().Get("Array")) {
return nil, fmt.Errorf("expected array")
}
n := v.Length()
out := make([]string, n)
for i := 0; i < n; i++ {
el := v.Index(i)
if el.Type() != js.TypeString {
return nil, fmt.Errorf("element %d is not a string", i)
}
out[i] = el.String()
}
return out, nil
}
// netBirdClientConstructor acts as a JavaScript constructor function
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {

View File

@@ -0,0 +1,304 @@
//go:build js
package websocket
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
"syscall/js"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
netbird "github.com/netbirdio/netbird/client/embed"
log "github.com/sirupsen/logrus"
)
type closeError struct {
code uint16
reason string
}
func (e *closeError) Error() string {
return fmt.Sprintf("websocket closed: %d %s", e.code, e.reason)
}
// bufferedConn fronts a net.Conn with a reader that serves any bytes buffered
// during the WebSocket handshake before falling through to the raw conn.
type bufferedConn struct {
net.Conn
r io.Reader
}
func (c *bufferedConn) Read(p []byte) (int, error) { return c.r.Read(p) }
// Conn wraps a WebSocket connection over a NetBird TCP connection.
type Conn struct {
conn net.Conn
mu sync.Mutex
closed chan struct{}
closeOnce sync.Once
closeErr error
}
// Dial establishes a WebSocket connection to the given URL through the NetBird network.
// Optional protocols are sent via the Sec-WebSocket-Protocol header.
func Dial(ctx context.Context, client *netbird.Client, rawURL string, protocols []string) (*Conn, error) {
d := ws.Dialer{
NetDial: client.Dial,
Protocols: protocols,
}
conn, br, _, err := d.Dial(ctx, rawURL)
if err != nil {
return nil, fmt.Errorf("websocket dial: %w", err)
}
// br is non-nil when the server pushed frames alongside the handshake
// response; those bytes live in the bufio.Reader and must be drained
// before reading from conn, otherwise we'd skip the first frames.
if br != nil {
if br.Buffered() > 0 {
conn = &bufferedConn{Conn: conn, r: io.MultiReader(br, conn)}
} else {
ws.PutReader(br)
}
}
return &Conn{
conn: conn,
closed: make(chan struct{}),
}, nil
}
// ReadMessage reads the next WebSocket message, handling control frames automatically.
func (c *Conn) ReadMessage() (ws.OpCode, []byte, error) {
for {
msgs, err := wsutil.ReadServerMessage(c.conn, nil)
if err != nil {
return 0, nil, err
}
for _, msg := range msgs {
if msg.OpCode.IsControl() {
if err := c.handleControl(msg); err != nil {
return 0, nil, err
}
continue
}
return msg.OpCode, msg.Payload, nil
}
}
}
func (c *Conn) handleControl(msg wsutil.Message) error {
switch msg.OpCode {
case ws.OpPing:
c.mu.Lock()
defer c.mu.Unlock()
return wsutil.WriteClientMessage(c.conn, ws.OpPong, msg.Payload)
case ws.OpClose:
code, reason := parseClosePayload(msg.Payload)
return &closeError{code: code, reason: reason}
default:
return nil
}
}
// WriteText sends a text WebSocket message.
func (c *Conn) WriteText(data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
return wsutil.WriteClientMessage(c.conn, ws.OpText, data)
}
// WriteBinary sends a binary WebSocket message.
func (c *Conn) WriteBinary(data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
return wsutil.WriteClientMessage(c.conn, ws.OpBinary, data)
}
// Close sends a close frame with StatusNormalClosure and closes the underlying connection.
func (c *Conn) Close() error {
return c.closeWith(ws.StatusNormalClosure, "")
}
// closeWith sends a close frame with the given code/reason and closes the underlying connection.
// Used to echo the server's code when responding to a server-initiated close per RFC 6455 §5.5.1.
func (c *Conn) closeWith(code ws.StatusCode, reason string) error {
var first bool
c.closeOnce.Do(func() {
first = true
close(c.closed)
c.mu.Lock()
_ = wsutil.WriteClientMessage(c.conn, ws.OpClose, ws.NewCloseFrameBody(code, reason))
c.mu.Unlock()
c.closeErr = c.conn.Close()
})
if !first {
return net.ErrClosed
}
return c.closeErr
}
// NewJSInterface creates a JavaScript object wrapping the WebSocket connection.
// It exposes: send(string|Uint8Array), close(), and callback properties
// onmessage, onclose, onerror.
//
// Callback properties may be set from the JS thread while the read loop
// goroutine reads them. In WASM this is safe because Go and JS share a
// single thread, but the design would need synchronization on
// multi-threaded runtimes.
func NewJSInterface(conn *Conn) js.Value {
obj := js.Global().Get("Object").Call("create", js.Null())
sendFunc := js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
log.Errorf("websocket send requires a data argument")
return js.ValueOf(false)
}
data := args[0]
switch data.Type() {
case js.TypeString:
if err := conn.WriteText([]byte(data.String())); err != nil {
log.Errorf("failed to send websocket text: %v", err)
return js.ValueOf(false)
}
default:
buf, err := jsToBytes(data)
if err != nil {
log.Errorf("failed to convert js value to bytes: %v", err)
return js.ValueOf(false)
}
if err := conn.WriteBinary(buf); err != nil {
log.Errorf("failed to send websocket binary: %v", err)
return js.ValueOf(false)
}
}
return js.ValueOf(true)
})
obj.Set("send", sendFunc)
closeFunc := js.FuncOf(func(_ js.Value, _ []js.Value) any {
if err := conn.Close(); err != nil {
log.Debugf("failed to close websocket: %v", err)
}
return js.Undefined()
})
obj.Set("close", closeFunc)
go func() {
defer func() {
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
log.Debugf("close websocket on readLoop exit: %v", err)
}
}()
readLoop(conn, obj)
// Undefining before Release turns post-close JS calls into TypeError
// instead of a silent "call to released function".
obj.Set("send", js.Undefined())
obj.Set("close", js.Undefined())
sendFunc.Release()
closeFunc.Release()
}()
return obj
}
func jsToBytes(data js.Value) ([]byte, error) {
var uint8Array js.Value
switch {
case data.InstanceOf(js.Global().Get("Uint8Array")):
uint8Array = data
case data.InstanceOf(js.Global().Get("ArrayBuffer")):
uint8Array = js.Global().Get("Uint8Array").New(data)
default:
return nil, fmt.Errorf("send: unsupported data type, use string, Uint8Array, or ArrayBuffer")
}
buf := make([]byte, uint8Array.Get("length").Int())
js.CopyBytesToGo(buf, uint8Array)
return buf, nil
}
func readLoop(conn *Conn, obj js.Value) {
var ce *closeError
defer func() { invokeOnClose(obj, ce) }()
for {
select {
case <-conn.closed:
return
default:
}
op, payload, err := conn.ReadMessage()
if err != nil {
ce = handleReadError(conn, obj, err)
return
}
dispatchMessage(obj, op, payload)
}
}
func handleReadError(conn *Conn, obj js.Value, err error) *closeError {
var ce *closeError
if errors.As(err, &ce) {
if cerr := conn.closeWith(ws.StatusCode(ce.code), ce.reason); cerr != nil {
log.Debugf("failed to close websocket after server close frame: %v", cerr)
}
return ce
}
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
return nil
}
if onerror := obj.Get("onerror"); onerror.Truthy() {
onerror.Invoke(js.ValueOf(err.Error()))
}
return nil
}
func invokeOnClose(obj js.Value, ce *closeError) {
onclose := obj.Get("onclose")
if !onclose.Truthy() {
return
}
if ce != nil {
onclose.Invoke(js.ValueOf(int(ce.code)), js.ValueOf(ce.reason))
return
}
onclose.Invoke()
}
func dispatchMessage(obj js.Value, op ws.OpCode, payload []byte) {
onmessage := obj.Get("onmessage")
if !onmessage.Truthy() {
return
}
switch op {
case ws.OpText:
onmessage.Invoke(js.ValueOf(string(payload)))
case ws.OpBinary:
uint8Array := js.Global().Get("Uint8Array").New(len(payload))
js.CopyBytesToJS(uint8Array, payload)
onmessage.Invoke(uint8Array)
}
}
func parseClosePayload(payload []byte) (uint16, string) {
if len(payload) < 2 {
return 1005, "" // RFC 6455: No Status Rcvd
}
code := binary.BigEndian.Uint16(payload[:2])
return code, string(payload[2:])
}

View File

@@ -2,4 +2,5 @@ FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-server" ]
CMD ["--config", "/etc/netbird/config.yaml"]
COPY netbird-server /go/bin/netbird-server
ARG TARGETPLATFORM
COPY ${TARGETPLATFORM}/netbird-server /go/bin/netbird-server

3
go.mod
View File

@@ -56,6 +56,7 @@ require (
github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/go-jose/go-jose/v4 v4.1.4
github.com/gobwas/ws v1.4.0
github.com/goccy/go-yaml v1.18.0
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.1
@@ -215,6 +216,8 @@ require (
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/go-webauthn/webauthn v0.16.4 // indirect
github.com/go-webauthn/x v0.2.3 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
github.com/google/btree v1.1.3 // indirect

7
go.sum
View File

@@ -249,6 +249,12 @@ github.com/go-webauthn/webauthn v0.16.4 h1:R9jqR/cYZa7hRquFF7Za/8qoH/K/TIs1/Q/4C
github.com/go-webauthn/webauthn v0.16.4/go.mod h1:SU2ljAgToTV/YLPI0C05QS4qn+e04WpB5g1RMfcZfS4=
github.com/go-webauthn/x v0.2.3 h1:8oArS+Rc1SWFLXhE17KZNx258Z4kUSyaDgsSncCO5RA=
github.com/go-webauthn/x v0.2.3/go.mod h1:tM04GF3V6VYq79AZMl7vbj4q6pz9r7L2criWRzbWhPk=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs=
github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
@@ -845,6 +851,7 @@ golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -2,4 +2,5 @@ FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
CMD ["--log-file", "console"]
COPY netbird-mgmt /go/bin/netbird-mgmt
ARG TARGETPLATFORM
COPY ${TARGETPLATFORM}/netbird-mgmt /go/bin/netbird-mgmt

View File

@@ -1,5 +0,0 @@
FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-mgmt","management","--log-level","debug"]
CMD ["--log-file", "console"]
COPY netbird-mgmt /go/bin/netbird-mgmt

View File

@@ -45,7 +45,7 @@ type Controller struct {
EphemeralPeersManager ephemeral.Manager
accountUpdateLocks sync.Map
sendAccountUpdateLocks sync.Map
affectedPeerUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
@@ -64,6 +64,13 @@ type bufferUpdate struct {
update atomic.Bool
}
type bufferAffectedUpdate struct {
sendMu sync.Mutex
dataMu sync.Mutex
next *time.Timer
peerIDs map[string]struct{}
}
var _ network_map.Controller = (*Controller)(nil)
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
@@ -201,7 +208,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
@@ -226,44 +233,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
return nil
}
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return nil
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
})
return
}
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
}()
return nil
}
// UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
@@ -273,6 +242,143 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
return c.sendUpdateAccountPeers(ctx, accountID, reason)
}
// UpdateAffectedPeers updates only the specified peers that belong to an account.
func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
if len(peerIDs) == 0 {
return nil
}
return c.sendUpdateForAffectedPeers(ctx, accountID, peerIDs)
}
func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers: %v (caller: %s)", accountID, len(peerIDs), peerIDs, util.GetCallerName())
if !c.hasConnectedPeers(peerIDs) {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no connected peers among %v, skipping", peerIDs)
return nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
globalStart := time.Now()
peersToUpdate := c.filterConnectedAffectedPeers(account, peerIDs)
if len(peersToUpdate) == 0 {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no peers to update (affected peers not found in account or no channels)")
return nil
}
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: sending network map to %d connected peers", len(peersToUpdate))
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return fmt.Errorf("failed to get validate peers: %v", err)
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
account.InjectProxyPolicies(ctx)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return fmt.Errorf("failed to get proxy network maps: %v", err)
}
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get flow enabled status: %v", err)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return fmt.Errorf("failed to get account zones: %v", err)
}
for _, peer := range peersToUpdate {
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
defer wg.Done()
defer func() { <-semaphore }()
start := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
return
}
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
})
}(peer)
}
wg.Wait()
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
}
return nil
}
func (c *Controller) hasConnectedPeers(peerIDs []string) bool {
for _, id := range peerIDs {
if c.peersUpdateManager.HasChannel(id) {
return true
}
}
return false
}
func (c *Controller) filterConnectedAffectedPeers(account *types.Account, peerIDs []string) []*nbpeer.Peer {
affected := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
affected[id] = struct{}{}
}
var result []*nbpeer.Peer
for _, peer := range account.Peers {
if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) {
result = append(result, peer)
}
}
return result
}
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
if !c.peersUpdateManager.HasChannel(peerId) {
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
@@ -381,66 +487,164 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
if len(peerIDs) == 0 {
return nil
}
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
peerIDs: make(map[string]struct{}),
})
b := bufUpd.(*bufferAffectedUpdate)
b.addPeerIDs(peerIDs)
if !b.sendMu.TryLock() {
// Another goroutine is already sending; it will pick up our IDs on its next drain.
return nil
}
b.stopTimer()
// The send and the debounced timer outlive the calling request, so detach from
// its context to avoid sending with a cancelled context once the handler returns.
bgCtx := context.WithoutCancel(ctx)
collected := b.drainPeerIDs()
go func() {
defer b.sendMu.Unlock()
_ = c.sendUpdateForAffectedPeers(bgCtx, accountID, collected)
// Check if more peer IDs accumulated while we were sending.
if !b.hasPending() {
return
}
// Schedule a debounced flush for the newly accumulated IDs.
b.setTimer(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
ids := b.drainPeerIDs()
if len(ids) > 0 {
_ = c.sendUpdateForAffectedPeers(bgCtx, accountID, ids)
}
})
}()
return nil
}
func (b *bufferAffectedUpdate) addPeerIDs(ids []string) {
b.dataMu.Lock()
for _, id := range ids {
b.peerIDs[id] = struct{}{}
}
b.dataMu.Unlock()
}
func (b *bufferAffectedUpdate) drainPeerIDs() []string {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if len(b.peerIDs) == 0 {
return nil
}
ids := make([]string, 0, len(b.peerIDs))
for id := range b.peerIDs {
ids = append(ids, id)
}
b.peerIDs = make(map[string]struct{})
return ids
}
func (b *bufferAffectedUpdate) hasPending() bool {
b.dataMu.Lock()
defer b.dataMu.Unlock()
return len(b.peerIDs) > 0
}
func (b *bufferAffectedUpdate) stopTimer() {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if b.next != nil {
b.next.Stop()
}
}
func (b *bufferAffectedUpdate) setTimer(d time.Duration, f func()) {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if b.next == nil {
b.next = time.AfterFunc(d, f)
return
}
b.next.Reset(d)
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
emptyMap := &types.NetworkMap{
Network: network.Copy(),
}
return peer, emptyMap, nil, 0, nil
return emptyMap, nil, 0, nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
account.InjectProxyPolicies(ctx)
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
startPosture := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
postureChecks, err := c.getPeerPostureChecks(account, peerID)
if err != nil {
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
proxyNetworkMap, ok := proxyNetworkMaps[peerID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
return peer, networkMap, postureChecks, dnsFwdPort, nil
return networkMap, postureChecks, dnsFwdPort, nil
}
// GetDNSDomain returns the configured dnsDomain
@@ -578,21 +782,24 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
return false, nil
}
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
err := c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer update in account %s, skipping", accountID)
return nil
}
return nil
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
}
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer add in account %s, skipping", accountID)
return nil
}
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
}
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return err
@@ -625,7 +832,11 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
c.peersUpdateManager.CloseChannel(ctx, peerID)
}
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer delete in account %s, skipping", accountID)
return nil
}
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)

View File

@@ -19,17 +19,19 @@ const (
type Controller interface {
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error)
GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
CountStreams() int
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)

View File

@@ -57,6 +57,20 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, r
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
}
// BufferUpdateAffectedPeers mocks base method.
func (m *MockController) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BufferUpdateAffectedPeers", ctx, accountID, peerIDs, reason)
ret0, _ := ret[0].(error)
return ret0
}
// BufferUpdateAffectedPeers indicates an expected call of BufferUpdateAffectedPeers.
func (mr *MockControllerMockRecorder) BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAffectedPeers), ctx, accountID, peerIDs, reason)
}
// CountStreams mocks base method.
func (m *MockController) CountStreams() int {
m.ctrl.T.Helper()
@@ -113,21 +127,20 @@ func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Cal
}
// GetValidatedPeerWithMap mocks base method.
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap)
ret2, _ := ret[2].([]*posture.Checks)
ret3, _ := ret[3].(int64)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, peerID)
ret0, _ := ret[0].(*types.NetworkMap)
ret1, _ := ret[1].([]*posture.Checks)
ret2, _ := ret[2].(int64)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
}
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, peerID)
}
// OnPeerConnected mocks base method.
@@ -158,45 +171,45 @@ func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID
}
// OnPeersAdded mocks base method.
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs, affectedPeerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersAdded indicates an expected call of OnPeersAdded.
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs, affectedPeerIDs)
}
// OnPeersDeleted mocks base method.
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs, affectedPeerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs, affectedPeerIDs)
}
// OnPeersUpdated mocks base method.
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs, affectedPeerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs, affectedPeerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs, affectedPeerIDs)
}
// StartWarmup mocks base method.
@@ -250,3 +263,17 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason)
}
// UpdateAffectedPeers mocks base method.
func (m *MockController) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
func (mr *MockControllerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
}

View File

@@ -242,7 +242,7 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee
},
}
_, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
_, _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
if err != nil {
return fmt.Errorf("failed to create proxy peer: %w", err)
}

View File

@@ -918,6 +918,10 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin
}
for _, svc := range services {
if err = transaction.DeleteServiceTargets(ctx, accountID, svc.ID); err != nil {
return fmt.Errorf("failed to delete service targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
@@ -1270,6 +1274,10 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
}
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}
@@ -1319,6 +1327,10 @@ func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerI
return nil
}
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}

View File

@@ -458,6 +458,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteServiceTargets(ctx, accountID, serviceID).
Return(nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
@@ -560,6 +563,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteServiceTargets(ctx, accountID, serviceID).
Return(nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
@@ -604,6 +610,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteServiceTargets(ctx, accountID, serviceID).
Return(nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
@@ -1192,6 +1201,67 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
}
func TestDeleteExpiredPeerService_DeletesTargets(t *testing.T) {
ctx := context.Background()
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
targets, err := testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
require.Len(t, targets, 1, "ephemeral peer-exposed service should have exactly one persisted target before reaping")
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
err = mgr.deleteExpiredPeerService(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "expired peer-exposed service should be deleted")
s, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.NotFound, s.Type())
targets, err = testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
assert.Len(t, targets, 0, "orphaned target rows must be deleted when an expired peer-exposed service is reaped")
}
func TestDeleteServiceFromPeer_DeletesTargets(t *testing.T) {
ctx := context.Background()
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
targets, err := testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
require.Len(t, targets, 1, "ephemeral peer-exposed service should have exactly one persisted target before stopping")
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "stopped peer-exposed service should be deleted")
s, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.NotFound, s.Type())
targets, err = testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
assert.Len(t, targets, 0, "orphaned target rows must be deleted when a peer stops its exposed service")
}
func TestValidateProtocolChange(t *testing.T) {
tests := []struct {
name string

View File

@@ -778,7 +778,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
}
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
peer, network, postureChecks, enableSSH, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
Meta: peerMeta,
@@ -792,7 +792,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
return nil, mapError(ctx, err)
}
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
loginResp, err := s.prepareLoginResponse(ctx, peer, network, postureChecks, enableSSH)
if err != nil {
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
@@ -895,7 +895,7 @@ func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMess
}, nil
}
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, network *types.Network, postureChecks []*posture.Checks, enableSSH bool) (*proto.LoginResponse, error) {
var relayToken *Token
var err error
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
@@ -914,7 +914,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
PeerConfig: toPeerConfig(peer, network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, enableSSH),
Checks: toProtocolChecks(ctx, postureChecks),
}

View File

@@ -1894,7 +1894,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil {
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
@@ -2577,7 +2577,9 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
if err != nil {
return err
}
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
changedPeerIDs := []string{peerID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -2668,7 +2670,9 @@ func (am *DefaultAccountManager) UpdatePeerIPv6(ctx context.Context, accountID,
}
if updateNetworkMap {
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peerID}); err != nil {
changedPeerIDs := []string{peerID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
return fmt.Errorf("notify network map controller: %w", err)
}
}

View File

@@ -13,6 +13,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -61,7 +62,7 @@ type Manager interface {
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -69,7 +70,7 @@ type Manager interface {
UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
@@ -108,8 +109,8 @@ type Manager interface {
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) // used by peer gRPC API
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
@@ -128,6 +129,7 @@ type Manager interface {
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change)
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error

View File

@@ -15,6 +15,7 @@ import (
dns "github.com/netbirdio/netbird/dns"
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
activity "github.com/netbirdio/netbird/management/server/activity"
affectedpeers "github.com/netbirdio/netbird/management/server/affectedpeers"
idp "github.com/netbirdio/netbird/management/server/idp"
peer "github.com/netbirdio/netbird/management/server/peer"
posture "github.com/netbirdio/netbird/management/server/posture"
@@ -79,14 +80,15 @@ func (mr *MockManagerMockRecorder) AccountExists(ctx, accountID interface{}) *go
}
// AddPeer mocks base method.
func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.Network, []*posture.Checks, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddPeer", ctx, accountID, setupKey, userID, p, temporary)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap)
ret1, _ := ret[1].(*types.Network)
ret2, _ := ret[2].([]*posture.Checks)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
ret3, _ := ret[3].(bool)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
}
// AddPeer indicates an expected call of AddPeer.
@@ -1288,14 +1290,15 @@ func (mr *MockManagerMockRecorder) ListUsers(ctx, accountID interface{}) *gomock
}
// LoginPeer mocks base method.
func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.Network, []*posture.Checks, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoginPeer", ctx, login)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap)
ret1, _ := ret[1].(*types.Network)
ret2, _ := ret[2].([]*posture.Checks)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
ret3, _ := ret[3].(bool)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
}
// LoginPeer indicates an expected call of LoginPeer.
@@ -1320,17 +1323,17 @@ func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID int
}
// MarkPeerConnected mocks base method.
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt)
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
ret0, _ := ret[0].(error)
return ret0
}
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
}
// MarkPeerDisconnected mocks base method.
@@ -1637,6 +1640,18 @@ func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason)
}
// ExpandAndUpdateAffected mocks base method.
func (m *MockManager) ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ExpandAndUpdateAffected", ctx, accountID, snap, change)
}
// ExpandAndUpdateAffected indicates an expected call of ExpandAndUpdateAffected.
func (mr *MockManagerMockRecorder) ExpandAndUpdateAffected(ctx, accountID, snap, change interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpandAndUpdateAffected", reflect.TypeOf((*MockManager)(nil).ExpandAndUpdateAffected), ctx, accountID, snap, change)
}
// UpdateAccountSettings mocks base method.
func (m *MockManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
m.ctrl.T.Helper()

View File

@@ -84,7 +84,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account
setupKey = key.Key
}
_, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
_, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
if err != nil {
t.Error("expected to add new peer successfully after creating new account, but failed", err)
}
@@ -1092,7 +1092,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
}
expectedPeerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}, false)
@@ -1156,7 +1156,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
expectedPeerKey := key.PublicKey().String()
expectedUserID := userID
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}, false)
@@ -1504,7 +1504,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
peerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
}, false)
@@ -1826,7 +1826,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
@@ -1836,7 +1836,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -1882,7 +1882,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
@@ -1907,7 +1907,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1927,7 +1927,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
}, false)
@@ -1935,7 +1935,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("disconnect peer when session token matches", func(t *testing.T) {
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1956,7 +1956,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
// Newer stream wins on connect (sets SessionStartedAt = now ns).
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1980,7 +1980,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
node2SyncTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano())
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano(), nil)
require.NoError(t, err, "node 2 should connect peer")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1990,7 +1990,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
"SessionStartedAt should equal node2SyncTime token")
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano())
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano(), nil)
require.NoError(t, err, "stale connect should not return error")
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -2017,7 +2017,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"},
}, false)
@@ -2052,7 +2052,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
defer done.Done()
ready.Done()
start.Wait()
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token)
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token, nil)
}()
}
@@ -2080,7 +2080,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
@@ -2093,7 +2093,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}
@@ -3276,7 +3276,7 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
}
expectedPeerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
Status: &nbpeer.PeerStatus{
@@ -3305,6 +3305,19 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
// when the channel delivers.
const peerUpdateTimeout = 5 * time.Second
func drainPeerUpdates(ch <-chan *network_map.UpdateMessage) {
for {
select {
case _, ok := <-ch:
if !ok {
return
}
case <-time.After(200 * time.Millisecond):
return
}
}
}
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
@@ -3431,7 +3444,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
_, _, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
WireGuardPubKey: account.Peers["peer-1"].Key,
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
@@ -3500,7 +3513,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
_, _, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
@@ -3895,13 +3908,13 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
key2, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
}, false)
require.NoError(t, err, "unable to add peer1")
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}, false)

View File

@@ -0,0 +1,117 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/affectedpeers"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
)
// TestAffectedPeers_DependencyCoverageMatrix enumerates each network-map
// dependency crossed with the change-type that can alter it, asserting the
// resolver folds in exactly the peers whose map changes. A new dependency that
// the resolver fails to walk should fail one of these rows; a new change-type
// without a row is a coverage gap to add here.
func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
type row struct {
name string
build func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string)
}
rows := []row{
{
name: "policy-groups/source-group-change refreshes source+routing, excludes unrelated",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "resource-routing-bridge/router-peer-change refreshes policy sources",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "policy-change/explicit-policy refreshes source+routing",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
return affectedpeers.Change{Policies: []*types.Policy{policy}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "policy-destinationresource/explicit-policy bridges to routing peer",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
return affectedpeers.Change{Policies: []*types.Policy{policy}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "resource-change refreshes source+routing on its network",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{Resources: []*resourceTypes.NetworkResource{
{ID: s.resourceID, NetworkID: s.networkID, GroupIDs: []string{s.resourceGroupID}},
}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "network-change refreshes source+routing on that network",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{Networks: []*networkTypes.Network{{ID: s.networkID}}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "posture-check-change refreshes source+routing of gated policy",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
check, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
Name: "cov-min-version",
Checks: posture.ChecksDefinition{NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"}},
}, true)
require.NoError(t, err)
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
policy.SourcePostureChecks = []string{check.ID}
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
require.NoError(t, err)
return affectedpeers.Change{PostureCheckIDs: []string{check.ID}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
}
for _, r := range rows {
t.Run(r.name, func(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
change, mustContain, mustExclude := r.build(t, s, ctx)
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
for _, id := range mustContain {
assert.Contains(t, affected, id, "expected peer to be affected")
}
for _, id := range mustExclude {
assert.NotContains(t, affected, id, "peer must not be affected")
}
})
}
}

View File

@@ -0,0 +1,143 @@
package server
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// An update spans an old and a new state. The affected set must be the UNION of
// peers reachable before and after the change; resolving only against the final
// state drops peers that were reachable but no longer are. These tests pin the
// two paths where the old state is reachable only by the changed object's
// previous references: detaching a resource group, and re-pointing a router peer.
// TestAffectedPeers_E2E_UpdateResource_DetachGroup_RefreshesOldGroupSources:
// a resource is reachable by a source group via two destination resource groups;
// detaching one of them must still refresh that group's policy source peers, even
// though the post-update resource no longer maps to it.
func TestAffectedPeers_E2E_UpdateResource_DetachGroup_RefreshesOldGroupSources(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
// A second resource group + a second source group/peer that reaches the
// resource only through that second group.
const detachGroupID = "rs-detach-grp"
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{ID: detachGroupID, Name: "rs-detach"}))
const secondSourceGroupID = "rs-source-grp-2"
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-detach-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
secondSourcePeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: secondSourceGroupID, Name: "rs-source-2", Peers: []string{secondSourcePeer.ID},
}))
resourcesManager, _, _ := s.managers()
// Attach the resource to the detach group as well: now in [resourceGroup, detachGroup].
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/24",
GroupIDs: []string{s.resourceGroupID, detachGroupID},
Enabled: true,
})
require.NoError(t, err)
// Policy granting the second source group access via the detach group.
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(secondSourceGroupID, detachGroupID), true)
require.NoError(t, err)
secondSrcCh := s.updateManager.CreateChannel(ctx, secondSourcePeer.ID)
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, secondSourcePeer.ID) })
settleAffectedUpdates(secondSrcCh)
done := make(chan struct{})
go func() {
// Detaching the resource from detachGroup removes the second source's
// access; that source peer must be refreshed even though the post-update
// resource no longer maps to detachGroup.
peerShouldReceiveUpdate(t, secondSrcCh)
close(done)
}()
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/24",
GroupIDs: []string{s.resourceGroupID}, // detached detachGroup
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: detaching a resource group did not refresh the old group's policy source peer")
}
}
// TestAffectedPeers_E2E_UpdateRouter_RepointPeer_RefreshesOldRoutingPeer:
// changing router.Peer within the same network must still refresh the OLD routing
// peer, which loses its routing role.
func TestAffectedPeers_E2E_UpdateRouter_RepointPeer_RefreshesOldRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
_, routersManager, _ := s.managers()
routers, err := s.manager.Store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, s.accountID, s.networkID)
require.NoError(t, err)
require.Len(t, routers, 1)
router := routers[0]
oldRoutingPeer := router.Peer
require.NotEmpty(t, oldRoutingPeer)
// A new peer to become the routing peer in place of the old one.
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-newrouter-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
newRoutingPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
oldCh := s.updateManager.CreateChannel(ctx, oldRoutingPeer)
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, oldRoutingPeer) })
settleAffectedUpdates(oldCh)
done := make(chan struct{})
go func() {
// The old routing peer stops serving the resource and must be refreshed.
peerShouldReceiveUpdate(t, oldCh)
close(done)
}()
_, err = routersManager.UpdateRouter(ctx, userID, &routerTypes.NetworkRouter{
ID: router.ID,
NetworkID: s.networkID,
AccountID: s.accountID,
Peer: newRoutingPeer.ID, // repoint within the same network
Masquerade: true,
Metric: 9999,
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: re-pointing the router peer did not refresh the old routing peer")
}
}

View File

@@ -0,0 +1,255 @@
package server
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"sort"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// allPeerMaps computes the serialized per-peer network map for every peer in the
// account, mirroring the controller's compute path so the property test compares
// against real output.
func allPeerMaps(t *testing.T, manager *DefaultAccountManager, accountID string) map[string]string {
t.Helper()
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
require.NoError(t, err)
account.InjectProxyPolicies(ctx)
validated := make(map[string]struct{}, len(account.Peers))
for id := range account.Peers {
validated[id] = struct{}{}
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
out := make(map[string]string, len(account.Peers))
for peerID := range account.Peers {
nm := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validated, resourcePolicies, routers, nil, groupIDToUserIDs)
// Network.Serial is an account-global counter bumped on every change; it
// is not a per-peer dependency, so normalize it out of the comparison.
if nm.Network != nil {
nm.Network.Serial = 0
}
out[peerID] = canonicalJSON(t, nm)
}
return out
}
// canonicalJSON marshals v and returns an order-insensitive string form: every
// JSON array is sorted by the canonical form of its elements. The network map's
// Peers/Routes/FirewallRules/SourceRanges slices have nondeterministic order, so
// a raw JSON compare would report spurious changes.
func canonicalJSON(t *testing.T, v interface{}) string {
t.Helper()
b, err := json.Marshal(v)
require.NoError(t, err)
var parsed interface{}
require.NoError(t, json.Unmarshal(b, &parsed))
canonicalized, err := json.Marshal(sortAny(parsed))
require.NoError(t, err)
return string(canonicalized)
}
func sortAny(v interface{}) interface{} {
switch val := v.(type) {
case []interface{}:
for i := range val {
val[i] = sortAny(val[i])
}
sort.Slice(val, func(i, j int) bool {
bi, _ := json.Marshal(val[i])
bj, _ := json.Marshal(val[j])
return string(bi) < string(bj)
})
return val
case map[string]interface{}:
for k := range val {
val[k] = sortAny(val[k])
}
return val
default:
return v
}
}
// changedPeers returns the peer IDs whose serialized map differs between before
// and after.
func changedPeers(before, after map[string]string) []string {
var changed []string
for id, b := range before {
a, ok := after[id]
if !ok || a != b {
changed = append(changed, id)
}
}
for id := range after {
if _, ok := before[id]; !ok {
changed = append(changed, id)
}
}
return changed
}
// TestAffectedPeers_Property_ResolverSupersetsRealChanges builds a topology,
// applies random changes, and asserts that the resolver's affected set is a
// superset of the peers whose real network map actually changed. If the resolver
// ever misses a dependency, a change will alter a peer's map without that peer
// appearing in the affected set, failing here.
func TestAffectedPeers_Property_ResolverSupersetsRealChanges(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
// A pre-existing peer->resource policy so the resource/router bridge is live.
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
// Extra peers and groups to give mutations room to move membership around.
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "prop-key", types.SetupKeyReusable, 0, nil, 999, userID, false, false)
require.NoError(t, err)
extraPeers := make([]string, 0, 4)
for i := 0; i < 4; i++ {
p := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
extraPeers = append(extraPeers, p.ID)
}
extraGroups := []string{"prop-grp-0", "prop-grp-1"}
for _, g := range extraGroups {
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{ID: g, Name: g}))
}
rng := rand.New(rand.NewSource(1))
allGroups := append([]string{s.sourceGroupID, s.resourceGroupID, s.routerPeerGroupID}, extraGroups...)
allPeers := append([]string{s.sourcePeerID, s.routerPeerID, s.routerGroupPeerID, s.unrelatedPeerID}, extraPeers...)
for iter := 0; iter < 60; iter++ {
change, apply := s.randomMutation(t, rng, allGroups, allPeers)
if apply == nil {
continue
}
before := allPeerMaps(t, s.manager, s.accountID)
resolvedSet := make(map[string]struct{})
resolve := func() {
require.NoError(t, s.manager.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
snap, err := affectedpeers.Load(ctx, tx, s.accountID, change)
if err != nil {
return err
}
for _, id := range snap.Expand(ctx, s.accountID, change) {
resolvedSet[id] = struct{}{}
}
return nil
}))
}
// Resolve on both sides of the mutation and union: removals are visible
// only pre-apply (the leaving peer is still a member), additions only
// post-apply (the joining peer is now a member). Production captures both
// via per-path handling (e.g. UpdateGroup passes peersToRemove); the union
// models that without coupling the test to each path's ordering.
resolve()
changedIDs := change.ChangedPeerIDs
apply()
resolve()
after := allPeerMaps(t, s.manager, s.accountID)
// The explicitly-changed peer's own map refresh is the caller's
// responsibility (the resolver returns the peers to propagate to), so it
// is allowed to be absent from the resolved set.
changedExplicitly := make(map[string]struct{}, len(changedIDs))
for _, id := range changedIDs {
changedExplicitly[id] = struct{}{}
}
for _, id := range changedPeers(before, after) {
if _, stillExists := after[id]; !stillExists {
continue
}
if _, isExplicit := changedExplicitly[id]; isExplicit {
continue
}
_, ok := resolvedSet[id]
require.Truef(t, ok,
"iter %d: peer %s network map changed but was not in the resolver's affected set %v (change=%+v)",
iter, id, maps.Keys(resolvedSet), change)
}
}
}
// randomMutation picks a random change, returns the Change to resolve and a
// function that applies the underlying store mutation. apply is nil when the
// drawn mutation is a no-op for the current state.
func (s *routerScenario) randomMutation(t *testing.T, rng *rand.Rand, allGroups, allPeers []string) (affectedpeers.Change, func()) {
t.Helper()
ctx := context.Background()
switch rng.Intn(3) {
case 0:
groupID := allGroups[rng.Intn(len(allGroups))]
peerID := allPeers[rng.Intn(len(allPeers))]
grp, err := s.manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, s.accountID, groupID)
require.NoError(t, err)
if slicesContains(grp.Peers, peerID) {
return affectedpeers.Change{}, nil
}
return affectedpeers.Change{ChangedGroupIDs: []string{groupID}, ChangedPeerIDs: []string{peerID}},
func() {
require.NoError(t, s.manager.GroupAddPeer(ctx, s.accountID, groupID, peerID))
}
case 1:
groupID := allGroups[rng.Intn(len(allGroups))]
grp, err := s.manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, s.accountID, groupID)
require.NoError(t, err)
if len(grp.Peers) == 0 {
return affectedpeers.Change{}, nil
}
peerID := grp.Peers[rng.Intn(len(grp.Peers))]
return affectedpeers.Change{ChangedGroupIDs: []string{groupID}, ChangedPeerIDs: []string{peerID}},
func() {
require.NoError(t, s.manager.GroupDeletePeer(ctx, s.accountID, groupID, peerID))
}
default:
src := allGroups[rng.Intn(len(allGroups))]
dst := allGroups[rng.Intn(len(allGroups))]
policy := &types.Policy{
Enabled: true,
Name: fmt.Sprintf("prop-policy-%d", rng.Int()),
Rules: []*types.PolicyRule{{
Enabled: true,
Sources: []string{src},
Destinations: []string{dst},
Action: types.PolicyTrafficActionAccept,
}},
}
return affectedpeers.Change{Policies: []*types.Policy{policy}},
func() {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
require.NoError(t, err)
}
}
}
func slicesContains(s []string, v string) bool {
for _, x := range s {
if x == v {
return true
}
}
return false
}

View File

@@ -0,0 +1,164 @@
package server
import (
"context"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/affectedpeers"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// countingStore wraps a real store and counts the per-account collection loads
// the resolver performs, so a test can assert each is read at most once and that
// irrelevant collections are skipped entirely.
type countingStore struct {
store.Store
mu sync.Mutex
counts map[string]int
}
func newCountingStore(s store.Store) *countingStore {
return &countingStore{Store: s, counts: map[string]int{}}
}
func (c *countingStore) bump(name string) {
c.mu.Lock()
c.counts[name]++
c.mu.Unlock()
}
func (c *countingStore) count(name string) int {
c.mu.Lock()
defer c.mu.Unlock()
return c.counts[name]
}
func (c *countingStore) total() int {
c.mu.Lock()
defer c.mu.Unlock()
n := 0
for _, v := range c.counts {
n += v
}
return n
}
func (c *countingStore) GetAccountPolicies(ctx context.Context, ls store.LockingStrength, accountID string) ([]*types.Policy, error) {
c.bump("policies")
return c.Store.GetAccountPolicies(ctx, ls, accountID)
}
func (c *countingStore) GetAccountRoutes(ctx context.Context, ls store.LockingStrength, accountID string) ([]*route.Route, error) {
c.bump("routes")
return c.Store.GetAccountRoutes(ctx, ls, accountID)
}
func (c *countingStore) GetAccountNameServerGroups(ctx context.Context, ls store.LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
c.bump("nameservers")
return c.Store.GetAccountNameServerGroups(ctx, ls, accountID)
}
func (c *countingStore) GetAccountDNSSettings(ctx context.Context, ls store.LockingStrength, accountID string) (*types.DNSSettings, error) {
c.bump("dnssettings")
return c.Store.GetAccountDNSSettings(ctx, ls, accountID)
}
func (c *countingStore) GetNetworkRoutersByAccountID(ctx context.Context, ls store.LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) {
c.bump("routers")
return c.Store.GetNetworkRoutersByAccountID(ctx, ls, accountID)
}
func (c *countingStore) GetNetworkResourcesByAccountID(ctx context.Context, ls store.LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) {
c.bump("resources")
return c.Store.GetNetworkResourcesByAccountID(ctx, ls, accountID)
}
func (c *countingStore) GetAccountServices(ctx context.Context, ls store.LockingStrength, accountID string) ([]*rpservice.Service, error) {
c.bump("services")
return c.Store.GetAccountServices(ctx, ls, accountID)
}
// TestAffectedPeers_QueryCount_NoRedundantFullTableLoads asserts the resolver
// loads each per-account collection at most once per Resolve (memoization) even
// on a change that drives every bridge, and skips the services table when the
// account has no embedded proxy peers.
func TestAffectedPeers_QueryCount_NoRedundantFullTableLoads(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
cs := newCountingStore(s.manager.Store)
// A group change that exercises policies, routers, resources and the bridge.
change := affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}}
snap, err := affectedpeers.Load(ctx, cs, s.accountID, change)
require.NoError(t, err)
affected := snap.Expand(ctx, s.accountID, change)
assert.Contains(t, affected, s.routerPeerID, "bridge must still resolve the routing peer")
for _, name := range []string{"policies", "routes", "nameservers", "dnssettings", "routers", "resources"} {
assert.LessOrEqualf(t, cs.count(name), 1,
"%s must be loaded at most once per Resolve, got %d", name, cs.count(name))
}
assert.Equal(t, 0, cs.count("services"),
"services must not be loaded when the account has no embedded proxy peers")
}
// TestAffectedPeers_QueryCount_NarrowChangeSkipsLoads asserts that a change with
// no group/peer signal touches no per-account collections beyond what its inputs
// require.
func TestAffectedPeers_QueryCount_NarrowChangeSkipsLoads(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
cs := newCountingStore(s.manager.Store)
// A bare network change drives only the router->source bridge: routers and
// resources are needed, but routes/nameservers/dnssettings/services are not.
_, err := affectedpeers.Load(ctx, cs, s.accountID, affectedpeers.Change{Networks: []*networkTypes.Network{{ID: s.networkID}}})
require.NoError(t, err)
assert.Equal(t, 0, cs.count("routes"), "routes must not be loaded for a network-only change")
assert.Equal(t, 0, cs.count("nameservers"), "nameservers must not be loaded for a network-only change")
assert.Equal(t, 0, cs.count("dnssettings"), "dnssettings must not be loaded for a network-only change")
assert.Equal(t, 0, cs.count("services"), "services must not be loaded for a network-only change")
}
// TestAffectedPeers_QueryCount_ExpandReadsNothing is the core invariant of the
// Load/Expand split: Load (run inside the transaction) does all store reads;
// Expand (run after commit) must touch the store ZERO times, so it never holds
// the write lock and never reads post-commit state.
func TestAffectedPeers_QueryCount_ExpandReadsNothing(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
change := affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}}
cs := newCountingStore(s.manager.Store)
snap, err := affectedpeers.Load(ctx, cs, s.accountID, change)
require.NoError(t, err)
require.Greater(t, cs.total(), 0, "Load must read the store")
// Any store access during Expand would increment the same counter. Expand
// operates purely on the snapshot, so the count must not move.
readsAfterLoad := cs.total()
affected := snap.Expand(ctx, s.accountID, change)
assert.Contains(t, affected, s.routerPeerID, "Expand must still produce the affected peers from the snapshot")
assert.Equal(t, readsAfterLoad, cs.total(), "Expand must perform zero store reads — it operates purely on the loaded snapshot")
}

View File

@@ -0,0 +1,333 @@
package server
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/affectedpeers"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
)
func (s *routerScenario) resolveGroupChangeAffected(ctx context.Context, changedGroupIDs []string) []string {
change := affectedpeers.Change{ChangedGroupIDs: changedGroupIDs}
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
if err != nil {
return nil
}
return snap.Expand(ctx, s.accountID, change)
}
func (s *routerScenario) resolvePeerChangeAffected(ctx context.Context, changedPeerIDs []string) []string {
change := affectedpeers.Change{ChangedPeerIDs: changedPeerIDs}
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
if err != nil {
return nil
}
return snap.Expand(ctx, s.accountID, change)
}
func TestAffectedPeers_GroupChange_SourceGroupMembership_RefreshesRoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
assert.Contains(t, affected, s.sourcePeerID, "source group member must be affected")
assert.Contains(t, affected, s.routerPeerID,
"changing the source group of a peer->resource policy must refresh the resource's routing peer")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_GroupChange_SourceGroupMembership_RefreshesRoutingPeer_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
assert.Contains(t, affected, s.routerGroupPeerID,
"changing the source group must refresh the routing peer defined via router.PeerGroups")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_GroupChange_RouterPeerGroupMembership_RefreshesPolicySources(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolveGroupChangeAffected(ctx, []string{s.routerPeerGroupID})
assert.Contains(t, affected, s.routerGroupPeerID, "the routing peer itself must be affected")
assert.Contains(t, affected, s.sourcePeerID,
"changing the router's PeerGroups must refresh the source peers of policies serving the resource")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_PeerChange_SourcePeer_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
assert.Contains(t, affected, s.routerPeerID,
"a status change on a source peer must refresh the resource's routing peer that serves it")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_PeerChange_SourcePeer_ByDestinationResource_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
require.NoError(t, err)
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
assert.Contains(t, affected, s.routerPeerID,
"DestinationResource-targeted policy must still bridge a source-peer change to the routing peer")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_E2E_DeleteGroup_ResolvesAffectedPeers(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
const memberOnlyGroupID = "rs-memberonly-grp"
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: memberOnlyGroupID, Name: "rs-memberonly", Peers: []string{s.sourcePeerID},
}))
affected := s.resolveGroupChangeAffected(ctx, []string{memberOnlyGroupID})
assert.Empty(t, affected, "an unlinked group has no network-map impact, so no peer is affected")
require.NoError(t, s.manager.DeleteGroup(ctx, s.accountID, userID, memberOnlyGroupID))
}
func TestAffectedPeers_GroupAddResource_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
const extraResourceGroupID = "rs-resource-grp-extra"
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: extraResourceGroupID, Name: "rs-resource-extra",
}))
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, extraResourceGroupID), true)
require.NoError(t, err)
require.NoError(t, s.manager.GroupAddResource(ctx, s.accountID, extraResourceGroupID, types.Resource{
ID: s.resourceID,
Type: types.ResourceTypeHost,
}))
affected := s.resolveGroupChangeAffected(ctx, []string{extraResourceGroupID})
assert.Contains(t, affected, s.routerPeerID,
"attaching a resource to a policy destination group must refresh the resource's routing peer")
assert.Contains(t, affected, s.sourcePeerID, "policy source peers must refresh")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func (s *routerScenario) createPostureCheckGatedPolicy(t *testing.T, ctx context.Context) string {
t.Helper()
check, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
Name: "rs-min-version",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
},
}, true)
require.NoError(t, err)
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
policy.SourcePostureChecks = []string{check.ID}
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
require.NoError(t, err)
return check.ID
}
func TestAffectedPeers_E2E_SavePostureCheck_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
checkID := s.createPostureCheckGatedPolicy(t, ctx)
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
})
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
peerShouldNotReceiveUpdate(t, unrelatedCh)
close(done)
}()
_, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
ID: checkID,
Name: "rs-min-version",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.31.0"},
},
}, false)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: editing a posture check did not refresh source + routing peers")
}
}
func TestAffectedPeers_E2E_UpdateResource_DestinationResourcePolicy_RefreshesSourcePeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
require.NoError(t, err)
resourcesManager, _, _ := s.managers()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
})
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
peerShouldNotReceiveUpdate(t, unrelatedCh)
close(done)
}()
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/25",
GroupIDs: []string{s.resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: updating a DestinationResource-targeted resource did not refresh its policy source peer")
}
}
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
resourcesManager, routersManager, _ := s.managers()
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-disabled", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
disabledRouterPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
NetworkID: s.networkID,
AccountID: s.accountID,
Peer: disabledRouterPeer.ID,
Masquerade: true,
Metric: 9000,
Enabled: false,
})
require.NoError(t, err)
disabledCh := s.updateManager.CreateChannel(ctx, disabledRouterPeer.ID)
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID) })
settleAffectedUpdates(disabledCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, disabledCh)
close(done)
}()
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/25",
GroupIDs: []string{s.resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: resource update did not refresh the disabled sibling router's peer")
}
}
func TestAffectedPeers_GroupChange_RouterInOtherNetworkNotAffected(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "groupiso")
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
assert.NotContains(t, affected, second.routerPeerID,
"a router in an unrelated network must not be affected by a source-group change for another resource")
}
func TestAffectedPeers_PeerChange_RouterInOtherNetworkNotAffected(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "peeriso")
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
assert.NotContains(t, affected, second.routerPeerID,
"a router in an unrelated network must not be affected by a source-peer change for another resource")
}

View File

@@ -0,0 +1,771 @@
package server
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/server/networks/routers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// routerScenario captures the topology from the bug report:
//
// network ── router (routing peer) ── resource (in resourceGroup)
// independent peer ──(policy: source -> resource)──> resource
//
// The routing peer must be refreshed when a policy grants a source peer access
// to the resource, because the network map connects the source peer to the
// routing peer at compute time (Account.GetPoliciesForNetworkResource +
// addNetworksRoutingPeers). The routing peer is NOT a member of the resource
// group, so static group/peer resolution alone cannot find it.
type routerScenario struct {
manager *DefaultAccountManager
updateManager *update_channel.PeersUpdateManager
accountID string
networkID string
sourcePeerID string // independent peer that the policy grants access from
sourceGroupID string // group containing the source peer
routerPeerID string // peer acting as the routing peer (direct router.Peer)
routerGroupPeerID string // peer that is a member of routerPeerGroup
routerPeerGroupID string // group used for router.PeerGroups
resourceID string // network resource
resourceGroupID string // group whose member is the resource (no peers)
unrelatedPeerID string // peer in no relevant entity
}
// setupRouterScenario builds the topology above with the default policy removed
// and channels NOT yet created, so callers control exactly when updates can flow.
func setupRouterScenario(t *testing.T, directRouterPeer bool) *routerScenario {
t.Helper()
manager, updateManager, err := createManager(t)
require.NoError(t, err)
ctx := context.Background()
account, err := createAccount(manager, "router_scenario", userID, "")
require.NoError(t, err)
accountID := account.Id
// Remove the default policy so AddPeer/CreateGroup don't schedule unrelated updates.
policies, err := manager.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
require.NoError(t, err)
for _, p := range policies {
require.NoError(t, manager.Store.DeletePolicy(ctx, accountID, p.ID))
}
setupKey, err := manager.CreateSetupKey(ctx, accountID, "rs-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
sourcePeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
routerPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
routerGroupPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
unrelatedPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
const (
sourceGroupID = "rs-source-grp"
routerPeerGroupID = "rs-router-grp"
resourceGroupID = "rs-resource-grp"
)
for _, g := range []*types.Group{
{ID: sourceGroupID, Name: "rs-source", Peers: []string{sourcePeer.ID}},
{ID: routerPeerGroupID, Name: "rs-router", Peers: []string{routerGroupPeer.ID}},
{ID: resourceGroupID, Name: "rs-resource"}, // intentionally peerless; the resource is its only member
} {
require.NoError(t, manager.CreateGroup(ctx, accountID, userID, g))
}
permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
network, err := networksManager.CreateNetwork(ctx, userID, &networkTypes.Network{
ID: "rs-network",
AccountID: accountID,
Name: "rs-network",
})
require.NoError(t, err)
resource, err := resourcesManager.CreateResource(ctx, userID, &resourceTypes.NetworkResource{
AccountID: accountID,
NetworkID: network.ID,
Name: "rs-resource-host",
Address: "10.20.30.0/24",
GroupIDs: []string{resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
router := &routerTypes.NetworkRouter{
ID: "rs-router",
NetworkID: network.ID,
AccountID: accountID,
Masquerade: true,
Metric: 9999,
Enabled: true,
}
if directRouterPeer {
router.Peer = routerPeer.ID
} else {
router.PeerGroups = []string{routerPeerGroupID}
}
_, err = routersManager.CreateRouter(ctx, userID, router)
require.NoError(t, err)
return &routerScenario{
manager: manager,
updateManager: updateManager,
accountID: accountID,
networkID: network.ID,
sourcePeerID: sourcePeer.ID,
sourceGroupID: sourceGroupID,
routerPeerID: routerPeer.ID,
routerGroupPeerID: routerGroupPeer.ID,
routerPeerGroupID: routerPeerGroupID,
resourceID: resource.ID,
resourceGroupID: resourceGroupID,
unrelatedPeerID: unrelatedPeer.ID,
}
}
// peerToResourcePolicy builds a policy granting the source group access to the
// resource, referencing the resource by its group in the rule destination.
func peerToResourcePolicyByGroup(sourceGroupID, resourceGroupID string) *types.Policy {
return &types.Policy{
Enabled: true,
Name: "peer-to-resource-by-group",
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{sourceGroupID},
Destinations: []string{resourceGroupID},
Action: types.PolicyTrafficActionAccept,
},
},
}
}
// peerToResourcePolicyByResource builds a policy referencing the resource
// directly via DestinationResource rather than its group.
func peerToResourcePolicyByResource(sourceGroupID, resourceID string) *types.Policy {
return &types.Policy{
Enabled: true,
Name: "peer-to-resource-by-resource",
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{sourceGroupID},
DestinationResource: types.Resource{ID: resourceID, Type: types.ResourceTypeHost},
Action: types.PolicyTrafficActionAccept,
},
},
}
}
// resolvePolicyAffected mirrors SavePolicy's resolution: resolve the affected
// peers for the given policy.
func (s *routerScenario) resolvePolicyAffected(ctx context.Context, policy *types.Policy) []string {
change := affectedpeers.Change{Policies: []*types.Policy{policy}}
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
if err != nil {
return nil
}
return snap.Expand(ctx, s.accountID, change)
}
func TestAffectedPeers_SourcePeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
affected := s.resolvePolicyAffected(ctx, policy)
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
}
func TestAffectedPeers_RoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
affected := s.resolvePolicyAffected(ctx, policy)
// BUG: the direct routing peer serves the resource's subnet to the source
// peer, so it must be refreshed when the policy is created. The policy path
// only resolves the literal rule groups (source group + resource group);
// the resource group has no peer members and the router peer is reachable
// only through the network, so it is dropped.
assert.Contains(t, affected, s.routerPeerID,
"routing peer (router.Peer) serving the resource must be affected by a policy granting access to it")
}
func TestAffectedPeers_RoutingPeer_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
affected := s.resolvePolicyAffected(ctx, policy)
// Router defined via PeerGroups instead of a direct peer.
assert.Contains(t, affected, s.routerGroupPeerID,
"routing peer (router.PeerGroups member) serving the resource must be affected")
}
func TestAffectedPeers_DestResource_RoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
affected := s.resolvePolicyAffected(ctx, policy)
// When the resource is referenced via DestinationResource, RuleGroups()
// returns only the source group and the resource ID is not a peer, so
// collectPolicyAffectedGroupsAndPeers yields nothing for the destination at
// all. The routing peer is dropped here too.
assert.Contains(t, affected, s.routerPeerID,
"routing peer must be affected when the resource is referenced via DestinationResource")
}
func TestAffectedPeers_DestResource_RoutingPeer_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
affected := s.resolvePolicyAffected(ctx, policy)
assert.Contains(t, affected, s.routerGroupPeerID,
"routing peer (PeerGroups) must be affected when the resource is referenced via DestinationResource")
}
func TestAffectedPeers_SourceResourcePeer_RoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
// Source expressed as a direct peer (SourceResource), destination as resource group.
policy := &types.Policy{
Enabled: true,
Name: "sourceResource-peer-to-resource",
Rules: []*types.PolicyRule{
{
Enabled: true,
SourceResource: types.Resource{ID: s.sourcePeerID, Type: types.ResourceTypePeer},
Destinations: []string{s.resourceGroupID},
Action: types.PolicyTrafficActionAccept,
},
},
}
affected := s.resolvePolicyAffected(ctx, policy)
// The direct source peer IS picked up (collectPolicyAffectedGroupsAndPeers
// handles SourceResource peers), but the routing peer is still missing.
assert.Contains(t, affected, s.sourcePeerID, "direct source peer must be affected")
assert.Contains(t, affected, s.routerPeerID, "routing peer must be affected")
}
func TestAffectedPeers_PolicyToResource_UnrelatedPeerNotAffected(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
affected := s.resolvePolicyAffected(ctx, policy)
// Guard against an over-broad fix: a peer in no relevant entity must never
// be pulled in.
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_ResourceSideBridgesToRoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
// A pre-existing policy grants the source group access to the resource.
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
// Drive an update through the resource manager and assert the routing peer
// is among the affected set by observing the channel. This path walks
// policies whose destinations reference the resource's groups, folds in the
// source groups, and loads the network's routers, so it reaches both the
// source peer and the routing peer.
permissionsManager := permissions.NewManager(s.manager.Store)
groupsManager := groups.NewManager(s.manager.Store, permissionsManager, s.manager)
rm := resources.NewManager(s.manager.Store, permissionsManager, groupsManager, s.manager, s.manager.serviceManager)
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
_, err = rm.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/24",
GroupIDs: []string{s.resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: resource update did not refresh source peer + routing peer")
}
}
// settleAffectedUpdates waits for in-flight async updates to arrive, then drains
// every given channel so subsequent assertions start from a clean slate.
//
// Setup (CreateNetwork/CreateResource/CreateRouter) fires async UpdateAffectedPeers
// goroutines; draining first means the assertion only observes updates from the
// action under test, not setup stragglers.
func settleAffectedUpdates(chans ...<-chan *network_map.UpdateMessage) {
time.Sleep(300 * time.Millisecond)
for _, ch := range chans {
drainPeerUpdates(ch)
}
}
func TestAffectedPeers_E2E_CreatePolicy_RoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
})
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
peerShouldNotReceiveUpdate(t, unrelatedCh)
close(done)
}()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: creating peer->resource policy did not refresh the routing peer")
}
}
func TestAffectedPeers_E2E_CreatePolicy_RoutingPeer_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerGroupPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerGroupPeerID)
})
settleAffectedUpdates(srcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: routing peer (PeerGroups) not refreshed on policy create")
}
}
func TestAffectedPeers_E2E_DestResource_RoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
settleAffectedUpdates(srcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: routing peer not refreshed when policy targets DestinationResource")
}
}
func TestAffectedPeers_E2E_DeletePolicy_RoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
settleAffectedUpdates(srcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
require.NoError(t, s.manager.DeletePolicy(ctx, s.accountID, policy.ID, userID))
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: deleting peer->resource policy did not refresh the routing peer")
}
}
func (s *routerScenario) managers() (resources.Manager, routers.Manager, networks.Manager) {
permissionsManager := permissions.NewManager(s.manager.Store)
groupsManager := groups.NewManager(s.manager.Store, permissionsManager, s.manager)
resourcesManager := resources.NewManager(s.manager.Store, permissionsManager, groupsManager, s.manager, s.manager.serviceManager)
routersManager := routers.NewManager(s.manager.Store, permissionsManager, s.manager)
networksManager := networks.NewManager(s.manager.Store, permissionsManager, resourcesManager, routersManager, s.manager)
return resourcesManager, routersManager, networksManager
}
type secondTopology struct {
networkID string
resourceID string
resourceGroupID string
routerPeerID string
}
func (s *routerScenario) addSecondTopology(t *testing.T, suffix string) secondTopology {
t.Helper()
ctx := context.Background()
resourcesManager, routersManager, networksManager := s.managers()
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-"+suffix, types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
routerPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
resourceGroupID := "rs-resource-grp-" + suffix
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: resourceGroupID, Name: "rs-resource-" + suffix,
}))
network, err := networksManager.CreateNetwork(ctx, userID, &networkTypes.Network{
ID: "rs-network-" + suffix,
AccountID: s.accountID,
Name: "rs-network-" + suffix,
})
require.NoError(t, err)
resource, err := resourcesManager.CreateResource(ctx, userID, &resourceTypes.NetworkResource{
AccountID: s.accountID,
NetworkID: network.ID,
Name: "rs-resource-host-" + suffix,
Address: "10.40.50.0/24",
GroupIDs: []string{resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
NetworkID: network.ID,
AccountID: s.accountID,
Peer: routerPeer.ID,
Masquerade: true,
Metric: 9999,
Enabled: true,
})
require.NoError(t, err)
return secondTopology{
networkID: network.ID,
resourceID: resource.ID,
resourceGroupID: resourceGroupID,
routerPeerID: routerPeer.ID,
}
}
func TestAffectedPeers_E2E_UpdatePolicy_BothRoutingPeers(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "b")
ctx := context.Background()
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerACh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
routerBCh := s.updateManager.CreateChannel(ctx, second.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
s.updateManager.CloseChannel(ctx, second.routerPeerID)
})
settleAffectedUpdates(srcCh, routerACh, routerBCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerACh)
peerShouldReceiveUpdate(t, routerBCh)
close(done)
}()
policy.Rules[0].Destinations = []string{second.resourceGroupID}
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, false)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: re-pointing the policy destination did not refresh both routing peers")
}
}
func TestAffectedPeers_E2E_UpdatePolicy_AddSource(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
const secondSourceGroupID = "rs-source-grp-2"
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-2", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
secondSourcePeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: secondSourceGroupID, Name: "rs-source-2", Peers: []string{secondSourcePeer.ID},
}))
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
newSrcCh := s.updateManager.CreateChannel(ctx, secondSourcePeer.ID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, secondSourcePeer.ID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
settleAffectedUpdates(newSrcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, newSrcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
policy.Rules[0].Sources = []string{s.sourceGroupID, secondSourceGroupID}
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, false)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: adding a source group did not refresh the new source peer + routing peer")
}
}
func TestAffectedPeers_E2E_DestResource_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerGroupPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerGroupPeerID)
})
settleAffectedUpdates(srcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: DestinationResource policy with PeerGroups router did not refresh the routing peer")
}
}
func TestAffectedPeers_AllRoutingPeers_Network(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, routersManager, _ := s.managers()
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-r2", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
secondRouterPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
NetworkID: s.networkID,
AccountID: s.accountID,
Peer: secondRouterPeer.ID,
Masquerade: true,
Metric: 9998,
Enabled: true,
})
require.NoError(t, err)
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.routerPeerID, "first routing peer must be affected")
assert.Contains(t, affected, secondRouterPeer.ID, "second routing peer on the same network must also be affected")
}
func TestAffectedPeers_DisabledRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
routers, err := s.manager.Store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, s.accountID, s.networkID)
require.NoError(t, err)
require.Len(t, routers, 1)
routers[0].Enabled = false
require.NoError(t, s.manager.Store.UpdateNetworkRouter(ctx, routers[0]))
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
assert.Contains(t, affected, s.routerPeerID,
"disabled router's peer must still be affected: Enabled must not gate affected-peers")
}
func TestAffectedPeers_DisabledResource(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
res, err := s.manager.Store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, s.accountID, s.resourceID)
require.NoError(t, err)
res.Enabled = false
require.NoError(t, s.manager.Store.SaveNetworkResource(ctx, res))
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
assert.Contains(t, affected, s.routerPeerID,
"disabled resource must still resolve the routing peer: Enabled must not gate affected-peers")
}
func TestAffectedPeers_DisabledRule(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
policy.Rules[0].Enabled = false
affected := s.resolvePolicyAffected(ctx, policy)
assert.Contains(t, affected, s.routerPeerID,
"disabled rule must still resolve the routing peer: Enabled must not gate affected-peers")
}
func TestAffectedPeers_MultiRule(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "c")
ctx := context.Background()
policy := &types.Policy{
Enabled: true,
Name: "multi-rule-two-resources",
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{s.sourceGroupID},
Destinations: []string{s.resourceGroupID},
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{s.sourceGroupID},
Destinations: []string{second.resourceGroupID},
Action: types.PolicyTrafficActionAccept,
},
},
}
affected := s.resolvePolicyAffected(ctx, policy)
assert.Contains(t, affected, s.routerPeerID, "routing peer for resource A must be affected")
assert.Contains(t, affected, second.routerPeerID, "routing peer for resource B must be affected")
}
func TestAffectedPeers_RouterOtherNetwork(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "d")
ctx := context.Background()
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
assert.NotContains(t, affected, second.routerPeerID,
"a router in an unrelated network must not be affected by a policy that does not target its resource")
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,825 @@
// Package affectedpeers computes which peers' network maps a change touches, so
// only those peers are refreshed instead of the whole account.
//
// Two phases keep the dependency walk off the write transaction:
// - Load: reads the needed collections. Call INSIDE the mutating tx (consistent,
// and before a delete/removal severs the old state).
// - Snapshot.Expand: in-memory walk, no store access. Run AFTER the tx commits.
//
// Enabled is never consulted: toggling it is itself an observable change.
package affectedpeers
import (
"context"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// Snapshot is an in-memory view of the collections needed to expand a Change.
// Loaded in-tx, walked by Expand after commit. Only the collections the Change
// can touch are loaded; the rest stay nil (see Load).
type Snapshot struct {
policies []*types.Policy
routes []*route.Route
nsGroups []*nbdns.NameServerGroup
dnsSettings *types.DNSSettings
routers []*routerTypes.NetworkRouter
resources []*resourceTypes.NetworkResource
services []*rpservice.Service
proxyByCluster map[string][]string
groups map[string]*types.Group
groupPeers map[string]map[string]struct{} // groupID -> member peer IDs
}
// Load reads the collections a Change requires, inside the caller's tx. It mirrors
// Expand's walker preconditions, loading only what the change can touch.
func Load(ctx context.Context, s store.Store, accountID string, c Change) (*Snapshot, error) {
snap := &Snapshot{}
if c.isEmpty() {
return snap, nil
}
if err := snap.loadCollections(ctx, s, accountID, c); err != nil {
return nil, err
}
if err := snap.loadGroupIndex(ctx, s, accountID); err != nil {
return nil, err
}
return snap, nil
}
// loadCollections reads the policy/route/nameserver/dns/router/resource/proxy
// collections a Change can touch, gated to what the walk needs.
func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accountID string, c Change) error {
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.Resources) > 0
hasNetworkObject := len(c.Routers) > 0 || len(c.Resources) > 0 || len(c.Networks) > 0
// the resource<->router bridge can fire for any of these
needsRoutersResources := hasGroupOrPeerChange || len(c.PostureCheckIDs) > 0 || len(c.Policies) > 0 || hasNetworkObject
if needsRoutersResources {
if err := snap.loadPolicyRoutersResources(ctx, s, accountID); err != nil {
return err
}
}
if hasGroupOrPeerChange {
if err := snap.loadRoutesAndProxy(ctx, s, accountID); err != nil {
return err
}
}
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 {
if err := snap.loadDNS(ctx, s, accountID); err != nil {
return err
}
}
return nil
}
// loadPolicyRoutersResources loads the policies plus the routers and resources
// the resource<->router bridge walks.
func (snap *Snapshot) loadPolicyRoutersResources(ctx context.Context, s store.Store, accountID string) error {
var err error
if snap.policies, err = s.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID); err != nil {
return err
}
if snap.routers, err = s.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID); err != nil {
return err
}
snap.resources, err = s.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
return err
}
// loadRoutesAndProxy loads the routes and the embedded-proxy services index.
func (snap *Snapshot) loadRoutesAndProxy(ctx context.Context, s store.Store, accountID string) error {
var err error
if snap.routes, err = s.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID); err != nil {
return err
}
return snap.loadProxyServices(ctx, s, accountID)
}
// loadDNS loads the nameserver groups and account DNS settings.
func (snap *Snapshot) loadDNS(ctx context.Context, s store.Store, accountID string) error {
var err error
if snap.nsGroups, err = s.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID); err != nil {
return err
}
snap.dnsSettings, err = s.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
return err
}
// loadProxyServices loads the embedded-proxy cluster index, and the services only
// when the account actually has embedded proxy peers.
func (snap *Snapshot) loadProxyServices(ctx context.Context, s store.Store, accountID string) error {
var err error
if snap.proxyByCluster, err = s.GetEmbeddedProxyPeerIDsByCluster(ctx, accountID); err != nil {
return err
}
if len(snap.proxyByCluster) == 0 {
return nil
}
snap.services, err = s.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
return err
}
// loadGroupIndex loads all groups (for group.Resources) and builds the
// group->member-peers index. Always needed: the bridge resolves group.Resources
// and Expand maps groups to member peers.
func (snap *Snapshot) loadGroupIndex(ctx context.Context, s store.Store, accountID string) error {
groups, err := s.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
snap.groups = make(map[string]*types.Group, len(groups))
snap.groupPeers = make(map[string]map[string]struct{}, len(groups))
for _, g := range groups {
snap.groups[g.ID] = g
members := make(map[string]struct{}, len(g.Peers))
for _, pID := range g.Peers {
members[pID] = struct{}{}
}
snap.groupPeers[g.ID] = members
}
return nil
}
// Change describes what changed in an account.
type Change struct {
ChangedGroupIDs []string
ChangedPeerIDs []string
Policies []*types.Policy
Routes []*route.Route
Routers []*routerTypes.NetworkRouter
Resources []*resourceTypes.NetworkResource
Networks []*networkTypes.Network
PostureCheckIDs []string
// DistributionGroupIDs are groups whose members are directly affected, with no
// dependency walk — the change distributes config to the groups' member peers
// only (nameserver groups, DNS DisabledManagementGroups), not through the
// policy/route reachability graph. Pass oldnew so both states refresh.
DistributionGroupIDs []string
// RemovedPeersByGroup: peers that left a group, keyed by that group. They are no
// longer in the group's member index but still lose its reachability, so they are
// folded in — but only when the group is linked (an unlinked group has no map
// impact), matching how current members are handled.
RemovedPeersByGroup map[string][]string
}
func (c Change) isEmpty() bool {
return len(c.ChangedGroupIDs) == 0 &&
len(c.ChangedPeerIDs) == 0 &&
len(c.Policies) == 0 &&
len(c.Routes) == 0 &&
len(c.Routers) == 0 &&
len(c.Resources) == 0 &&
len(c.Networks) == 0 &&
len(c.PostureCheckIDs) == 0 &&
len(c.DistributionGroupIDs) == 0 &&
len(c.RemovedPeersByGroup) == 0
}
// Expand returns the deduplicated affected peer IDs from the preloaded Snapshot,
// no store access. Run after the producing tx commits. Logs the full walk at
// trace level for diagnosing a miscalculation.
func (snap *Snapshot) Expand(ctx context.Context, accountID string, c Change) []string {
if c.isEmpty() {
return nil
}
r := newResolver(ctx, snap, accountID, c)
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
r.walk()
return r.expand()
}
// Collect returns the affected group and direct-peer IDs without expanding groups
// to members. Test-only introspection; use Resolve otherwise.
func Collect(ctx context.Context, s store.Store, accountID string, c Change) (groupIDs []string, directPeerIDs []string) {
if c.isEmpty() {
return nil, nil
}
snap, err := Load(ctx, s, accountID, c)
if err != nil {
log.WithContext(ctx).Errorf("failed to load snapshot for affected peers collect: %v", err)
return nil, nil
}
r := newResolver(ctx, snap, accountID, c)
r.walk()
return setToSlice(r.groupSet), setToSlice(r.peerSet)
}
func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change) *resolver {
r := &resolver{
ctx: ctx,
snap: snap,
accountID: accountID,
change: c,
changedGroupSet: toSet(c.ChangedGroupIDs),
changedPeerSet: toSet(c.ChangedPeerIDs),
groupSet: make(map[string]struct{}),
peerSet: make(map[string]struct{}),
networkIDs: make(map[string]struct{}),
}
// Resolve each changed peer to its groups here so callers pass only ChangedPeerIDs.
r.seedChangedGroupsFromPeers()
r.matchedPolicies = append(r.matchedPolicies, c.Policies...)
return r
}
// seedChangedGroupsFromPeers adds each changed peer's groups to changedGroupSet so
// the group-driven walkers fire for memberships, not just direct peer references.
func (r *resolver) seedChangedGroupsFromPeers() {
if len(r.changedPeerSet) == 0 {
return
}
for groupID, members := range r.snap.groupPeers {
for pID := range r.changedPeerSet {
if _, ok := members[pID]; ok {
r.changedGroupSet[groupID] = struct{}{}
break
}
}
}
}
func (r *resolver) walk() {
r.collectFromExplicitPolicies()
r.collectFromExplicitRoutes(r.change.Routes)
r.collectFromExplicitRouters(r.change.Routers)
r.collectFromExplicitResources(r.change.Resources)
r.collectFromExplicitNetworks(r.change.Networks)
r.collectFromPostureChecks(r.change.PostureCheckIDs)
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
// straight into groupSet so expand() maps them to members, without the policy/
// route walk that changedGroupSet would trigger.
addAll(r.groupSet, r.change.DistributionGroupIDs)
if len(r.changedGroupSet) > 0 || len(r.changedPeerSet) > 0 {
r.collectFromPolicies()
r.collectFromRoutes()
r.collectFromNameServers()
r.collectFromDNSSettings()
r.collectFromNetworkRouters()
r.collectFromProxyServices()
}
r.collectResourceRouterBridge()
}
type resolver struct {
ctx context.Context
snap *Snapshot
accountID string
change Change
changedGroupSet map[string]struct{}
changedPeerSet map[string]struct{}
groupSet map[string]struct{}
peerSet map[string]struct{}
matchedPolicies []*types.Policy
networkIDs map[string]struct{}
}
func (r *resolver) policies() []*types.Policy { return r.snap.policies }
func (r *resolver) networkResources() []*resourceTypes.NetworkResource { return r.snap.resources }
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter { return r.snap.routers }
// peerIDsForGroups maps a group set to its member peer IDs via the preloaded index.
func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
seen := make(map[string]struct{})
var ids []string
for gID := range groupSet {
for pID := range r.snap.groupPeers[gID] {
if _, ok := seen[pID]; ok {
continue
}
seen[pID] = struct{}{}
ids = append(ids, pID)
}
}
return ids
}
func (r *resolver) expand() []string {
peerIDs := r.peerIDsForGroups(r.groupSet)
log.WithContext(r.ctx).Tracef("affectedpeers expand: account=%s affectedGroups=%v -> %d group-member peers; direct peers=%v",
r.accountID, setToSlice(r.groupSet), len(peerIDs), setToSlice(r.peerSet))
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for id := range r.peerSet {
if _, ok := seen[id]; !ok {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
// Fold in removed peers only when their group is linked (in groupSet).
for groupID, removed := range r.change.RemovedPeersByGroup {
if _, linked := r.groupSet[groupID]; !linked {
continue
}
for _, id := range removed {
if _, ok := seen[id]; !ok {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
log.WithContext(r.ctx).Tracef("affectedpeers expand: removed peer %s from linked group %s -> affected", id, groupID)
}
}
}
log.WithContext(r.ctx).Tracef("affectedpeers expand done: account=%s -> %d affected peers: %v", r.accountID, len(peerIDs), peerIDs)
return peerIDs
}
func (r *resolver) collectFromExplicitPolicies() {
for _, policy := range r.matchedPolicies {
if policy == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitPolicies: changed policy %s (%s) -> folding rule groups %v + direct peers",
policy.ID, policy.Name, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
}
}
func (r *resolver) collectFromExplicitRoutes(routes []*route.Route) {
for _, rt := range routes {
if rt == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
if rt.Peer != "" {
r.peerSet[rt.Peer] = struct{}{}
}
}
}
// collectFromExplicitRouters folds changed routers' peers and marks their networks
// for the bridge. Passing the old router keeps a repointed router's previous peers
// affected without a post-commit read.
func (r *resolver) collectFromExplicitRouters(routers []*routerTypes.NetworkRouter) {
for _, router := range routers {
if router == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitRouters: changed router %s on network %s -> folding peerGroups=%v peer=%q and marking network for source bridge",
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
addAll(r.groupSet, router.PeerGroups)
if router.Peer != "" {
r.peerSet[router.Peer] = struct{}{}
}
if router.NetworkID != "" {
r.networkIDs[router.NetworkID] = struct{}{}
}
}
}
// collectFromExplicitResources marks changed resources' networks for the bridge and
// treats their group IDs as changed, so policies targeting the resource via a
// now-detached (old) group still refresh.
func (r *resolver) collectFromExplicitResources(resources []*resourceTypes.NetworkResource) {
for _, resource := range resources {
if resource == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitResources: changed resource %s on network %s -> marking network for bridge and treating groups %v as changed",
resource.ID, resource.NetworkID, resource.GroupIDs)
addAll(r.changedGroupSet, resource.GroupIDs)
if resource.NetworkID != "" {
r.networkIDs[resource.NetworkID] = struct{}{}
}
}
}
// collectFromExplicitNetworks marks changed networks for the bridge. A network has
// no groups/peers of its own.
func (r *resolver) collectFromExplicitNetworks(networks []*networkTypes.Network) {
for _, network := range networks {
if network == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitNetworks: changed network %s -> marking for bridge", network.ID)
if network.ID != "" {
r.networkIDs[network.ID] = struct{}{}
}
}
}
func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
if len(postureCheckIDs) == 0 {
return
}
ids := toSet(postureCheckIDs)
for _, policy := range r.policies() {
if !policyReferencesPostureChecks(policy, ids) {
continue
}
log.WithContext(r.ctx).Tracef("collectFromPostureChecks: policy %s (%s) references changed posture checks %v -> folding rule groups %v + direct peers",
policy.ID, policy.Name, postureCheckIDs, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
r.matchedPolicies = append(r.matchedPolicies, policy)
}
}
func (r *resolver) collectFromPolicies() {
for _, policy := range r.policies() {
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
r.matchedPolicies = append(r.matchedPolicies, policy)
}
}
func (r *resolver) collectFromRoutes() {
for _, rt := range r.snap.routes {
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
matchedByPeer := rt.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(rt.Peer, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (byGroup=%t byPeer=%t) -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
rt.ID, matchedByGroup, matchedByPeer, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
if rt.Peer != "" {
r.peerSet[rt.Peer] = struct{}{}
}
}
}
func (r *resolver) collectFromNameServers() {
if len(r.changedGroupSet) == 0 {
return
}
for _, ns := range r.snap.nsGroups {
if anyInSet(ns.Groups, r.changedGroupSet) {
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a changed group -> folding its groups %v", ns.ID, ns.Groups)
addAll(r.groupSet, ns.Groups)
}
}
}
func (r *resolver) collectFromDNSSettings() {
if len(r.changedGroupSet) == 0 || r.snap.dnsSettings == nil {
return
}
for _, gID := range r.snap.dnsSettings.DisabledManagementGroups {
if _, ok := r.changedGroupSet[gID]; ok {
log.WithContext(r.ctx).Tracef("collectFromDNSSettings: changed group %s is in DisabledManagementGroups -> folding it", gID)
r.groupSet[gID] = struct{}{}
}
}
}
func (r *resolver) collectFromNetworkRouters() {
for _, router := range r.networkRouters() {
matchedByGroup := anyInSet(router.PeerGroups, r.changedGroupSet)
matchedByPeer := router.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(router.Peer, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding peerGroups=%v peer=%q and marking network for source bridge",
router.ID, router.NetworkID, matchedByGroup, matchedByPeer, router.PeerGroups, router.Peer)
addAll(r.groupSet, router.PeerGroups)
if router.Peer != "" {
r.peerSet[router.Peer] = struct{}{}
}
r.networkIDs[router.NetworkID] = struct{}{}
}
}
func (r *resolver) collectFromProxyServices() {
if len(r.snap.proxyByCluster) == 0 || len(r.snap.services) == 0 {
return
}
services, proxyByCluster := r.snap.services, r.snap.proxyByCluster
expanded := r.expandChangedPeersWithGroups()
for _, svc := range services {
if svc == nil {
continue
}
proxyPeers := proxyByCluster[svc.ProxyCluster]
if len(proxyPeers) == 0 {
continue
}
matchedByPeer := serviceMatchesChangedPeers(svc, proxyPeers, expanded)
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.changedGroupSet)
if !matchedByPeer && !matchedByAccessGroup {
continue
}
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets and access groups %v",
svc.ID, svc.ProxyCluster, matchedByPeer, matchedByAccessGroup, len(proxyPeers), svc.AccessGroups)
for _, pid := range proxyPeers {
r.peerSet[pid] = struct{}{}
}
for _, target := range svc.Targets {
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
r.peerSet[target.TargetId] = struct{}{}
}
}
addAll(r.groupSet, svc.AccessGroups)
}
}
func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
if len(r.changedGroupSet) == 0 {
return r.changedPeerSet
}
ids := r.peerIDsForGroups(r.changedGroupSet)
if len(ids) == 0 {
return r.changedPeerSet
}
merged := make(map[string]struct{}, len(r.changedPeerSet)+len(ids))
for id := range r.changedPeerSet {
merged[id] = struct{}{}
}
for _, id := range ids {
merged[id] = struct{}{}
}
return merged
}
// collectResourceRouterBridge crosses between source peers and routing peers, which
// are reachable only via resource -> network -> router, not through the policy's own
// groups: source -> router (targeted resources' networks), then router -> source.
func (r *resolver) collectResourceRouterBridge() {
r.bridgeSourceToRouters()
r.bridgeRoutersToSources()
}
func (r *resolver) bridgeSourceToRouters() {
resourceIDs := r.policyDestinationResourceIDs(r.matchedPolicies...)
if len(resourceIDs) == 0 {
return
}
networkIDs := r.resourceNetworkIDs(resourceIDs)
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
setToSlice(resourceIDs), setToSlice(networkIDs))
for id := range networkIDs {
r.networkIDs[id] = struct{}{}
}
}
func (r *resolver) bridgeRoutersToSources() {
if len(r.networkIDs) == 0 {
return
}
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: affected networks %v -> folding their routing peers and the source peers of policies targeting their resources",
setToSlice(r.networkIDs))
r.foldRoutersOnNetworks(r.networkIDs)
resourceIDs := make(map[string]struct{})
for _, resource := range r.networkResources() {
if _, ok := r.networkIDs[resource.NetworkID]; ok {
resourceIDs[resource.ID] = struct{}{}
}
}
if len(resourceIDs) == 0 {
return
}
for _, policy := range r.policies() {
if r.policyTargetsResources(policy, resourceIDs) {
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: policy %s (%s) targets an affected-network resource -> folding its source groups/peers", policy.ID, policy.Name)
collectPolicySources(policy, r.groupSet, r.peerSet)
}
}
}
func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
for _, router := range r.networkRouters() {
if _, ok := networkIDs[router.NetworkID]; !ok {
continue
}
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: router %s serves affected network %s -> folding peerGroups=%v peer=%q",
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
addAll(r.groupSet, router.PeerGroups)
if router.Peer != "" {
r.peerSet[router.Peer] = struct{}{}
}
}
}
func (r *resolver) resourceNetworkIDs(resourceIDs map[string]struct{}) map[string]struct{} {
networkIDs := make(map[string]struct{})
for _, resource := range r.networkResources() {
if _, ok := resourceIDs[resource.ID]; ok {
networkIDs[resource.NetworkID] = struct{}{}
}
}
return networkIDs
}
func (r *resolver) policyTargetsResources(policy *types.Policy, resourceIDs map[string]struct{}) bool {
if policy == nil {
return false
}
destGroupSet := make(map[string]struct{})
for _, rule := range policy.Rules {
if rule.DestinationResource.Type != types.ResourceTypePeer && isInSet(rule.DestinationResource.ID, resourceIDs) {
return true
}
for _, gID := range rule.Destinations {
destGroupSet[gID] = struct{}{}
}
}
if len(destGroupSet) == 0 {
return false
}
for gID := range destGroupSet {
group := r.snap.groups[gID]
if group == nil {
continue
}
for _, res := range group.Resources {
if isInSet(res.ID, resourceIDs) {
return true
}
}
}
return false
}
func (r *resolver) policyDestinationResourceIDs(policies ...*types.Policy) map[string]struct{} {
resourceIDs := make(map[string]struct{})
destGroupSet := collectPolicyDestinations(resourceIDs, policies...)
r.addGroupResourceIDs(destGroupSet, resourceIDs)
return resourceIDs
}
// collectPolicyDestinations adds direct destination resource IDs to resourceIDs and
// returns the referenced destination group IDs.
func collectPolicyDestinations(resourceIDs map[string]struct{}, policies ...*types.Policy) map[string]struct{} {
destGroupSet := make(map[string]struct{})
for _, policy := range policies {
if policy == nil {
continue
}
for _, rule := range policy.Rules {
addAll(destGroupSet, rule.Destinations)
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID != "" {
resourceIDs[rule.DestinationResource.ID] = struct{}{}
}
}
}
return destGroupSet
}
// addGroupResourceIDs folds the resource IDs of the given groups into resourceIDs.
func (r *resolver) addGroupResourceIDs(groupIDs map[string]struct{}, resourceIDs map[string]struct{}) {
for gID := range groupIDs {
group := r.snap.groups[gID]
if group == nil {
continue
}
for _, res := range group.Resources {
if res.ID != "" {
resourceIDs[res.ID] = struct{}{}
}
}
}
}
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
}
func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]struct{}) {
for _, rule := range policy.Rules {
addAll(groupSet, rule.Sources)
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
}
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
return true
}
}
return false
}
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
return true
}
}
return false
}
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
for _, id := range policy.SourcePostureChecks {
if _, ok := ids[id]; ok {
return true
}
}
return false
}
func isDirectPeerInSet(res types.Resource, set map[string]struct{}) bool {
if res.Type != types.ResourceTypePeer || res.ID == "" {
return false
}
_, ok := set[res.ID]
return ok
}
func serviceMatchesChangedPeers(svc *rpservice.Service, proxyPeers []string, changedPeers map[string]struct{}) bool {
for _, pid := range proxyPeers {
if _, ok := changedPeers[pid]; ok {
return true
}
}
for _, target := range svc.Targets {
if target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
continue
}
if _, ok := changedPeers[target.TargetId]; ok {
return true
}
}
return false
}
func anyInSet(ids []string, set map[string]struct{}) bool {
for _, id := range ids {
if _, ok := set[id]; ok {
return true
}
}
return false
}
func isInSet(id string, set map[string]struct{}) bool {
_, ok := set[id]
return ok
}
func addAll(set map[string]struct{}, slices ...[]string) {
for _, s := range slices {
for _, id := range s {
set[id] = struct{}{}
}
}
}
func toSet(ids []string) map[string]struct{} {
set := make(map[string]struct{}, len(ids))
for _, id := range ids {
set[id] = struct{}{}
}
return set
}
func setToSlice(set map[string]struct{}) []string {
s := make([]string, 0, len(set))
for id := range set {
s = append(s, id)
}
return s
}

View File

@@ -0,0 +1,140 @@
package affectedpeers
import (
"testing"
"github.com/stretchr/testify/assert"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/types"
)
// policyGroupsAndPeers mirrors the explicit-policy extraction (RuleGroups +
// direct peers) the resolver folds in, for asserting the pure logic.
func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []string) {
peerSet := map[string]struct{}{}
for _, p := range policies {
if p == nil {
continue
}
groups = append(groups, p.RuleGroups()...)
collectPolicyDirectPeers(p, peerSet)
}
for id := range peerSet {
peers = append(peers, id)
}
return groups, peers
}
func TestPolicyGroupsAndPeers_Basic(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
groups, peers := policyGroupsAndPeers(policy)
assert.ElementsMatch(t, []string{"g1", "g2", "g3"}, groups)
assert.Empty(t, peers)
}
func TestPolicyGroupsAndPeers_WithPeerResources(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
Sources: []string{"g1"},
SourceResource: types.Resource{ID: "p1", Type: types.ResourceTypePeer},
Destinations: []string{"g2"},
DestinationResource: types.Resource{ID: "p2", Type: types.ResourceTypePeer},
}}}
groups, peers := policyGroupsAndPeers(policy)
assert.ElementsMatch(t, []string{"g1", "g2"}, groups)
assert.ElementsMatch(t, []string{"p1", "p2"}, peers)
}
func TestPolicyGroupsAndPeers_NilPolicy(t *testing.T) {
groups, peers := policyGroupsAndPeers(nil)
assert.Nil(t, groups)
assert.Nil(t, peers)
}
func TestPolicyGroupsAndPeers_MultiplePolicies(t *testing.T) {
old := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1"}, Destinations: []string{"g2"}}}}
updated := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g3"}, Destinations: []string{"g4"}}}}
groups, _ := policyGroupsAndPeers(updated, old)
assert.ElementsMatch(t, []string{"g1", "g2", "g3", "g4"}, groups)
}
func TestPolicyGroupsAndPeers_NonPeerResource(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
Sources: []string{"g1"},
SourceResource: types.Resource{ID: "domain-1", Type: types.ResourceTypeDomain},
Destinations: []string{"g2"},
}}}
groups, peers := policyGroupsAndPeers(policy)
assert.ElementsMatch(t, []string{"g1", "g2"}, groups)
assert.Empty(t, peers, "domain resource type should not produce direct peer IDs")
}
func TestChangeIsEmpty(t *testing.T) {
assert.True(t, Change{}.isEmpty())
assert.False(t, Change{ChangedGroupIDs: []string{"g"}}.isEmpty())
assert.False(t, Change{ChangedPeerIDs: []string{"p"}}.isEmpty())
assert.False(t, Change{Policies: []*types.Policy{{}}}.isEmpty())
assert.False(t, Change{Resources: []*resourceTypes.NetworkResource{{ID: "r"}}}.isEmpty())
assert.False(t, Change{Networks: []*networkTypes.Network{{ID: "n"}}}.isEmpty())
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
}
func TestPolicyReferencesGroups(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
}
func TestPolicyReferencesDirectPeers(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
}}}
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
}
func TestPolicyReferencesPostureChecks(t *testing.T) {
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}
assert.True(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc1": {}}))
assert.False(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc3": {}}))
}
func TestCollectPolicyDirectPeers(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
DestinationResource: types.Resource{Type: types.ResourceTypePeer, ID: "p2"},
}, {
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
}}}
peerSet := map[string]struct{}{}
collectPolicyDirectPeers(policy, peerSet)
assert.Contains(t, peerSet, "p1")
assert.Contains(t, peerSet, "p2")
assert.NotContains(t, peerSet, "r1")
}
func TestCollectPolicySources(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
Sources: []string{"g1"},
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
Destinations: []string{"g2"},
}}}
groupSet := map[string]struct{}{}
peerSet := map[string]struct{}{}
collectPolicySources(policy, groupSet, peerSet)
assert.Contains(t, groupSet, "g1")
assert.NotContains(t, groupSet, "g2", "destination groups must not be collected as sources")
assert.Contains(t, peerSet, "p1")
}

View File

@@ -8,6 +8,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
@@ -47,8 +48,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var eventsToStore []func()
var snap *affectedpeers.Snapshot
var change affectedpeers.Change
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
@@ -63,11 +65,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
if err != nil {
return err
}
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...)
@@ -75,6 +72,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return err
}
change = affectedpeers.Change{DistributionGroupIDs: slices.Concat(addedGroups, removedGroups)}
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -85,9 +87,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceDNSSettings, Operation: types.UpdateOperationUpdate})
}
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return nil
}
@@ -133,20 +133,6 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
return eventsToStore
}
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups)
}
// validateDNSSettings validates the DNS settings.
func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error {
if len(settings.DisabledManagementGroups) == 0 {

View File

@@ -298,11 +298,11 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
return nil, err
}
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false)
savedPeer1, _, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false)
if err != nil {
return nil, err
}
_, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false)
_, _, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false)
if err != nil {
return nil, err
}

View File

@@ -11,6 +11,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -79,7 +80,8 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
var eventsToStore []func()
var updateAccountPeers bool
var snap *affectedpeers.Snapshot
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
@@ -91,11 +93,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
}
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err)
}
@@ -106,6 +103,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
}
snap, err = affectedpeers.Load(ctx, transaction, accountID, change)
if err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -116,9 +118,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
}
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return nil
}
@@ -134,7 +134,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
var eventsToStore []func()
var updateAccountPeers bool
var snap *affectedpeers.Snapshot
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
@@ -153,20 +154,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
peersToAdd := util.Difference(newGroup.Peers, oldGroup.Peers)
peersToRemove := util.Difference(oldGroup.Peers, newGroup.Peers)
for _, peerID := range peersToAdd {
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil {
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
}
}
for _, peerID := range peersToRemove {
if err := transaction.RemovePeerFromGroup(ctx, peerID, newGroup.ID); err != nil {
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, newGroup.ID, err)
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
if err = syncGroupMembership(ctx, transaction, accountID, newGroup.ID, peersToAdd, peersToRemove); err != nil {
return err
}
@@ -178,6 +166,17 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err
}
// A membership change does not alter which entities reference the group, so
// the dependency walk runs once against the post-change snapshot. The new
// members are already in the snapshot's index; the removed members are
// carried separately and folded in only when the group is linked.
if len(peersToRemove) > 0 {
change.RemovedPeersByGroup = map[string][]string{newGroup.ID: peersToRemove}
}
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -188,13 +187,26 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return nil
}
// syncGroupMembership applies the peer membership delta for a group within a transaction.
func syncGroupMembership(ctx context.Context, transaction store.Store, accountID, groupID string, peersToAdd, peersToRemove []string) error {
for _, peerID := range peersToAdd {
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, groupID, err)
}
}
for _, peerID := range peersToRemove {
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, groupID, err)
}
}
return nil
}
// CreateGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
@@ -209,11 +221,14 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
var updateAccountPeers bool
var snaps []*affectedpeers.Snapshot
var changes []affectedpeers.Change
var globalErr error
groupIDs := make([]string, 0, len(groups))
createdCount := 0
for _, newGroup := range groups {
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
var snap *affectedpeers.Snapshot
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
@@ -230,35 +245,31 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
return err
}
groupIDs = append(groupIDs, newGroup.ID)
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
return nil
snap, err = affectedpeers.Load(ctx, transaction, accountID, change)
return err
})
if err != nil {
log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
if len(groupIDs) == 1 {
if createdCount == 0 {
return err
}
globalErr = errors.Join(globalErr, err)
// continue updating other groups
continue
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
createdCount++
snaps = append(snaps, snap)
changes = append(changes, change)
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
}
go am.dispatchAffected(ctx, accountID, snaps, changes)
return globalErr
}
@@ -277,12 +288,13 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
var updateAccountPeers bool
var snaps []*affectedpeers.Snapshot
var changes []affectedpeers.Change
var globalErr error
groupIDs := make([]string, 0, len(groups))
for _, newGroup := range groups {
events, err := am.updateSingleGroup(ctx, accountID, userID, newGroup)
change := affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}}
events, snap, err := am.updateSingleGroup(ctx, accountID, userID, newGroup, change)
if err != nil {
log.WithContext(ctx).Errorf("failed to update group %s: %v", newGroup.ID, err)
if len(groups) == 1 {
@@ -292,27 +304,22 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
continue
}
eventsToStore = append(eventsToStore, events...)
groupIDs = append(groupIDs, newGroup.ID)
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
snaps = append(snaps, snap)
changes = append(changes, change)
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
go am.dispatchAffected(ctx, accountID, snaps, changes)
return globalErr
}
func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) ([]func(), error) {
func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountID, userID string, newGroup *types.Group, change affectedpeers.Change) ([]func(), *affectedpeers.Snapshot, error) {
var events []func()
var snap *affectedpeers.Snapshot
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
return err
@@ -333,9 +340,12 @@ func (am *DefaultAccountManager) updateSingleGroup(ctx context.Context, accountI
}
events = am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
return nil
var err error
snap, err = affectedpeers.Load(ctx, transaction, accountID, change)
return err
})
return events, err
return events, snap, err
}
// prepareGroupEvents prepares a list of event functions to be stored.
@@ -438,6 +448,8 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
var allErrors error
var groupIDsToDelete []string
var deletedGroups []*types.Group
var snap *affectedpeers.Snapshot
var change affectedpeers.Change
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
@@ -445,26 +457,23 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group)
deletedGroups, allErrors = collectDeletableGroups(ctx, transaction, accountID, userID, groupIDs, extraSettings.FlowGroups)
for _, group := range deletedGroups {
groupIDsToDelete = append(groupIDsToDelete, group.ID)
}
if len(groupIDsToDelete) == 0 {
return allErrors
}
// Delete: compute affected peers from the PRE-delete state. The groups,
// their members and the entities referencing them still exist, so a plain
// Load+Expand captures everyone — no removed-peer folding needed.
change = affectedpeers.Change{ChangedGroupIDs: groupIDsToDelete}
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err
}
if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil {
return err
}
@@ -483,25 +492,47 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
}
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return allErrors
}
// collectDeletableGroups loads and validates each group for deletion, returning
// the groups that may be deleted and the joined validation errors for the rest.
func collectDeletableGroups(ctx context.Context, transaction store.Store, accountID, userID string, groupIDs, flowGroups []string) ([]*types.Group, error) {
var deletable []*types.Group
var allErrors error
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
if err = validateDeleteGroup(ctx, transaction, group, userID, flowGroups); err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
deletable = append(deletable, group)
}
return deletable, allErrors
}
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
var updateAccountPeers bool
var err error
var snap *affectedpeers.Snapshot
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
return err
}
if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
return err
}
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
var err error
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err
}
@@ -511,9 +542,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return nil
}
@@ -521,8 +550,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupAddResource appends resource to the group
func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
var group *types.Group
var updateAccountPeers bool
var snap *affectedpeers.Snapshot
var err error
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
@@ -534,12 +564,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
if err = transaction.UpdateGroup(ctx, group); err != nil {
return err
}
if err = transaction.UpdateGroup(ctx, group); err != nil {
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err
}
@@ -549,29 +578,32 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return nil
}
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
var updateAccountPeers bool
var err error
var snap *affectedpeers.Snapshot
change := affectedpeers.Change{
ChangedGroupIDs: []string{groupID},
RemovedPeersByGroup: map[string][]string{groupID: {peerID}},
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
return err
}
if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
if err := am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
return err
}
if err = am.reconcileIPv6ForGroupChanges(ctx, transaction, accountID, []string{groupID}); err != nil {
// The removed peer is carried in change.RemovedPeersByGroup and folded in
// only when the group is linked, so loading post-removal is correct.
var err error
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err
}
@@ -581,9 +613,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return nil
}
@@ -591,8 +621,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
// GroupDeleteResource removes resource from the group
func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
var group *types.Group
var updateAccountPeers bool
var snap *affectedpeers.Snapshot
var err error
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID)
@@ -604,8 +635,9 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
// Load before persisting the removal, so the snapshot still maps the group
// to the resource and the bridge can reach its routing peers.
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err
}
@@ -619,9 +651,7 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
}
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return nil
}
@@ -832,49 +862,103 @@ func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store,
}
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
// It fetches each collection once and checks all groupIDs against them in memory.
func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
groupSet := make(map[string]struct{}, len(groupIDs))
for _, id := range groupIDs {
groupSet[id] = struct{}{}
}
if affected, err := dnsSettingsReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
if affected, err := nameServersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
if affected, err := policiesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
if affected, err := routesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
if affected, err := networkRoutersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
return false, nil
}
func dnsSettingsReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, groupID := range groupIDs {
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToNetworkRouter(ctx, transaction, accountID, groupID); linked {
return true, nil
}
}
return false, nil
return anyInSet(dnsSettings.DisabledManagementGroups, groupSet), nil
}
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
func nameServersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, group := range groups {
if group.HasPeers() || group.HasResources() {
for _, ns := range nameServerGroups {
if anyInSet(ns.Groups, groupSet) {
return true, nil
}
}
return false, nil
}
func policiesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, policy := range policies {
for _, rule := range policy.Rules {
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
return true, nil
}
}
}
return false, nil
}
func routesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, r := range routes {
if anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet) {
return true, nil
}
}
return false, nil
}
func networkRoutersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, router := range routers {
if anyInSet(router.PeerGroups, groupSet) {
return true, nil
}
}
return false, nil
}
func anyInSet(ids []string, set map[string]struct{}) bool {
for _, id := range ids {
if _, ok := set[id]; ok {
return true
}
}
return false
}

View File

@@ -55,7 +55,7 @@ func TestGroupIPv6Assignment(t *testing.T) {
key, err := wgtypes.GeneratePrivateKey()
require.NoError(t, err)
peer, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{
peer, _, _, _, err := am.AddPeer(ctx, "", setupKey.Key, "", &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "ipv6-test-host"},
}, false)

View File

@@ -479,7 +479,7 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
return
}
peer, _, _, err := h.accountManager.AddPeer(r.Context(), userAuth.AccountId, "", userAuth.UserId, newPeer, true)
peer, _, _, _, err := h.accountManager.AddPeer(r.Context(), userAuth.AccountId, "", userAuth.UserId, newPeer, true)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -728,7 +728,7 @@ func Test_LoginPerformance(t *testing.T) {
}
login := func() error {
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
_, _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return err
@@ -746,7 +746,7 @@ func Test_LoginPerformance(t *testing.T) {
go func(peerLogin types.PeerLogin, counterStart *int32) {
defer wgPeer.Done()
_, _, _, err = am.LoginPeer(context.Background(), peerLogin)
_, _, _, _, err = am.LoginPeer(context.Background(), peerLogin)
if err != nil {
t.Logf("failed to login peer: %v", err)
return

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
@@ -38,13 +39,13 @@ type MockAccountManager struct {
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error)
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByNameFunc func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
@@ -97,7 +98,7 @@ type MockAccountManager struct {
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
ExtendPeerSessionFunc func(ctx context.Context, peerPubKey, userID string) (time.Time, error)
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
@@ -132,6 +133,7 @@ type MockAccountManager struct {
AllowSyncFunc func(string, uint64) bool
UpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
ExpandAndUpdateAffectedFunc func(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
@@ -209,6 +211,12 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID
}
}
func (am *MockAccountManager) ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change) {
if am.ExpandAndUpdateAffectedFunc != nil {
am.ExpandAndUpdateAffectedFunc(ctx, accountID, snap, change)
}
}
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
if am.BufferUpdateAccountPeersFunc != nil {
am.BufferUpdateAccountPeersFunc(ctx, accountID, reason)
@@ -337,9 +345,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt)
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
@@ -416,11 +424,11 @@ func (am *MockAccountManager) AddPeer(
userId string,
peer *nbpeer.Peer,
temporary bool,
) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
if am.AddPeerFunc != nil {
return am.AddPeerFunc(ctx, accountID, setupKey, userId, peer, temporary)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
return nil, nil, nil, false, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
}
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
@@ -854,11 +862,11 @@ func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, account
}
// LoginPeer mocks LoginPeer of the AccountManager interface
func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
if am.LoginPeerFunc != nil {
return am.LoginPeerFunc(ctx, login)
}
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
return nil, nil, nil, false, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
}
// ExtendPeerSession mocks ExtendPeerSession of the AccountManager interface

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"slices"
"strings"
"unicode/utf8"
@@ -11,6 +12,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
@@ -57,19 +59,19 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
SearchDomainsEnabled: searchDomainEnabled,
}
var updateAccountPeers bool
var snap *affectedpeers.Snapshot
change := affectedpeers.Change{DistributionGroupIDs: newNSGroup.Groups}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
return err
}
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, newNSGroup.Groups)
if err != nil {
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
return err
}
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err
}
@@ -81,9 +83,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationCreate})
}
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return newNSGroup.Copy(), nil
}
@@ -102,7 +102,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var snap *affectedpeers.Snapshot
var change affectedpeers.Change
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
@@ -115,12 +116,12 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err
}
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
if err != nil {
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
return err
}
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
change = affectedpeers.Change{DistributionGroupIDs: slices.Concat(nsGroupToSave.Groups, oldNSGroup.Groups)}
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err
}
@@ -132,9 +133,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationUpdate})
}
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return nil
}
@@ -150,7 +149,8 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
}
var nsGroup *nbdns.NameServerGroup
var updateAccountPeers bool
var snap *affectedpeers.Snapshot
var change affectedpeers.Change
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID)
@@ -158,8 +158,9 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err
}
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, nsGroup.Groups)
if err != nil {
// Load before delete: the post-delete state no longer references the groups.
change = affectedpeers.Change{DistributionGroupIDs: nsGroup.Groups}
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err
}
@@ -175,9 +176,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationDelete})
}
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return nil
}
@@ -224,24 +223,6 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return validateGroups(nameserverGroup.Groups, groups)
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false, nil
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
}
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
if !primary && len(domains) == 0 {
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+

View File

@@ -896,11 +896,11 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account,
return nil, err
}
_, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer1, false)
_, _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer1, false)
if err != nil {
return nil, err
}
_, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer2, false)
_, _, _, _, err = am.AddPeer(context.Background(), "", "", userID, peer2, false)
if err != nil {
return nil, err
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/types"
@@ -15,7 +16,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
serverTypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -127,30 +127,39 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
}
var eventsToStore []func()
var snap *affectedpeers.Snapshot
change := affectedpeers.Change{Networks: []*types.Network{network}}
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
if err != nil {
return fmt.Errorf("failed to get resources in network: %w", err)
}
for _, resource := range resources {
event, err := m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resource.ID)
if err != nil {
return fmt.Errorf("failed to delete resource: %w", err)
}
eventsToStore = append(eventsToStore, event...)
}
routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
netRouters, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
if err != nil {
return fmt.Errorf("failed to get routers in network: %w", err)
}
for _, router := range routers {
event, err := m.routersManager.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, router.ID)
var lerr error
if snap, lerr = affectedpeers.Load(ctx, transaction, accountID, change); lerr != nil {
return lerr
}
for _, resource := range resources {
deleted, event, err := m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resource.ID)
if err != nil {
return fmt.Errorf("failed to delete resource: %w", err)
}
change.Resources = append(change.Resources, deleted)
eventsToStore = append(eventsToStore, event...)
}
for _, router := range netRouters {
deleted, event, err := m.routersManager.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, router.ID)
if err != nil {
return fmt.Errorf("failed to delete router: %w", err)
}
change.Routers = append(change.Routers, deleted)
eventsToStore = append(eventsToStore, event)
}
@@ -178,7 +187,7 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
event()
}
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetwork, Operation: serverTypes.UpdateOperationDelete})
m.accountManager.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return nil
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/server/permissions"
@@ -29,7 +30,7 @@ type Manager interface {
GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error)
UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error)
DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error
DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error)
DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) (*types.NetworkResource, []func(), error)
}
type managerImpl struct {
@@ -114,45 +115,12 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
}
var eventsToStore []func()
var snap *affectedpeers.Snapshot
change := affectedpeers.Change{Resources: []*types.NetworkResource{resource}}
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil {
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
}
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
err = transaction.SaveNetworkResource(ctx, resource)
if err != nil {
return fmt.Errorf("failed to save network resource: %w", err)
}
event := func() {
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network))
}
eventsToStore = append(eventsToStore, event)
res := nbtypes.Resource{
ID: resource.ID,
Type: nbtypes.ResourceType(resource.Type.String()),
}
for _, groupID := range resource.GroupIDs {
event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res)
if err != nil {
return fmt.Errorf("failed to add resource to group: %w", err)
}
eventsToStore = append(eventsToStore, event)
}
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
var txErr error
eventsToStore, snap, txErr = m.createResourceInTransaction(ctx, transaction, userID, resource, change)
return txErr
})
if err != nil {
return nil, fmt.Errorf("failed to create network resource: %w", err)
@@ -162,11 +130,55 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
event()
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationCreate})
m.accountManager.ExpandAndUpdateAffected(ctx, resource.AccountID, snap, change)
return resource, nil
}
func (m *managerImpl) createResourceInTransaction(ctx context.Context, transaction store.Store, userID string, resource *types.NetworkResource, change affectedpeers.Change) ([]func(), *affectedpeers.Snapshot, error) {
_, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil {
return nil, nil, status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
}
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
if err != nil {
return nil, nil, fmt.Errorf("failed to get network: %w", err)
}
if err = transaction.SaveNetworkResource(ctx, resource); err != nil {
return nil, nil, fmt.Errorf("failed to save network resource: %w", err)
}
var eventsToStore []func()
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network))
})
res := nbtypes.Resource{
ID: resource.ID,
Type: nbtypes.ResourceType(resource.Type.String()),
}
for _, groupID := range resource.GroupIDs {
event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res)
if err != nil {
return nil, nil, fmt.Errorf("failed to add resource to group: %w", err)
}
eventsToStore = append(eventsToStore, event)
}
if err = transaction.IncrementNetworkSerial(ctx, resource.AccountID); err != nil {
return nil, nil, fmt.Errorf("failed to increment network serial: %w", err)
}
snap, err := affectedpeers.Load(ctx, transaction, resource.AccountID, change)
if err != nil {
return nil, nil, err
}
return eventsToStore, snap, nil
}
func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
if err != nil {
@@ -207,6 +219,8 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
resource.Prefix = prefix
var eventsToStore []func()
var snap *affectedpeers.Snapshot
var change affectedpeers.Change
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
if err != nil {
@@ -232,6 +246,14 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
return fmt.Errorf("failed to get network resource: %w", err)
}
oldGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthNone, resource.AccountID, resource.ID)
if err != nil {
return fmt.Errorf("failed to get old resource groups: %w", err)
}
for _, g := range oldGroups {
oldResource.GroupIDs = append(oldResource.GroupIDs, g.ID)
}
err = transaction.SaveNetworkResource(ctx, resource)
if err != nil {
return fmt.Errorf("failed to save network resource: %w", err)
@@ -247,6 +269,11 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network))
})
change = affectedpeers.Change{Resources: []*types.NetworkResource{oldResource, resource}}
if snap, err = affectedpeers.Load(ctx, transaction, resource.AccountID, change); err != nil {
return err
}
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
@@ -270,7 +297,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
}
}()
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationUpdate})
m.accountManager.ExpandAndUpdateAffected(ctx, resource.AccountID, snap, change)
return resource, nil
}
@@ -331,8 +358,26 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
}
var events []func()
var snap *affectedpeers.Snapshot
var change affectedpeers.Change
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)
existing, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthUpdate, accountID, resourceID)
if err != nil {
return fmt.Errorf("failed to get network resource: %w", err)
}
oldGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthNone, accountID, resourceID)
if err != nil {
return fmt.Errorf("failed to get resource groups: %w", err)
}
for _, g := range oldGroups {
existing.GroupIDs = append(existing.GroupIDs, g.ID)
}
change = affectedpeers.Change{Resources: []*types.NetworkResource{existing}}
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err
}
_, events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)
if err != nil {
return fmt.Errorf("failed to delete resource: %w", err)
}
@@ -352,51 +397,53 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
event()
}
go m.accountManager.UpdateAccountPeers(ctx, accountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationDelete})
m.accountManager.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return nil
}
func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) {
func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) (*types.NetworkResource, []func(), error) {
resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthUpdate, accountID, resourceID)
if err != nil {
return nil, fmt.Errorf("failed to get network resource: %w", err)
return nil, nil, fmt.Errorf("failed to get network resource: %w", err)
}
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, accountID, networkID)
if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err)
return nil, nil, fmt.Errorf("failed to get network: %w", err)
}
if resource.NetworkID != networkID {
return nil, errors.New("resource not part of network")
return nil, nil, errors.New("resource not part of network")
}
groups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, accountID, resourceID)
if err != nil {
return nil, fmt.Errorf("failed to get resource groups: %w", err)
return nil, nil, fmt.Errorf("failed to get resource groups: %w", err)
}
var eventsToStore []func()
for _, group := range groups {
resource.GroupIDs = append(resource.GroupIDs, group.ID)
event, err := m.groupsManager.RemoveResourceFromGroupInTransaction(ctx, transaction, accountID, userID, group.ID, resourceID)
if err != nil {
return nil, fmt.Errorf("failed to remove resource from group: %w", err)
return nil, nil, fmt.Errorf("failed to remove resource from group: %w", err)
}
eventsToStore = append(eventsToStore, event)
}
err = transaction.DeleteNetworkResource(ctx, accountID, resourceID)
if err != nil {
return nil, fmt.Errorf("failed to delete network resource: %w", err)
return nil, nil, fmt.Errorf("failed to delete network resource: %w", err)
}
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, resourceID, accountID, activity.NetworkResourceDeleted, resource.EventMeta(network))
})
return eventsToStore, nil
return resource, eventsToStore, nil
}
func NewManagerMock() Manager {
@@ -431,6 +478,6 @@ func (m *mockManager) DeleteResource(ctx context.Context, accountID, userID, net
return nil
}
func (m *mockManager) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) {
return []func(){}, nil
func (m *mockManager) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) (*types.NetworkResource, []func(), error) {
return nil, []func(){}, nil
}

View File

@@ -9,13 +9,13 @@ import (
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
serverTypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -26,7 +26,7 @@ type Manager interface {
GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error)
UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error)
DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error
DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error)
DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (*types.NetworkRouter, func(), error)
}
type managerImpl struct {
@@ -90,6 +90,8 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
}
var network *networkTypes.Network
var snap *affectedpeers.Snapshot
change := affectedpeers.Change{Routers: []*types.NetworkRouter{router}}
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
@@ -112,6 +114,10 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
return fmt.Errorf("failed to increment network serial: %w", err)
}
if snap, err = affectedpeers.Load(ctx, transaction, router.AccountID, change); err != nil {
return err
}
return nil
})
if err != nil {
@@ -120,7 +126,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationCreate})
m.accountManager.ExpandAndUpdateAffected(ctx, router.AccountID, snap, change)
return router, nil
}
@@ -156,36 +162,12 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
}
var network *networkTypes.Network
var snap *affectedpeers.Snapshot
var change affectedpeers.Change
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID)
if err != nil {
return fmt.Errorf("failed to get network router: %w", err)
}
if existing.AccountID != router.AccountID {
return status.NewNetworkRouterNotFoundError(router.ID)
}
if existing.NetworkID != router.NetworkID {
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
}
err = transaction.UpdateNetworkRouter(ctx, router)
if err != nil {
return fmt.Errorf("failed to update network router: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, router.AccountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
var txErr error
network, snap, change, txErr = m.updateRouterInTransaction(ctx, transaction, router)
return txErr
})
if err != nil {
return nil, err
@@ -193,11 +175,47 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationUpdate})
m.accountManager.ExpandAndUpdateAffected(ctx, router.AccountID, snap, change)
return router, nil
}
func (m *managerImpl) updateRouterInTransaction(ctx context.Context, transaction store.Store, router *types.NetworkRouter) (*networkTypes.Network, *affectedpeers.Snapshot, affectedpeers.Change, error) {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
return nil, nil, affectedpeers.Change{}, fmt.Errorf("failed to get network: %w", err)
}
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID)
if err != nil {
return nil, nil, affectedpeers.Change{}, fmt.Errorf("failed to get network router: %w", err)
}
if existing.AccountID != router.AccountID {
return nil, nil, affectedpeers.Change{}, status.NewNetworkRouterNotFoundError(router.ID)
}
if existing.NetworkID != router.NetworkID {
return nil, nil, affectedpeers.Change{}, status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
}
if err = transaction.UpdateNetworkRouter(ctx, router); err != nil {
return nil, nil, affectedpeers.Change{}, fmt.Errorf("failed to update network router: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, router.AccountID); err != nil {
return nil, nil, affectedpeers.Change{}, fmt.Errorf("failed to increment network serial: %w", err)
}
change := affectedpeers.Change{Routers: []*types.NetworkRouter{existing, router}}
snap, err := affectedpeers.Load(ctx, transaction, router.AccountID, change)
if err != nil {
return nil, nil, affectedpeers.Change{}, err
}
return network, snap, change, nil
}
func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
if err != nil {
@@ -208,8 +226,19 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
}
var event func()
var snap *affectedpeers.Snapshot
var change affectedpeers.Change
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID)
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, accountID, routerID)
if err != nil {
return fmt.Errorf("failed to get network router: %w", err)
}
change = affectedpeers.Change{Routers: []*types.NetworkRouter{existing}}
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err
}
_, event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID)
if err != nil {
return fmt.Errorf("failed to delete network router: %w", err)
}
@@ -227,36 +256,36 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
event()
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationDelete})
m.accountManager.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return nil
}
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) {
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (*types.NetworkRouter, func(), error) {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err)
return nil, nil, fmt.Errorf("failed to get network: %w", err)
}
router, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, accountID, routerID)
if err != nil {
return nil, fmt.Errorf("failed to get network router: %w", err)
return nil, nil, fmt.Errorf("failed to get network router: %w", err)
}
if router.NetworkID != networkID {
return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID)
return nil, nil, status.NewRouterNotPartOfNetworkError(routerID, networkID)
}
err = transaction.DeleteNetworkRouter(ctx, accountID, routerID)
if err != nil {
return nil, fmt.Errorf("failed to delete network router: %w", err)
return nil, nil, fmt.Errorf("failed to delete network router: %w", err)
}
event := func() {
m.accountManager.StoreEvent(ctx, userID, routerID, accountID, activity.NetworkRouterDeleted, router.EventMeta(network))
}
return event, nil
return router, event, nil
}
func NewManagerMock() Manager {
@@ -287,6 +316,9 @@ func (m *mockManager) DeleteRouter(ctx context.Context, accountID, userID, netwo
return nil
}
func (m *mockManager) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) {
return func() {}, nil
func (m *mockManager) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (*types.NetworkRouter, func(), error) {
return nil, func() {
// no-op mock: returns zero values so tests that don't exercise router deletion
// can satisfy the Manager interface without a real store.
}, nil
}

View File

@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/shared/management/status"
@@ -73,7 +74,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
//
// Disconnects use MarkPeerDisconnected and require the session to match
// exactly; see PeerStatus.SessionStartedAt for the protocol.
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
start := time.Now()
defer func() {
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
@@ -105,35 +106,22 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
}
expired := peer.Status != nil && peer.Status.LoginExpired
if peer.AddedWithSSOLogin() {
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.schedulePeerLoginExpiration(ctx, accountID)
}
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
if err = am.schedulePeerExpirations(ctx, accountID, peer); err != nil {
return err
}
if expired {
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
// A login-expired peer reconnecting, or an embedded proxy peer flipping to
// connected (which triggers SynthesizePrivateServiceZones), must refresh the
// peers reachable from it. The embedded-proxy fan-out tolerates a dispatch error.
if peer.Status != nil && peer.Status.LoginExpired {
affectedPeerIDs := am.markConnectedAffectedPeers(ctx, accountID, peer.ID, nmap)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}, affectedPeerIDs); err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
// An embedded proxy peer flipping to connected is the trigger for
// SynthesizePrivateServiceZones to emit DNS A records pointing at its
// tunnel IP. Without an account-wide netmap recompute, user peers keep
// the stale synth (or no synth at all on first connect) until some
// other change pokes the controller. Fire OnPeersUpdated so the
// buffered recompute fans the new state out to every peer.
if peer.ProxyMeta.Embedded {
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
affectedPeerIDs := am.markConnectedAffectedPeers(ctx, accountID, peer.ID, nmap)
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}, affectedPeerIDs); err != nil {
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s connect: %v", peer.ID, err)
}
}
@@ -141,6 +129,25 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return nil
}
// schedulePeerExpirations reschedules the account's login/inactivity expiration
// timers for an SSO peer that just connected.
func (am *DefaultAccountManager) schedulePeerExpirations(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
if !peer.AddedWithSSOLogin() {
return nil
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.schedulePeerLoginExpiration(ctx, accountID)
}
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
return nil
}
// MarkPeerDisconnected marks a peer as disconnected, but only when the
// stored session token matches the one passed in. A mismatch means a
// newer stream has already taken ownership of the peer — disconnects from
@@ -175,11 +182,12 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied)
// Symmetric with MarkPeerConnected: when an embedded proxy peer goes
// offline, drive an account-wide netmap recompute so the synthesized
// DNS records that pointed at it are pulled. Without this the records
// linger client-side at TTL until something else triggers a refresh.
// offline, refresh the peers that had synthesized records pointing at
// it so they pull the stale entries instead of waiting out TTL.
if peer.ProxyMeta.Embedded {
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s disconnect: %v", peer.ID, err)
}
}
@@ -346,7 +354,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
}
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
affectedPeerIDs = append(affectedPeerIDs, peer.ID)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return nil, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -501,10 +512,6 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return status.NewPeerNotPartOfAccountError()
}
var peer *nbpeer.Peer
var settings *types.Settings
var eventsToStore []func()
serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
if err != nil {
return fmt.Errorf("failed to check if resource is used by service: %w", err)
@@ -513,8 +520,38 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return status.NewPeerInUseError(peerID, serviceID)
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
change := affectedpeers.Change{ChangedPeerIDs: []string{peerID}}
settings, eventsToStore, snap, err := am.deletePeerInTransaction(ctx, accountID, userID, peerID, change)
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if err = am.integratedPeerValidator.PeerDeleted(ctx, accountID, peerID, settings.Extra); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peerID, err)
}
affectedPeerIDs := snap.Expand(ctx, accountID, change)
if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}, affectedPeerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err)
}
return nil
}
// deletePeerInTransaction loads the peer + settings, captures the affected-peers
// snapshot (before the delete, while the peer's group memberships still exist),
// then deletes the peer and bumps the network serial — all in one transaction.
func (am *DefaultAccountManager) deletePeerInTransaction(ctx context.Context, accountID, userID, peerID string, change affectedpeers.Change) (*types.Settings, []func(), *affectedpeers.Snapshot, error) {
var settings *types.Settings
var eventsToStore []func()
var snap *affectedpeers.Snapshot
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return err
}
@@ -528,8 +565,11 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}, settings)
if err != nil {
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err
}
if eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}, settings); err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
}
@@ -539,23 +579,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return nil
})
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if err = am.integratedPeerValidator.PeerDeleted(ctx, accountID, peerID, settings.Extra); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peerID, err)
}
if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err)
}
return nil
return settings, eventsToStore, snap, err
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)
@@ -694,10 +718,10 @@ func (am *DefaultAccountManager) handleSetupKeyAddedPeer(ctx context.Context, en
// to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied
// Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused).
// The peer property is just a placeholder for the Peer properties to pass further
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
if setupKey == "" && userID == "" && !peer.ProxyMeta.Embedded {
// no auth method provided => reject access
return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
return nil, nil, nil, false, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
}
upperKey := strings.ToUpper(setupKey)
@@ -713,7 +737,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
// The connecting peer should be able to recover with a retry.
_, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key)
if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
return nil, nil, nil, false, status.Errorf(status.PreconditionFailed, "peer has been already registered")
}
opEvent := &activity.Event{
@@ -724,7 +748,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
peerAddConfig, err := am.processPeerAddAuth(ctx, accountID, userID, encodedHashedKey, peer, temporary, addedByUser, addedBySetupKey, opEvent)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, false, err
}
accountID = peerAddConfig.AccountID
ephemeral := peerAddConfig.Ephemeral
@@ -739,7 +763,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
}
if err := domain.ValidateDomainsList(peer.ExtraDNSLabels); err != nil {
return nil, nil, nil, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
return nil, nil, nil, false, status.Errorf(status.InvalidArgument, "invalid extra DNS labels: %v", err)
}
registrationTime := time.Now().UTC()
@@ -765,7 +789,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get account settings: %w", err)
return nil, nil, nil, false, fmt.Errorf("failed to get account settings: %w", err)
}
if am.geo != nil && newPeer.Location.ConnectionIP != nil {
@@ -783,30 +807,30 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed getting network: %w", err)
return nil, nil, nil, false, fmt.Errorf("failed getting network: %w", err)
}
maxAttempts := 10
for attempt := 1; attempt <= maxAttempts; attempt++ {
netPrefix, err := netip.ParsePrefix(network.Net.String())
if err != nil {
return nil, nil, nil, fmt.Errorf("parse network prefix: %w", err)
return nil, nil, nil, false, fmt.Errorf("parse network prefix: %w", err)
}
freeIP, err := types.AllocateRandomPeerIP(netPrefix)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get free IP: %w", err)
return nil, nil, nil, false, fmt.Errorf("failed to get free IP: %w", err)
}
var freeLabel string
if ephemeral || attempt > 1 {
freeLabel, err = getPeerIPDNSLabel(freeIP, peer.Meta.Hostname)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
return nil, nil, nil, false, fmt.Errorf("failed to get free DNS label: %w", err)
}
} else {
freeLabel, err = nbdns.GetParsedDomainLabel(peer.Meta.Hostname)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get free DNS label: %w", err)
return nil, nil, nil, false, fmt.Errorf("failed to get free DNS label: %w", err)
}
}
newPeer.DNSLabel = freeLabel
@@ -828,11 +852,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
if allocate {
v6Prefix, err := netip.ParsePrefix(network.NetV6.String())
if err != nil {
return nil, nil, nil, fmt.Errorf("parse IPv6 prefix: %w", err)
return nil, nil, nil, false, fmt.Errorf("parse IPv6 prefix: %w", err)
}
freeIPv6, err := types.AllocateRandomPeerIPv6(v6Prefix)
if err != nil {
return nil, nil, nil, fmt.Errorf("allocate peer IPv6: %w", err)
return nil, nil, nil, false, fmt.Errorf("allocate peer IPv6: %w", err)
}
newPeer.IPv6 = freeIPv6
}
@@ -905,10 +929,10 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
continue
}
return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err)
return nil, nil, nil, false, fmt.Errorf("failed to add peer to database: %w", err)
}
if newPeer == nil {
return nil, nil, nil, fmt.Errorf("new peer is nil")
return nil, nil, nil, false, fmt.Errorf("new peer is nil")
}
opEvent.TargetID = newPeer.ID
@@ -916,7 +940,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
if !addedByUser {
opEvent.Meta["setup_key_name"] = peerAddConfig.SetupKeyName
}
if newPeer.Status != nil && newPeer.Status.RequiresApproval {
requiresApproval := newPeer.Status != nil && newPeer.Status.RequiresApproval
if requiresApproval {
opEvent.Meta["pending_approval"] = true
}
@@ -924,12 +949,18 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
}
if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil {
network, postureChecks, enableSSH, err := getPeerLoginInfo(ctx, am.Store, accountID, newPeer, !requiresApproval)
if err != nil {
return nil, nil, nil, false, err
}
changedPeerIDs := []string{newPeer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersAdded(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
return p, nmap, pc, err
return newPeer, network, postureChecks, enableSSH, nil
}
func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) {
@@ -1011,17 +1042,51 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
nmap, resPostureChecks, dnsFwdPort, err := am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer.ID)
if err != nil {
return nil, nil, nil, 0, err
}
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || updated || (len(postureChecks) > 0 || versionChanged) {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(postureChecks) > 0)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
return am.networkMapController.GetValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
return peer, nmap, resPostureChecks, dnsFwdPort, nil
}
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
// syncPeerAffectedPeers resolves the peers affected by a SyncPeer change. The
// peer's own validated network map is bidirectional for policy and routing
// reachability, so when the peer stays valid and no source-posture gate is in
// play it already lists every affected peer — reuse it and skip the full
// dependency walk. Posture checks gate the source side of a policy only, so a
// metadata change that flips a posture result removes this peer from others'
// maps asymmetrically; that case (and an invalid peer, whose map is empty) falls
// back to the resolver.
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaUpdated, hasPostureChecks bool) []string {
if peerNotValid || (metaUpdated && hasPostureChecks) {
return am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, []string{peerID})
}
return affectedPeerIDsFromNetworkMap(nmap, peerID)
}
// markConnectedAffectedPeers resolves the peers affected when a peer connects
// (login-expiry reconnect or embedded-proxy connect). The connecting peer's
// network map already lists them bidirectionally — the synthesized
// private-service policy puts proxy access-group members in the proxy peer's own
// map, and these edges carry no source-posture gate. An invalid peer has an
// empty map, so fall back to the resolver in that case.
func (am *DefaultAccountManager) markConnectedAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap) []string {
if nmap == nil || len(nmap.Peers)+len(nmap.OfflinePeers) == 0 {
return am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, []string{peerID})
}
return affectedPeerIDsFromNetworkMap(nmap, peerID)
}
func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound {
// we couldn't find this peer by its public key which can mean that peer hasn't been registered yet.
// Try registering it.
@@ -1037,12 +1102,12 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo
}
log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err)
return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer")
return nil, nil, nil, false, status.Errorf(status.Internal, "failed while logging in peer")
}
// LoginPeer logs in or registers a peer.
// If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so.
func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey)
if err != nil {
return am.handlePeerLoginNotFound(ctx, login, err)
@@ -1054,20 +1119,17 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
if login.UserID == "" {
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, false, err
}
}
var peer *nbpeer.Peer
var updateRemotePeers bool
var isPeerUpdated bool
var ipv6CapabilityChanged bool
var postureChecks []*posture.Checks
var shouldStorePeer bool
var peerGroupIDs []string
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, false, err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -1076,9 +1138,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return err
}
// this flag prevents unnecessary calls to the persistent store.
shouldStorePeer := false
if login.UserID != "" {
if peer.UserID != login.UserID {
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
@@ -1092,7 +1151,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
if changed {
shouldStorePeer = true
updateRemotePeers = true
}
}
@@ -1101,23 +1159,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return err
}
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
isPeerUpdated, _ = peer.UpdateMetaIfNew(login.Meta)
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
if isPeerUpdated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
shouldStorePeer = true
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
}
if peer.SSHKey != login.SSHKey {
peer.SSHKey = login.SSHKey
shouldStorePeer = true
updateRemotePeers = true
}
if !peer.AllowExtraDNSLabels && len(login.ExtraDNSLabels) > 0 {
@@ -1133,23 +1177,28 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil
})
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, false, err
}
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, false, err
}
if updateRemotePeers || isStatusChanged || ipv6CapabilityChanged || (isPeerUpdated && len(postureChecks) > 0) {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err)
network, postureChecks, enableSSH, err := getPeerLoginInfo(ctx, am.Store, accountID, peer, !isRequiresApproval)
if err != nil {
return nil, nil, nil, false, err
}
if isStatusChanged || shouldStorePeer {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
return nil, nil, nil, false, fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
return p, nmap, pc, err
return peer, network, postureChecks, enableSSH, nil
}
// ExtendPeerSession refreshes the peer's SSO session deadline by updating
@@ -1225,6 +1274,50 @@ func (am *DefaultAccountManager) ExtendPeerSession(ctx context.Context, peerPubK
return refreshed.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration), nil
}
// getPeerLoginInfo computes the login/register response data (network, posture
// checks, SSH) from the store without building the peer's full network map.
func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer, isValid bool) (*types.Network, []*posture.Checks, bool, error) {
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, nil, false, fmt.Errorf("get account network: %w", err)
}
if !isValid {
return network, nil, false, nil
}
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
if err != nil {
return nil, nil, false, err
}
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer)
if err != nil {
return nil, nil, false, err
}
return network, postureChecks, enableSSH, nil
}
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return false, err
}
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
for _, g := range peerGroups {
peerGroupIDs[g.ID] = struct{}{}
}
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
}
// getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
@@ -1407,6 +1500,100 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID, reason)
}
// ExpandAndUpdateAffected expands a Snapshot (loaded INSIDE the now-committed
// transaction) into the affected peers and dispatches the network-map refresh.
// Pure in-memory work plus dispatch, so it runs AFTER commit — the fan-out walk
// never holds the write lock, over the consistent in-tx snapshot. Exported so the
// networks sub-package managers (which hold only account.Manager) share it.
func (am *DefaultAccountManager) ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change) {
go am.dispatchAffected(ctx, accountID, []*affectedpeers.Snapshot{snap}, []affectedpeers.Change{change})
}
// dispatchAffected expands one or more (snapshot, change) pairs — collected across
// one or several transactions — unions their affected peers, and dispatches a
// single network-map refresh. Each snapshot must already be loaded inside its
// transaction; this runs AFTER commit (pure in-memory + dispatch). It is spawned
// in a goroutine that outlives the request, so it detaches from the request
// context's cancellation up front.
func (am *DefaultAccountManager) dispatchAffected(ctx context.Context, accountID string, snaps []*affectedpeers.Snapshot, changes []affectedpeers.Change) {
ctx = context.WithoutCancel(ctx)
var lists [][]string
for i, snap := range snaps {
if snap == nil {
continue
}
lists = append(lists, snap.Expand(ctx, accountID, changes[i]))
}
affectedPeerIDs := unionStrings(lists...)
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for account %s", accountID)
return
}
log.WithContext(ctx).Debugf("updating %d affected peers for account %s: %v", len(affectedPeerIDs), accountID, affectedPeerIDs)
_ = am.networkMapController.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
}
// unionStrings concatenates the given string lists into one deduplicated slice,
// preserving first-occurrence order.
func unionStrings(lists ...[]string) []string {
seen := make(map[string]struct{})
var out []string
for _, list := range lists {
for _, id := range list {
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
}
return out
}
// affectedPeerIDsFromNetworkMap returns the peer IDs referenced by a peer's
// network map (its connected and offline peers, which include routing and proxy
// peers), excluding the peer itself. For a freshly added peer these are, by ACL
// symmetry, exactly the peers its addition affects.
func affectedPeerIDsFromNetworkMap(nmap *types.NetworkMap, selfPeerID string) []string {
if nmap == nil {
return nil
}
seen := make(map[string]struct{}, len(nmap.Peers)+len(nmap.OfflinePeers))
ids := make([]string, 0, len(nmap.Peers)+len(nmap.OfflinePeers))
add := func(peers []*nbpeer.Peer) {
for _, p := range peers {
if p == nil || p.ID == "" || p.ID == selfPeerID {
continue
}
if _, ok := seen[p.ID]; ok {
continue
}
seen[p.ID] = struct{}{}
ids = append(ids, p.ID)
}
}
add(nmap.Peers)
add(nmap.OfflinePeers)
return ids
}
// resolveAffectedPeersForPeerChanges loads a snapshot and expands it for a peer
// change. The graph is unchanged by these paths, so it runs out of the mutating
// transaction (after commit); the resolver derives the peers' group memberships
// during the walk, so the caller passes only the changed peer IDs.
func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context.Context, s store.Store, accountID string, changedPeerIDs []string) []string {
change := affectedpeers.Change{ChangedPeerIDs: changedPeerIDs}
snap, err := affectedpeers.Load(ctx, s, accountID, change)
if err != nil {
log.WithContext(ctx).Errorf("failed to load snapshot for affected peers: %v", err)
return nil
}
return snap.Expand(ctx, accountID, change)
}
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID, reason)
}

Some files were not shown because too many files have changed in this diff Show More