Compare commits

...

21 Commits

Author SHA1 Message Date
mlsmaycon
2d8b0310a4 [client, proxy] IPv6 in-place apply + accept-loop hardening on netstack listeners
Two related fixes for the embedded netbird client and the per-account
inbound listeners that ride on its gVisor netstack.

client/internal/engine.go — replace hasIPv6Changed with reconcileIPv6:

  - First v6 assignment (current had no v6, conf carries one) is applied
    in place via WGIface.UpdateAddr instead of returning ErrResetConnection.
    Pre-fix, every embedded client whose account had IPv6 enabled would
    reset on its first NetworkMap sync — boot config has no v6, the sync
    introduces one, the engine tore itself down to "apply" it. That
    teardown destroys the gVisor netstack and orphans every listener
    bound on it, which is what made the proxy's per-account :80/:443
    silently stop accepting traffic.
  - v6 removed clears in place.
  - v6 swapped to a different non-empty value still resets (gVisor
    netstack can't safely swap its address at runtime).
  - Mutates e.config.WgAddr to match the applied state so subsequent
    PeerConfig comparisons are stable.

proxy/internal/tcp/accept.go (new) + proxy/inbound.go +
proxy/internal/tcp/router.go — harden the two Accept() loops on
netstack-backed listeners:

  - IsClosedListenerErr recognises net.ErrClosed AND gVisor's
    "endpoint is in invalid state" — the latter survives gonet's
    *net.OpError wrapping in a way errors.Is(.., net.ErrClosed) does
    not. Without this the loop spins CPU-hot after the underlying
    netstack is destroyed (peer rekey, embedded-client reset, account
    churn), emitting one log line per iteration.
  - AcceptBackoff implements the exponential backoff that
    net/http.Server.Serve uses on transient Accept errors: 5ms doubling
    up to 1s. Defence-in-depth so an unknown sticky error cannot burn
    a CPU core even if IsClosedListenerErr misses its signature.

proxy/internal/roundtrip/netbird.go — emit a single structured INFO
line summarising every embed.Options flag (account_id, service_id,
public_key, management_url, wg_port, block_inbound, block_lan_access,
disable_ipv6, no_userspace, presence of credentials) when each
per-account embedded client is created. Secrets reduced to a "present"
boolean — never logged verbatim. Diagnostic-only; no behavior change,
but it makes the "why is this embedded peer misbehaving" loop a single
log read instead of a code dive.

Tests (real listeners, scripted errors, no mocks of production code):
  - engine_reconcileipv6_test.go: 8 cases for every transition (first
    assignment, no change, removed, prefix-length changed, value
    changed, invalid bytes, UpdateAddr error) plus a updateConfig
    integration check that the fix actually fires on a v6-added
    PeerConfig.
  - accept_test.go: IsClosedListenerErr matrix + AcceptBackoff
    progression / cap / reset / cancel-during-wait / cancel-before-call.
  - router_test.go, inbound_test.go: scriptedAcceptListener +
    TestRouter_Serve_ExitsOnGVisorInvalidEndpoint and
    TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint —
    regression guards that fail in 2 s if the loop ever spins.
2026-06-18 10:37:51 +02:00
Theodor Midtlien
ee360963f9 [client] Migrate profile identity from display name to ID and allow renaming of profiles (#6367)
* Migrate to profile ids

* Migrate android profile manager

* Clean up

* Fix review

* Add ID type

* Fix test and runes in ShortID()

* Fix profile switch on up and android comments

* Revert android profile to string id

* Fix feedback

* Fix UI feedback

* Fix id assignment

* Add renaming of profiles

* Fix review

* Remove ui binary
* Fix getProfileConfigPath not validating id

* Change resolve handle order and fix server merge problems

* Fix mdm test
2026-06-18 08:49:19 +02:00
Maycon Santos
8d9580e491 [misc] improve goreleaser with RC handling and update docker builds (#6438)
- introduce variables to avoid publishing latest docker tags and installers
- Refactor .goreleaser.yaml to simplify docker configurations and add environment-driven flags
- removed management debug containers (it was doing only log var)
- Stopped building arm v6 32bits in favor of v7 32 bits for services (not client)
- Add target argument to docker files
2026-06-17 20:13:13 +02:00
Viktor Liu
5bd7c6c7ea [client] Detect and recover from a stalled signal receive stream (#6459) 2026-06-17 18:48:09 +02:00
Zoltan Papp
8ae2cd0a08 [client] Fix ios route notify ordering (#6454)
* [client] fix iOS route-update reordering that black-holed IPv6 on exit-node disable

On iOS the route notifier delivered each prefix update from its own
fire-and-forget goroutine (notify -> `go func`), so Go provided no ordering
guarantee between consecutive updates. It also read currentPrefixes inside
that goroutine without holding the lock, racing the next OnNewPrefixes write.

On exit-node disable the core removes the default routes as two separate
prefix updates (0.0.0.0/0, then the synthesized ::/0). When the two
goroutines were reordered, the stale snapshot still containing ::/0 was
delivered last and clobbered the correct default-free one. iOS then kept the
::/0 default route on the tunnel with no exit node to carry it, black-holing
all IPv6 traffic while IPv4 recovered correctly.

Fix: deliver updates through a single worker goroutine fed by a buffered
channel, preserving production order, and snapshot the joined prefix string
under the mutex so it can't race a concurrent update. Buffered so producers
(which run under the route manager lock) don't block on the listener callback.

* [client] close iOS notifier delivery goroutine on Stop, unbounded queue

The delivery goroutine was never stopped, leaking on every engine
restart. Add Notifier.Close, called from the route manager Stop after
routing cleanup.

Replace the buffered update channel with a cond-driven linked-list
queue so route-update producers (running under the route manager lock)
never block when the listener callback is slow.
2026-06-17 18:29:33 +02:00
Pascal Fischer
e4397d4d46 [management] remove nmap calc from login (#6449) 2026-06-17 16:37:24 +02:00
Viktor Liu
6fbc90b4d3 [client, relay] Expose relay transport and connection errors in status and metrics (#6342) 2026-06-17 15:41:48 +02:00
Riccardo Manfrin
5095e17cc5 [management] fix flaky Test_SaveAccount_Large from random IP collision (#6452) 2026-06-17 14:00:50 +02:00
Zoltan Papp
6df0175607 [client] Add IsLoginRequiredCached for iOS mobile client (#6447)
Expose a network-free login-required check backed by the in-memory status
recorder. Unlike IsLoginRequired(), which creates a fresh auth client and
performs a blocking network call, IsLoginRequiredCached() reports whether the
LAST observed management error was an auth failure (PermissionDenied/
InvalidArgument).

This lets the iOS connection listener detect a mid-session token expiry from
within onDisconnected during teardown without blocking on a slow or
unavailable network.
2026-06-16 16:15:19 +02:00
Zoltan Papp
3c23700e56 [client] Add iOS debug bundle support in Go (#6270)
* Add iOS debug bundle support in Go

Thread cacheDir through NewClient -> RunOniOS -> MobileDependency.TempDir
so the iOS client can pass its sandbox-writable cache directory for
debug bundle zip file creation instead of os.TempDir().

Move log collection into platform-dispatched addPlatformLog():
- iOS: adds the file-based Go client log (with rotation, stderr/stdout
  companions and anonymization handled by addLogfile) plus the Swift app
  log (swift-log.log) written by the iOS app into the same log directory
- Other non-Android platforms: existing file-based log + systemd fallback

Narrow the debug_nonandroid.go build tag to !android && !ios so iOS no
longer attempts the systemd journal fallback.

Add a DebugBundle() entry point to the iOS Go client that generates a
bundle, uploads it and returns the upload key. It works with or without
a running engine: when the engine is up it reuses the live config, sync
response and client metrics; otherwise it loads the config from disk (or
the preloaded tvOS config). Guard the live config/ConnectClient behind a
state mutex since DebugBundle may run on a different thread.

* Include the iOS state file in the debug bundle

addStateFile() resolved the state path via ServiceManager.GetStatePath(),
which on iOS points at a hard-coded default that does not exist in the app
sandbox, so the state file was silently skipped.

Add an optional StatePath to GeneratorDependencies and use it when set,
falling back to the ServiceManager default otherwise. The iOS DebugBundle
passes the client's actual state file path (the App Group profile state),
matching the Android bundle which includes the state file.

* ios: enable sync response persistence for debug bundle

Turn on sync response persistence before starting the engine so
DebugBundle can include the network map. On iOS the store is disk-backed
(see syncstore) to keep the map out of the constrained process memory.

* ios: pass log file path through NewClient constructor (#6393)

Add logFilePath field to Client struct and expose it as a parameter
in NewClient so callers provide the Go log path at construction time.
Wire it into DebugBundle via GeneratorDependencies.LogPath so the
debug bundle includes client.log and swift-log.log regardless of
whether the bundle is triggered by the app or the management server.

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>

* ios: pass log file path to engine for remote debug bundles

RunOniOS started the engine with an empty LogPath, so EngineConfig.LogPath
was never set. Management-triggered (jobs) debug bundles read the log path
from the engine config, so they collected no client logs (client.log,
rotated logs, swift-log.log). The GUI path was unaffected because it passes
c.logFilePath directly to the bundle generator.

Thread c.logFilePath through RunOniOS into the engine config so remote
bundles include the client logs too.

---------

Co-authored-by: evgeniyChepelev <68751844+evgeniyChepelev@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-06-16 15:54:46 +02:00
Pascal Fischer
38ad2b67e8 [proxy] fix context for udprelay (#6444) 2026-06-16 14:41:17 +02:00
Pascal Fischer
01aa49433e [management] delete targets when deleting exposed service (#6442) 2026-06-16 14:33:24 +02:00
Zoltan Papp
08a2b63675 [client] propagate exit-node deselect to synthesized v6 (::/0) route (#6296)
* [client] propagate exit-node deselect to synthesized v6 (::/0) route

When a client deselects an IPv4 exit node, the auto-generated IPv6 default
route (::/0) was still selected and pushed onto the tunnel interface, even
though the user disabled the exit node. On an exit node without a real IPv6
egress this blackholes IPv6 traffic, and because clients prefer IPv6 (happy
eyeballs) it can break general connectivity.

Root cause: the synthesized v6 route gets a different NetID than its v4 base
(base + "-v6"). The route selector keys deselects by NetID and defaults
unknown NetIDs to selected, so the "-v6" entry was never matched by the v4
deselect. The effectiveNetID() mirror that solves exactly this is used by
HasUserSelectionForRoute and FilterSelectedExitNodes, but categorizeUserSelection
called the raw IsSelected(), bypassing it and mis-categorizing the v6 pair as
user-selected.

Add RouteSelector.IsSelectedForExitNode(), which applies effectiveNetID before
the selection check, and use it in categorizeUserSelection. IsSelected() is left
untouched so non-exit code paths don't make unrelated "*-v6" routes inherit v4
state. Adds regression tests for the v4/v6 deselect mirror and explicit-v6
override.

* [client] add DIAG logging to trace exit-node v6 (::/0) route filtering

Temporary diagnostics to find why a deselected v4 exit node's synthesized
::/0 route still reaches the tunnel. Logs the full install path: incoming
client networks, route-selector state before/after the management-driven
update, what updateExitNodeSelections deselects/selects, and per-route
KEEP/SKIP/DROP decisions in FilterSelectedExitNodes and applyExitNodeFilter.
To be reverted once the real root cause is confirmed from a client log.

* [client] clear orphaned v6 exit selection when v4 pair is toggled

Root cause of the leaking ::/0 route, confirmed from client logs: the
synthesized "-v6" exit route could stay explicitly selected in the persisted
route-selector state while its v4 base was deselected (selected=[...-v6],
deselected=[...v4base]). Because the v6 entry then has its own explicit state,
effectiveNetID stops mirroring the v4 base, so FilterSelectedExitNodes keeps
::/0 and it is installed on the tunnel even though the user disabled the exit
node. This happened because the iOS SDK's deselect only pairs the "-v6" sibling
via ExpandV6ExitPairs when the v6 route is present in the current routesMap; a
deselect at a moment it wasn't expanded left the v6 selection orphaned.

Fix at the selector write path so it is independent of routesMap timing: when a
v4 exit NetID is selected or deselected, clear any orphaned explicit state on
its "-v6" sibling (clearPairedV6Locked), unless the sibling is part of the same
batch (the deliberate ExpandV6ExitPairs case). The v6 then falls back to
inheriting the v4 base via effectiveNetID, so a v4 deselect also drops ::/0 and
a v4 select brings both back.

Adds regression tests: a stale explicit v6 selection is cleared by a later v4
deselect, and an explicit v6 select made in the same batch is preserved.

* [ios] compute route connection status in the bridge

The iOS bridge exposed a route's Network as a possibly comma-joined string
("0.0.0.0/0, ::/0" for a merged exit node) but no connection status, forcing
the UI to infer status by string-matching that joined value against peer
routes — which never matched for the merged exit node, leaving it stuck as
not-connected. Android already computes status in the core (findBestRoutePeer).

Mirror that here: add a Status field to RoutesSelectionInfo and compute it from
the connected peers' route tables, matching the route's primary prefix, a merged
exit node's extra v6 prefix, or a dynamic route's domain pattern (the key the
route manager records). The UI can now read the status directly.

* [client] remove exit-node v6 DIAG logging and tidy routeselector

Drop the temporary DIAG diagnostics added to trace the leaking ::/0 route
(the root cause is fixed and confirmed). Also reorganize routeselector.go so
the exit-node helpers (clearPairedV6Locked, isExitNode) sit next to the
exit-node code paths and MarshalJSON/UnmarshalJSON are grouped together.

* [client] mirror v4 exit selection onto v6 pair at write time

The synthesized "-v6" exit route shares its v4 base's NetID plus a "-v6"
suffix. Selection state was reconciled at read time via effectiveNetID, a
mirror that could only be applied on exit-node code paths, which forced a
parallel IsSelectedForExitNode() alongside IsSelected() and a clearPairedV6Locked()
orphan cleanup on every toggle. That machinery still missed the case observed
in the field: a persisted state with the v4 base deselected but its "-v6"
sibling explicitly selected (orphaned). Because effectiveNetID returns the v6
entry itself once it carries explicit state, and clearPairedV6Locked only fires
on a live toggle, the loaded orphan survived and the ::/0 route leaked onto the
tunnel despite the exit node being disabled, breaking IPv6 (happy eyeballs).

Treat the v4/v6 exit pair as a single toggle and keep state consistent at write
time instead. RouteSelector.SyncPairedSelection forces the "-v6" entry to match
its v4 base unconditionally, resetting any orphaned explicit state. The route
manager, which knows the route prefixes, computes the pairs (V6ExitMergeSet) and
calls it from updateRouteSelectorFromManagement before selection is read, so both
collectExitNodeInfo and FilterSelectedExitNodes see consistent state, including
pairs loaded from persisted selector state.

This removes effectiveNetID, IsSelectedForExitNode and clearPairedV6Locked; the
selector is literal again and no longer needs the "exit-node paths only" caveat.
HasUserSelectionForRoute and applyExitNodeFilter use the raw NetID.

Adds a selector test for SyncPairedSelection (including the orphaned-v6 case) and
a route-manager test reproducing the persisted-orphan scenario from the field log.

* [client] add DIAG logging to trace v6 exit-pair mirror

The write-time mirror did not eliminate the leak in field testing. Re-add the
DIAG diagnostics around the exit-node selection flow to capture a fresh trace:

- UpdateRoutes: incoming client networks, selector state before/after the
  management update, and the networks remaining after FilterSelectedExitNodes.
- mirrorV6ExitPairSelections: the NetIDs present in this update and the v6 pairs
  V6ExitMergeSet derives from them (reveals whether the v4 base and its ::/0 pair
  are present in the same update so the pair can be matched).
- SyncPairedSelection: the base/paired state before and after the sync.
- FilterSelectedExitNodes / applyExitNodeFilter: per-route SKIP/KEEP/DROP and the
  selection lookups behind each decision.
- updateExitNodeSelections / logExitNodeUpdate: categorization and deselect set.

Temporary; to be removed once the root cause is confirmed.

* [client] remove v6 exit-pair mirror DIAG logging

Drop the temporary DIAG diagnostics added to trace the v4/v6 exit-pair mirror.
The field log confirmed the write-time mirror keeps the pair consistent (the
::/0 route is only ever applied alongside its v4 base and is dropped on deselect),
so the diagnostics are no longer needed.
2026-06-16 12:27:58 +02:00
Maycon Santos
b3f9e6588a [management] sync openapi spec and test for diff on workflows (#6437)
* [management] sync openapi spec and test for diff on workflows

* [management] pin oapi-codegen version to v2.7.1
2026-06-15 17:53:25 +02:00
Pascal Fischer
967e2d6864 [management] network map for affected peers (#6105) 2026-06-15 17:43:22 +02:00
Zoltan Papp
e7c1d364c3 [management] treat ci- builds as development for remote jobs (#6436)
* fix(management): treat ci- builds as development for remote jobs

CI snapshot builds use a "ci-<sha>" version string that did not match
IsDevelopmentVersion, so the remote-jobs minimum-version gate rejected
them. Recognize the "ci-" prefix as a development build.

* fix(management): treat dev- builds as development for remote jobs

Dev snapshot builds use a "dev-<sha>" version string that did not match
IsDevelopmentVersion, so the remote-jobs minimum-version gate rejected
them. Recognize the "dev-" prefix as a development build, alongside the
existing "ci-" prefix.
2026-06-15 17:22:40 +02:00
Viktor Liu
a44198fd77 [client] Add dialWebSocket method to WASM client (#5980) 2026-06-15 16:43:24 +02:00
Viktor Liu
b57f714350 [client] Drop signaling-side ICE candidate filter, drop overlay STUN at mux read-side instead (#6142) 2026-06-15 16:37:03 +02:00
Viktor Liu
f893abc41d [client] Recover from tun device read/write panics and restart the client (#6419) 2026-06-15 16:36:00 +02:00
Lee Sang Hoon
60067619a1 [proxy] Keep custom TCP listeners alive after mapping batches (#6415) 2026-06-15 12:21:24 +02:00
Bethuel Mmbaga
cd777395f2 [management] Skip JWT group evaluation for embedded-IdP local users (#6422)
When JWT group sync is enabled with a restrictive JWTAllowGroups list, the local owner of an embedded-IdP (Dex) deployment can get locked out. The allow-groups check runs account-wide but local password users do not receive
external IdP group claims, so they can't satisfy the allowed list.

This skips JWT group evaluation for local Dex users so the restriction and JWT group sync continue to apply to external-IdP users as intended.
2026-06-15 12:01:54 +03:00
165 changed files with 11045 additions and 2667 deletions

View File

@@ -9,10 +9,13 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.5"
GORELEASER_VER: "v2.14.3"
SIGN_PIPE_VER: "v0.1.6"
GORELEASER_VER: "v2.16.0"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
flags: ""
SKIP_PUBLISH: "true"
SKIP_DOCKER_PUSH: "false"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
@@ -130,8 +133,6 @@ jobs:
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
env:
flags: ""
steps:
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -143,8 +144,27 @@ jobs:
id: semver_parser
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Set snapshot flag
if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: |
echo "flags=--snapshot" >> $GITHUB_ENV
- name: Set build vars
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
run: |
if [[ "x-${{ steps.semver_parser.outputs.prerelease }}" == "x-" && "x-${{ github.repository }}" == "x-netbirdio/netbird" ]]; then
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
echo "SKIP_PUBLISH=false" >> $GITHUB_ENV
else
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
fi
if [[ "x-${{ github.repository }}" != "x-netbirdio/netbird" ]]; then
echo "SKIP_DOCKER_PUSH=true" >> $GITHUB_ENV
fi
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
with:
@@ -161,6 +181,8 @@ jobs:
${{ runner.os }}-go-releaser-
- name: Install modules
run: go mod tidy
- name: run openapi generator
run: bash shared/management/http/api/generate.sh
- name: check git status
run: git --no-pager diff --exit-code
- name: Set up QEMU
@@ -210,6 +232,8 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
SKIP_PUBLISH: ${{ env.SKIP_PUBLISH }}
SKIP_DOCKER_PUSH: ${{ env.SKIP_DOCKER_PUSH }}
- name: Verify RPM signatures
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
@@ -332,8 +356,22 @@ jobs:
id: semver_parser
uses: netbirdio/shared-actions/actions/parse-semver@be5df6047383da2236e02243cceb857d8567c27e # v0.0.2
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Set snapshot flag
if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: |
echo "flags=--snapshot" >> $GITHUB_ENV
- name: Set build vars
if: ${{ startsWith(github.ref, 'refs/tags/v') }}
run: |
if [[ "x-${{ steps.semver_parser.outputs.prerelease }}" == "x-" && "x-${{ github.repository }}" == "x-netbirdio/netbird" ]]; then
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
echo "SKIP_PUBLISH=false" >> $GITHUB_ENV
else
echo "x-${{ github.repository }}"
echo "x-${{ steps.semver_parser.outputs.prerelease }}"
fi
- name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0
@@ -393,6 +431,7 @@ jobs:
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
SKIP_PUBLISH: ${{ env.SKIP_PUBLISH }}
- name: Verify RPM signatures
run: |
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '

View File

@@ -1,5 +1,7 @@
version: 2
env:
- SKIP_PUBLISH={{ if index .Env "SKIP_PUBLISH" }}{{ .Env.SKIP_PUBLISH }}{{ else }}true{{ end }}
- SKIP_DOCKER_PUSH={{ if index .Env "SKIP_DOCKER_PUSH" }}{{ .Env.SKIP_DOCKER_PUSH }}{{ else }}false{{ end }}
project_name: netbird
builds:
- id: netbird-wasm
@@ -74,6 +76,8 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -88,6 +92,8 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -102,6 +108,8 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -122,6 +130,8 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -136,6 +146,8 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -150,6 +162,8 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -170,6 +184,8 @@ builds:
- amd64
- arm64
- arm
goarm:
- 7
ldflags:
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
mod_timestamp: "{{ .CommitTimestamp }}"
@@ -222,670 +238,192 @@ nfpms:
rpm:
signature:
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
dockers:
- image_templates:
- netbirdio/netbird:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
ids:
- netbird
goarch: amd64
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
ids:
- netbird
goarch: arm64
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
ids:
- netbird
goarch: arm
goarm: 6
use: buildx
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
ids:
- netbird
goarch: amd64
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
ids:
- netbird
goarch: arm64
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
ids:
- netbird
goarch: arm
goarm: 6
use: buildx
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-amd64
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
ids:
- netbird-relay
goarch: amd64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
ids:
- netbird-relay
goarch: arm64
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
ids:
- netbird-relay
goarch: arm
goarm: 6
use: buildx
dockerfile: relay/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-amd64
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
ids:
- netbird-signal
goarch: amd64
use: buildx
dockerfile: signal/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
ids:
- netbird-signal
goarch: arm64
use: buildx
dockerfile: signal/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
ids:
- netbird-signal
goarch: arm
goarm: 6
use: buildx
dockerfile: signal/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-amd64
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
ids:
- netbird-mgmt
goarch: amd64
use: buildx
dockerfile: management/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
ids:
- netbird-mgmt
goarch: arm64
use: buildx
dockerfile: management/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-arm
ids:
- netbird-mgmt
goarch: arm
goarm: 6
use: buildx
dockerfile: management/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-amd64
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
ids:
- netbird-mgmt
goarch: amd64
use: buildx
dockerfile: management/Dockerfile.debug
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
ids:
- netbird-mgmt
goarch: arm64
use: buildx
dockerfile: management/Dockerfile.debug
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/management:{{ .Version }}-debug-arm
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
ids:
- netbird-mgmt
goarch: arm
goarm: 6
use: buildx
dockerfile: management/Dockerfile.debug
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-amd64
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
ids:
- netbird-upload
goarch: amd64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
ids:
- netbird-upload
goarch: arm64
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
ids:
- netbird-upload
goarch: arm
goarm: 6
use: buildx
dockerfile: upload-server/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-amd64
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
ids:
- netbird-server
goarch: amd64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
ids:
- netbird-server
goarch: arm64
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
ids:
- netbird-server
goarch: arm
goarm: 6
use: buildx
dockerfile: combined/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
ids:
- netbird-proxy
goarch: amd64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/amd64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
ids:
- netbird-proxy
goarch: arm64
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm64"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
- image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
ids:
- netbird-proxy
goarch: arm
goarm: 6
use: buildx
dockerfile: proxy/Dockerfile
build_flag_templates:
- "--platform=linux/arm"
- "--label=org.opencontainers.image.created={{.Date}}"
- "--label=org.opencontainers.image.title={{.ProjectName}}"
- "--label=org.opencontainers.image.version={{.Version}}"
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
- "--label=maintainer=dev@netbird.io"
docker_manifests:
- name_template: netbirdio/netbird:{{ .Version }}
image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
- netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/netbird:latest
image_templates:
- netbirdio/netbird:{{ .Version }}-arm64v8
- netbirdio/netbird:{{ .Version }}-arm
- netbirdio/netbird:{{ .Version }}-amd64
- name_template: netbirdio/netbird:{{ .Version }}-rootless
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/netbird:rootless-latest
image_templates:
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- netbirdio/netbird:{{ .Version }}-rootless-arm
- netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: netbirdio/relay:{{ .Version }}
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/relay:latest
image_templates:
- netbirdio/relay:{{ .Version }}-arm64v8
- netbirdio/relay:{{ .Version }}-arm
- netbirdio/relay:{{ .Version }}-amd64
- name_template: netbirdio/signal:{{ .Version }}
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- netbirdio/signal:{{ .Version }}-arm
- netbirdio/signal:{{ .Version }}-amd64
- name_template: netbirdio/signal:latest
image_templates:
- netbirdio/signal:{{ .Version }}-arm64v8
- netbirdio/signal:{{ .Version }}-arm
- netbirdio/signal:{{ .Version }}-amd64
- name_template: netbirdio/management:{{ .Version }}
image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- netbirdio/management:{{ .Version }}-arm
- netbirdio/management:{{ .Version }}-amd64
- name_template: netbirdio/management:latest
image_templates:
- netbirdio/management:{{ .Version }}-arm64v8
- netbirdio/management:{{ .Version }}-arm
- netbirdio/management:{{ .Version }}-amd64
- name_template: netbirdio/management:debug-latest
image_templates:
- netbirdio/management:{{ .Version }}-debug-arm64v8
- netbirdio/management:{{ .Version }}-debug-arm
- netbirdio/management:{{ .Version }}-debug-amd64
- name_template: netbirdio/upload:{{ .Version }}
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/upload:latest
image_templates:
- netbirdio/upload:{{ .Version }}-arm64v8
- netbirdio/upload:{{ .Version }}-arm
- netbirdio/upload:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:{{ .Version }}
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/netbird-server:latest
image_templates:
- netbirdio/netbird-server:{{ .Version }}-arm64v8
- netbirdio/netbird-server:{{ .Version }}-arm
- netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:latest
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: ghcr.io/netbirdio/netbird:rootless-latest
image_templates:
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
- name_template: ghcr.io/netbirdio/relay:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/relay:latest
image_templates:
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/signal:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/signal:latest
image_templates:
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:latest
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-arm
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/management:debug-latest
image_templates:
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
- name_template: ghcr.io/netbirdio/upload:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/upload:latest
image_templates:
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/netbird-server:latest
image_templates:
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: netbirdio/reverse-proxy:latest
image_templates:
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- netbirdio/reverse-proxy:{{ .Version }}-arm
- netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
image_templates:
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
dockers_v2:
- id: netbird
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird
images:
- netbirdio/netbird
- ghcr.io/netbirdio/netbird
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: client/Dockerfile
extra_files:
- client/netbird-entrypoint.sh
platforms:
- linux/amd64
- linux/arm64
- linux/arm/6
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: netbird-rootless
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird
images:
- netbirdio/netbird
- ghcr.io/netbirdio/netbird
tags:
- "v{{ .Version }}-rootless"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: client/Dockerfile-rootless
extra_files:
- client/netbird-entrypoint.sh
platforms:
- linux/amd64
- linux/arm64
- linux/arm/6
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: relay
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-relay
images:
- netbirdio/relay
- ghcr.io/netbirdio/relay
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: relay/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: signal
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-signal
images:
- netbirdio/signal
- ghcr.io/netbirdio/signal
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: signal/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: management
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-mgmt
images:
- netbirdio/management
- ghcr.io/netbirdio/management
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: management/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: upload
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-upload
images:
- netbirdio/upload
- ghcr.io/netbirdio/upload
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: upload-server/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: netbird-server
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-server
images:
- netbirdio/netbird-server
- ghcr.io/netbirdio/netbird-server
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: combined/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
- id: netbird-proxy
disable: "{{ .Env.SKIP_DOCKER_PUSH }}"
ids:
- netbird-proxy
images:
- netbirdio/reverse-proxy
- ghcr.io/netbirdio/reverse-proxy
tags:
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: proxy/Dockerfile
platforms:
- linux/amd64
- linux/arm64
- linux/arm
annotations:
"org.opencontainers.image.created": "{{.Date}}"
"org.opencontainers.image.title": "{{.ProjectName}}"
"org.opencontainers.image.version": "{{.Version}}"
"org.opencontainers.image.revision": "{{.FullCommit}}"
"org.opencontainers.image.source": "{{.GitURL}}"
"maintainer": "dev@netbird.io"
brews:
- ids:
- default
skip_upload: "{{ .Env.SKIP_PUBLISH }}"
repository:
owner: netbirdio
name: homebrew-tap
@@ -902,6 +440,7 @@ brews:
uploads:
- name: debian
skip: "{{ .Env.SKIP_PUBLISH }}"
ids:
- netbird_deb
mode: archive
@@ -910,6 +449,7 @@ uploads:
method: PUT
- name: yum
skip: "{{ .Env.SKIP_PUBLISH }}"
ids:
- netbird_rpm
mode: archive

View File

@@ -1,5 +1,6 @@
version: 2
env:
- SKIP_PUBLISH={{ if index .Env "SKIP_PUBLISH" }}{{ .Env.SKIP_PUBLISH }}{{ else }}true{{ end }}
project_name: netbird-ui
builds:
- id: netbird-ui
@@ -101,6 +102,7 @@ nfpms:
uploads:
- name: debian
skip: "{{ .Env.SKIP_PUBLISH }}"
ids:
- netbird_ui_deb
mode: archive
@@ -109,6 +111,7 @@ uploads:
method: PUT
- name: yum
skip: "{{ .Env.SKIP_PUBLISH }}"
ids:
- netbird_ui_rpm
mode: archive

View File

@@ -4,7 +4,7 @@
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.23.3
FROM alpine:3.24
# iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache \
bash \
@@ -21,7 +21,7 @@ ENV \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
ARG NETBIRD_BINARY=netbird
ARG TARGETPLATFORM
ARG NETBIRD_BINARY=$TARGETPLATFORM/netbird
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird

View File

@@ -4,7 +4,7 @@
# podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
# podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
FROM alpine:3.22.0
FROM alpine:3.24
RUN apk add --no-cache \
bash \
@@ -27,7 +27,7 @@ ENV \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
ARG NETBIRD_BINARY=netbird
ARG TARGETPLATFORM
ARG NETBIRD_BINARY=$TARGETPLATFORM/netbird
COPY client/netbird-entrypoint.sh /usr/local/bin/netbird-entrypoint.sh
COPY "${NETBIRD_BINARY}" /usr/local/bin/netbird

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
@@ -24,6 +23,7 @@ const (
// Profile represents a profile for gomobile
type Profile struct {
ID string
Name string
IsActive bool
}
@@ -53,10 +53,10 @@ func (p *ProfileArray) Get(i int) *Profile {
├── state.json ← Default profile state
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
└── profiles/ ← Subdirectory for non-default profiles
├── work.json ← Work profile config
├── work.state.json ← Work profile state
├── personal.json ← Personal profile config
── personal.state.json ← Personal profile state
├── work.json ← Legacy work profile config
├── work.state.json ← Legacy work profile state
├── 4c5f5c8198c3989cffb5b5394f5a7ae0.json ← ID profile config
── 4c5f5c8198c3989cffb5b5394f5a7ae0.state.json ← ID profile state
*/
// ProfileManager manages profiles for Android
@@ -99,6 +99,7 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
var profiles []*Profile
for _, p := range internalProfiles {
profiles = append(profiles, &Profile{
ID: p.ID.String(),
Name: p.Name,
IsActive: p.IsActive,
})
@@ -108,55 +109,65 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
}
// GetActiveProfile returns the currently active profile name
func (pm *ProfileManager) GetActiveProfile() (string, error) {
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
activeState, err := pm.serviceMgr.GetActiveProfileState()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
return nil, fmt.Errorf("failed to get active profile: %w", err)
}
return activeState.Name, nil
// ActiveProfileState only stores the ID (and username), not the display
// name. Resolve the ID to the full profile so callers get the real Name.
prof, err := pm.serviceMgr.ResolveProfile(activeState.ID.String(), androidUsername)
if err != nil {
return nil, fmt.Errorf("failed to resolve active profile %q: %w", activeState.ID, err)
}
return &Profile{ID: prof.ID.String(), Name: prof.Name, IsActive: true}, nil
}
// SwitchProfile switches to a different profile
func (pm *ProfileManager) SwitchProfile(profileName string) error {
func (pm *ProfileManager) SwitchProfile(id string) error {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profileName,
ID: profilemanager.ID(id),
Username: androidUsername,
})
if err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
log.Infof("switched to profile: %s", profileName)
log.Infof("switched to profile: %s", id)
return nil
}
// AddProfile creates a new profile
func (pm *ProfileManager) AddProfile(profileName string) error {
// Use ServiceManager (creates profile in profiles/ directory)
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
profile, err := pm.serviceMgr.AddProfile(profileName, androidUsername)
if err != nil {
return fmt.Errorf("failed to add profile: %w", err)
}
log.Infof("created new profile: %s", profileName)
log.Infof("created new profile: %s", profile.ID)
return nil
}
// LogoutProfile logs out from a profile (clears authentication)
func (pm *ProfileManager) LogoutProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
configPath, err := pm.getProfileConfigPath(profileName)
func (pm *ProfileManager) LogoutProfile(id string) error {
configPath, err := pm.getProfileConfigPath(id)
if err != nil {
return err
}
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
return fmt.Errorf("id '%s' is not valid", id)
}
// Check if profile exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("profile '%s' does not exist", profileName)
return fmt.Errorf("profile '%s' does not exist", id)
}
// Read current config using internal profilemanager
@@ -174,53 +185,57 @@ func (pm *ProfileManager) LogoutProfile(profileName string) error {
return fmt.Errorf("failed to save config: %w", err)
}
log.Infof("logged out from profile: %s", profileName)
log.Infof("logged out from profile: %s", id)
return nil
}
// RemoveProfile deletes a profile
func (pm *ProfileManager) RemoveProfile(profileName string) error {
func (pm *ProfileManager) RemoveProfile(id string) error {
// Use ServiceManager (removes profile from profiles/ directory)
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
if err := pm.serviceMgr.RemoveProfile(profilemanager.ID(id), androidUsername); err != nil {
return fmt.Errorf("failed to remove profile: %w", err)
}
log.Infof("removed profile: %s", profileName)
log.Infof("removed profile: %s", id)
return nil
}
// getProfileConfigPath returns the config file path for a profile
// This is needed for Android-specific path handling (netbird.cfg for default profile)
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
func (pm *ProfileManager) getProfileConfigPath(id string) (string, error) {
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
return "", fmt.Errorf("id %q is not valid", id)
}
if id == profilemanager.DefaultProfileName {
// Android uses netbird.cfg for default profile instead of default.json
// Default profile is stored in root configDir, not in profiles/
return filepath.Join(pm.configDir, defaultConfigFilename), nil
}
// Non-default profiles are stored in profiles subdirectory
// This matches the Java Preferences.java expectation
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".json"), nil
return filepath.Join(profilesDir, id+".json"), nil
}
// GetConfigPath returns the config file path for a given profile
// GetConfigPath returns the config file path for a given profile id
// Java should call this instead of constructing paths with Preferences.configFile()
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
return pm.getProfileConfigPath(profileName)
func (pm *ProfileManager) GetConfigPath(id string) (string, error) {
return pm.getProfileConfigPath(id)
}
// GetStateFilePath returns the state file path for a given profile
// Java should call this instead of constructing paths with Preferences.stateFile()
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
func (pm *ProfileManager) GetStateFilePath(id string) (string, error) {
if id == "" || id == profilemanager.DefaultProfileName {
return filepath.Join(pm.configDir, "state.json"), nil
}
profileName = sanitizeProfileName(profileName)
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
return "", fmt.Errorf("id %q is not valid", id)
}
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".state.json"), nil
return filepath.Join(profilesDir, id+".state.json"), nil
}
// GetActiveConfigPath returns the config file path for the currently active profile
@@ -230,7 +245,7 @@ func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetConfigPath(activeProfile)
return pm.GetConfigPath(activeProfile.ID)
}
// GetActiveStateFilePath returns the state file path for the currently active profile
@@ -240,18 +255,5 @@ func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetStateFilePath(activeProfile)
}
// sanitizeProfileName removes invalid characters from profile name
func sanitizeProfileName(name string) string {
// Keep only alphanumeric, underscore, and hyphen
var result strings.Builder
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
result.WriteRune(r)
}
}
return result.String()
return pm.GetStateFilePath(activeProfile.ID)
}

View File

@@ -96,17 +96,19 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
}
handle := activeProf.ID.String()
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
ProfileName: &activeProf.Name,
ProfileName: &handle,
Username: &username,
}
profileState, err := pm.GetProfileState(activeProf.Name)
profileState, err := pm.GetProfileState(activeProf.ID)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -170,14 +172,13 @@ func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, pr
return activeProf, nil
}
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) error {
err := switchProfile(context.Background(), profileName, username)
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, handle string, username string) error {
resolvedID, err := switchProfile(ctx, handle, username)
if err != nil {
return fmt.Errorf("switch profile on daemon: %v", err)
}
err = pm.SwitchProfile(profileName)
if err != nil {
if err := pm.SwitchProfile(resolvedID); err != nil {
return fmt.Errorf("switch profile: %v", err)
}
@@ -205,11 +206,15 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
return nil
}
func switchProfile(ctx context.Context, profileName string, username string) error {
// switchProfile asks the daemon to switch to the profile identified by
// handle (a name, ID, or unique ID prefix). Returns the resolved profile
// ID so the caller can update the local active-profile state without
// re-resolving the handle.
func switchProfile(ctx context.Context, handle string, username string) (profilemanager.ID, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
@@ -217,15 +222,15 @@ func switchProfile(ctx context.Context, profileName string, username string) err
client := proto.NewDaemonServiceClient(conn)
_, err = client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profileName,
resp, err := client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
Username: &username,
})
if err != nil {
return fmt.Errorf("switch profile failed: %v", err)
return "", fmt.Errorf("switch profile failed: %v", err)
}
return nil
return profilemanager.ID(resp.Id), nil
}
func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, activeProf *profilemanager.Profile) error {
@@ -249,7 +254,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
return fmt.Errorf("read config file %s: %v", configFilePath, err)
}
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.ID)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -277,7 +282,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
return nil
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string, profileID profilemanager.ID) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
@@ -291,7 +296,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken := ""
if setupKey == "" && needsLogin {
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileID)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
@@ -306,10 +311,10 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
return nil
}
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileID profilemanager.ID) (*auth.TokenInfo, error) {
hint := ""
pm := profilemanager.NewProfileManager()
profileState, err := pm.GetProfileState(profileName)
profileState, err := pm.GetProfileState(profileID)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {

View File

@@ -27,7 +27,7 @@ func TestLogin(t *testing.T) {
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
sm := profilemanager.ServiceManager{}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "default",
ID: "default",
Username: currUser.Username,
})
if err != nil {

View File

@@ -2,11 +2,16 @@ package cmd
import (
"context"
"errors"
"fmt"
"os/user"
"strings"
"text/tabwriter"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -14,6 +19,8 @@ import (
"github.com/netbirdio/netbird/util"
)
var profileListShowID bool
var profileCmd = &cobra.Command{
Use: "profile",
Short: "Manage NetBird client profiles",
@@ -31,27 +38,40 @@ var profileListCmd = &cobra.Command{
var profileAddCmd = &cobra.Command{
Use: "add <profile_name>",
Short: "Add a new profile",
Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
Long: `Add a new profile. Profile name is free-form, a unique ID is generated for the on-disk config file.`,
Args: cobra.ExactArgs(1),
RunE: addProfileFunc,
}
var profileRenameCmd = &cobra.Command{
Use: "rename <profile> <new_profile_name>",
Short: "Renames an existing profile",
Long: `Renames an existing profile (by a name, ID, or unique ID prefix). Profile name is free-form.`,
Args: cobra.ExactArgs(2),
RunE: renameProfileFunc,
}
var profileRemoveCmd = &cobra.Command{
Use: "remove <profile_name>",
Short: "Remove a profile",
Long: `Remove a profile from the NetBird client. The profile must not be inactive.`,
Args: cobra.ExactArgs(1),
RunE: removeProfileFunc,
Use: "remove <profile>",
Short: "Remove a profile",
Long: `Remove a profile by name, ID, or unique ID prefix.`,
Aliases: []string{"rm"},
Args: cobra.ExactArgs(1),
RunE: removeProfileFunc,
}
var profileSelectCmd = &cobra.Command{
Use: "select <profile_name>",
Use: "select <profile>",
Short: "Select a profile",
Long: `Make the specified profile active. This will switch the client to use the selected profile's configuration.`,
Long: `Make the specified profile active. Accepts a name, ID, or unique ID prefix.`,
Args: cobra.ExactArgs(1),
RunE: selectProfileFunc,
}
func init() {
profileListCmd.Flags().BoolVar(&profileListShowID, "show-id", false, "show the profile ID column")
}
func setupCmd(cmd *cobra.Command) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
@@ -65,6 +85,7 @@ func setupCmd(cmd *cobra.Command) error {
return nil
}
func listProfilesFunc(cmd *cobra.Command, _ []string) error {
if err := setupCmd(cmd); err != nil {
return err
@@ -83,25 +104,33 @@ func listProfilesFunc(cmd *cobra.Command, _ []string) error {
daemonClient := proto.NewDaemonServiceClient(conn)
profiles, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
resp, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
Username: currUser.Username,
})
if err != nil {
return err
}
// list profiles, add a tick if the profile is active
cmd.Println("Found", len(profiles.Profiles), "profiles:")
for _, profile := range profiles.Profiles {
// use a cross to indicate the passive profiles
activeMarker := "✗"
if profile.IsActive {
activeMarker = "✓"
}
cmd.Println(activeMarker, profile.Name)
tw := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0)
if profileListShowID {
fmt.Fprintln(tw, "ID\tNAME\tACTIVE")
} else {
fmt.Fprintln(tw, "NAME\tACTIVE")
}
return nil
for _, profile := range resp.Profiles {
marker := ""
if profile.IsActive {
marker = "✓"
}
name := profilemanager.StripCtrlChars(profile.Name)
id := profilemanager.ID(profile.Id)
if profileListShowID {
fmt.Fprintf(tw, "%s\t%s\t%s\n", id.ShortID(), name, marker)
} else {
fmt.Fprintf(tw, "%s\t%s\n", name, marker)
}
}
return tw.Flush()
}
func addProfileFunc(cmd *cobra.Command, args []string) error {
@@ -121,21 +150,82 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
}
daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0]
_, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
if err != nil {
return fmt.Errorf("add profile request: %w", err)
}
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
if dupCount > 1 {
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount-1, profileName)
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
}
id := profilemanager.ID(resp.Id)
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
return nil
}
func renameProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
}
cmd.Println("Profile added successfully:", profileName)
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn)
handle := args[0]
newProfilename := args[1]
resp, err := daemonClient.RenameProfile(cmd.Context(), &proto.RenameProfileRequest{
Handle: handle,
Username: currUser.Username,
NewProfileName: newProfilename,
})
if err != nil {
return wrapAmbiguityError(err, handle)
}
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, newProfilename)
if dupCount > 1 {
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount-1, newProfilename)
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
}
cmd.Printf("Profile renamed from %s to %s\n", profilemanager.StripCtrlChars(resp.OldProfileName), profilemanager.StripCtrlChars(newProfilename))
return nil
}
func countProfilesWithName(ctx context.Context, c proto.DaemonServiceClient, username, name string) (int, error) {
resp, err := c.ListProfiles(ctx, &proto.ListProfilesRequest{Username: username})
if err != nil {
return 0, err
}
n := 0
for _, p := range resp.Profiles {
if p.Name == name {
n++
}
}
return n, nil
}
func removeProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
@@ -153,18 +243,17 @@ func removeProfileFunc(cmd *cobra.Command, args []string) error {
}
daemonClient := proto.NewDaemonServiceClient(conn)
handle := args[0]
profileName := args[0]
_, err = daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
ProfileName: profileName,
resp, err := daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
ProfileName: handle,
Username: currUser.Username,
})
if err != nil {
return err
return wrapAmbiguityError(err, handle)
}
cmd.Println("Profile removed successfully:", profileName)
cmd.Printf("Profile removed: %s\n", resp.Id)
return nil
}
@@ -174,7 +263,7 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
}
profileManager := profilemanager.NewProfileManager()
profileName := args[0]
handle := args[0]
currUser, err := user.Current()
if err != nil {
@@ -191,32 +280,15 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
daemonClient := proto.NewDaemonServiceClient(conn)
profiles, err := daemonClient.ListProfiles(ctx, &proto.ListProfilesRequest{
Username: currUser.Username,
switchResp, err := daemonClient.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
Username: &currUser.Username,
})
if err != nil {
return fmt.Errorf("list profiles: %w", err)
return wrapAmbiguityError(err, handle)
}
var profileExists bool
for _, profile := range profiles.Profiles {
if profile.Name == profileName {
profileExists = true
break
}
}
if !profileExists {
return fmt.Errorf("profile %s does not exist", profileName)
}
if err := switchProfile(cmd.Context(), profileName, currUser.Username); err != nil {
return err
}
err = profileManager.SwitchProfile(profileName)
if err != nil {
if err := profileManager.SwitchProfile(profilemanager.ID(switchResp.Id)); err != nil {
return err
}
@@ -231,6 +303,30 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
}
}
cmd.Println("Profile switched successfully to:", profileName)
id := profilemanager.ID(switchResp.Id)
cmd.Printf("Profile switched to: %s\n", id.ShortID())
return nil
}
// wrapAmbiguityError turns the daemon's gRPC InvalidArgument errors
// (which carry the resolver's message verbatim) into CLI-friendly text
// that points the user at --show-id.
func wrapAmbiguityError(err error, handle string) error {
if err == nil {
return nil
}
st, ok := gstatus.FromError(err)
if !ok {
return err
}
switch st.Code() {
case codes.InvalidArgument:
msg := st.Message()
if strings.Contains(msg, "ambiguous") {
return errors.New(msg + "\nRun `netbird profile list --show-id` to see IDs, then select by ID prefix:\n netbird profile select|remove <id-prefix>")
}
case codes.NotFound:
return fmt.Errorf("profile %q not found", handle)
}
return err
}

View File

@@ -190,6 +190,7 @@ func init() {
// profile commands
profileCmd.AddCommand(profileListCmd)
profileCmd.AddCommand(profileAddCmd)
profileCmd.AddCommand(profileRenameCmd)
profileCmd.AddCommand(profileRemoveCmd)
profileCmd.AddCommand(profileSelectCmd)

View File

@@ -128,13 +128,12 @@ func upFunc(cmd *cobra.Command, args []string) error {
var profileSwitched bool
// switch profile if provided
if profileName != "" {
err = switchProfile(cmd.Context(), profileName, username.Username)
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
err = pm.SwitchProfile(profileName)
if err != nil {
if err := pm.SwitchProfile(resolvedID); err != nil {
return fmt.Errorf("switch profile: %v", err)
}
@@ -190,7 +189,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.ID)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -261,10 +260,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
}
// set the new config
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.Name, username.Username)
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.ID.String(), username.Username)
if _, err := client.SetConfig(ctx, req); err != nil {
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable {
log.Warnf("setConfig method is not available in the daemon")
log.Warnf("setConfig method is not available in the daemon: %s", st.Message())
} else {
return fmt.Errorf("call service setConfig method: %v", err)
}
@@ -289,10 +288,11 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
return fmt.Errorf("setup login request: %v", err)
}
loginRequest.ProfileName = &activeProf.Name
profileID := activeProf.ID.String()
loginRequest.ProfileName = &profileID
loginRequest.Username = &username
profileState, err := pm.GetProfileState(activeProf.Name)
profileState, err := pm.GetProfileState(activeProf.ID)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -329,7 +329,7 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
}
if _, err := client.Up(ctx, &proto.UpRequest{
ProfileName: &activeProf.Name,
ProfileName: &profileID,
Username: &username,
}); err != nil {
return fmt.Errorf("call service up method: %v", err)

View File

@@ -29,14 +29,14 @@ func TestUpDaemon(t *testing.T) {
}
sm := profilemanager.ServiceManager{}
err = sm.AddProfile("test1", currUser.Username)
created, err := sm.AddProfile("test1", currUser.Username)
if err != nil {
t.Fatalf("failed to add profile: %v", err)
return
}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "test1",
ID: created.ID,
Username: currUser.Username,
})
if err != nil {

View File

@@ -41,7 +41,6 @@ type ICEBind struct {
*wgConn.StdNetBind
transportNet transport.Net
filterFn udpmux.FilterFn
address wgaddr.Address
mtu uint16
@@ -61,12 +60,11 @@ type ICEBind struct {
ipv6Conn *net.UDPConn
}
func NewICEBind(transportNet transport.Net, filterFn udpmux.FilterFn, address wgaddr.Address, mtu uint16) *ICEBind {
func NewICEBind(transportNet transport.Net, address wgaddr.Address, mtu uint16) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{
StdNetBind: b,
transportNet: transportNet,
filterFn: filterFn,
address: address,
mtu: mtu,
endpoints: make(map[netip.Addr]net.Conn),
@@ -265,7 +263,6 @@ func (s *ICEBind) createOrUpdateMux() {
udpmux.UniversalUDPMuxParams{
UDPConn: muxConn,
Net: s.transportNet,
FilterFn: s.filterFn,
WGAddress: s.address,
MTU: s.mtu,
},

View File

@@ -289,7 +289,7 @@ func setupICEBind(t *testing.T) *ICEBind {
IP: netip.MustParseAddr("100.64.0.1"),
Network: netip.MustParsePrefix("100.64.0.0/10"),
}
return NewICEBind(transportNet, nil, address, 1280)
return NewICEBind(transportNet, address, 1280)
}
func createDualStackConns(t *testing.T) (*net.UDPConn, *net.UDPConn) {

View File

@@ -1,10 +1,13 @@
package device
import (
"fmt"
"net/netip"
"runtime/debug"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/tun"
)
@@ -41,10 +44,13 @@ type PacketCapture interface {
type FilteredDevice struct {
tun.Device
filter PacketFilter
capture atomic.Pointer[PacketCapture]
mutex sync.RWMutex
closeOnce sync.Once
filter PacketFilter
capture atomic.Pointer[PacketCapture]
// panicHandler is invoked after a panic in the underlying device is
// recovered in Read or Write.
panicHandler atomic.Pointer[func()]
mutex sync.RWMutex
closeOnce sync.Once
}
// newDeviceFilter constructor function
@@ -70,7 +76,7 @@ func (d *FilteredDevice) Close() error {
// Read wraps read method with filtering feature
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
if n, err = d.deviceRead(bufs, sizes, offset); err != nil {
return 0, err
}
@@ -112,7 +118,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RUnlock()
if filter == nil {
return d.Device.Write(bufs, offset)
return d.deviceWrite(bufs, offset)
}
filteredBufs := make([][]byte, 0, len(bufs))
@@ -125,9 +131,44 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
}
}
n, err := d.Device.Write(filteredBufs, offset)
n += dropped
return n, err
n, err := d.deviceWrite(filteredBufs, offset)
if err != nil {
return n, err
}
return n + dropped, nil
}
// deviceRead calls the underlying device Read, recovering from panics in the
// wintun read path and converting them into errors.
func (d *FilteredDevice) deviceRead(bufs [][]byte, sizes []int, offset int) (n int, err error) {
defer d.recoverFromPanic("read", &n, &err)
return d.Device.Read(bufs, sizes, offset)
}
// deviceWrite calls the underlying device Write, recovering from panics in the
// wintun write path and converting them into errors.
func (d *FilteredDevice) deviceWrite(bufs [][]byte, offset int) (n int, err error) {
defer d.recoverFromPanic("write", &n, &err)
return d.Device.Write(bufs, offset)
}
// recoverFromPanic converts a panic in the underlying device into a regular
// error and invokes the registered panic handler. The wintun read path is
// known to panic on zero-length packets that third-party filter drivers can
// place in the ring.
func (d *FilteredDevice) recoverFromPanic(op string, n *int, err *error) {
r := recover()
if r == nil {
return
}
log.Errorf("recovered panic in tun device %s: %v\n%s", op, r, debug.Stack())
*n = 0
*err = fmt.Errorf("tun device %s panic: %v", op, r)
if handler := d.panicHandler.Load(); handler != nil {
(*handler)()
}
}
// SetFilter sets packet filter to device
@@ -137,6 +178,17 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) {
d.mutex.Unlock()
}
// SetPanicHandler registers a handler invoked after a recovered panic in Read
// or Write. The device is unusable after such a panic; the handler should
// trigger recreation of the interface. Pass nil to remove.
func (d *FilteredDevice) SetPanicHandler(handler func()) {
if handler == nil {
d.panicHandler.Store(nil)
return
}
d.panicHandler.Store(&handler)
}
// SetCapture sets or clears the packet capture sink. Pass nil to disable.
// Uses atomic store so the hot path (Read/Write) is a single pointer load
// with no locking overhead when capture is off.

View File

@@ -221,3 +221,60 @@ func TestDeviceWrapperRead(t *testing.T) {
}
})
}
func TestDeviceWrapperReadPanic(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Read(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(bufs [][]byte, sizes []int, offset int) (int, error) {
// Reproduce the wintun zero-length packet panic (index out of range).
packet := make([]byte, 0)
return int(packet[0]), nil
})
wrapped := newDeviceFilter(tun)
handlerCalled := false
wrapped.SetPanicHandler(func() { handlerCalled = true })
n, err := wrapped.Read([][]byte{{}}, []int{0}, 0)
if err == nil {
t.Errorf("expected error from recovered panic, got nil")
}
if n != 0 {
t.Errorf("expected n=0, got %d", n)
}
if !handlerCalled {
t.Errorf("expected panic handler to be called")
}
}
func TestDeviceWrapperWritePanic(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
tun := mocks.NewMockDevice(ctrl)
tun.EXPECT().Write(gomock.Any(), gomock.Any()).
DoAndReturn(func(bufs [][]byte, offset int) (int, error) {
packet := make([]byte, 0)
return int(packet[0]), nil
})
wrapped := newDeviceFilter(tun)
handlerCalled := false
wrapped.SetPanicHandler(func() { handlerCalled = true })
n, err := wrapped.Write([][]byte{{0x45, 0x00}}, 0)
if err == nil {
t.Errorf("expected error from recovered panic, got nil")
}
if n != 0 {
t.Errorf("expected n=0, got %d", n)
}
if !handlerCalled {
t.Errorf("expected panic handler to be called")
}
}

View File

@@ -32,8 +32,6 @@ type TunKernelDevice struct {
link *wgLink
udpMuxConn net.PacketConn
udpMux *udpmux.UniversalUDPMuxDefault
filterFn udpmux.FilterFn
}
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu uint16, transportNet transport.Net) *TunKernelDevice {
@@ -104,7 +102,6 @@ func (t *TunKernelDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
bindParams := udpmux.UniversalUDPMuxParams{
UDPConn: nbnet.WrapPacketConn(rawSock),
Net: t.transportNet,
FilterFn: t.filterFn,
WGAddress: t.address,
MTU: t.mtu,
}

View File

@@ -63,7 +63,6 @@ type WGIFaceOpts struct {
MTU uint16
MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net
FilterFn udpmux.FilterFn
DisableDNS bool
}

View File

@@ -11,7 +11,7 @@ import (
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
var tun WGTunDevice
if netstack.IsEnabled() {

View File

@@ -9,7 +9,7 @@ import (
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
if netstack.IsEnabled() {
wgIFace := &WGIface{

View File

@@ -10,7 +10,7 @@ import (
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunFd),

View File

@@ -14,7 +14,7 @@ import (
// NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
return &WGIface{
tun: device.NewNetstackDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()),
userspaceBind: true,
@@ -30,7 +30,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
}
if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, opts.Address, opts.MTU)
iceBind := bind.NewICEBind(opts.TransportNet, opts.Address, opts.MTU)
return &WGIface{
tun: device.NewTunDevice(opts.IFaceName, opts.Address, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind),
userspaceBind: true,

View File

@@ -8,8 +8,6 @@ import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
@@ -22,10 +20,6 @@ import (
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
// FilterFn is a function that filters out candidates based on the address.
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct {
@@ -43,7 +37,6 @@ type UniversalUDPMuxParams struct {
UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration
Net transport.Net
FilterFn FilterFn
WGAddress wgaddr.Address
MTU uint16
}
@@ -68,7 +61,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
PacketConn: params.UDPConn,
mux: m,
logger: params.Logger,
filterFn: params.FilterFn,
address: params.WGAddress,
}
@@ -115,15 +107,12 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
}
}
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
// UDPConn is a wrapper around UDPMux conn that overrides WriteTo to drop packets destined for the overlay subnet.
type UDPConn struct {
net.PacketConn
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
filterFn FilterFn
// TODO: reset cache on route changes
addrCache sync.Map
address wgaddr.Address
mux *UniversalUDPMuxDefault
logger logging.LeveledLogger
address wgaddr.Address
}
// GetPacketConn returns the underlying PacketConn
@@ -132,65 +121,16 @@ func (u *UDPConn) GetPacketConn() net.PacketConn {
}
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil {
udpAddr, ok := addr.(*net.UDPAddr)
if !ok {
return u.PacketConn.WriteTo(b, addr)
}
if isRouted, found := u.addrCache.Load(addr.String()); found {
return u.handleCachedAddress(isRouted.(bool), b, addr)
}
return u.handleUncachedAddress(b, addr)
}
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil {
return 0, err
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr)
if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err)
return nil
}
a, err := netip.ParseAddr(host)
if err != nil {
log.Errorf("Failed to parse address %s: %v", addr, err)
return nil
}
if u.address.Network.Contains(a) {
dst := udpAddr.AddrPort().Addr().Unmap()
if (u.address.Network.IsValid() && u.address.Network.Contains(dst)) || (u.address.IPv6Net.IsValid() && u.address.IPv6Net.Contains(dst)) {
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return 0, fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else {
u.addrCache.Store(addr.String(), isRouted)
if isRouted {
// Extra log, as the error only shows up with ICE logging enabled
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
}
}
return nil
}
func getHostFromAddr(addr net.Addr) (string, error) {
host, _, err := net.SplitHostPort(addr.String())
return host, err
return u.PacketConn.WriteTo(b, addr)
}
// GetSharedConn returns the shared udp conn
@@ -225,6 +165,13 @@ func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.A
return nil
}
src := udpAddr.AddrPort().Addr().Unmap()
wg := m.params.WGAddress
if (wg.Network.IsValid() && wg.Network.Contains(src)) || (wg.IPv6Net.IsValid() && wg.IPv6Net.Contains(src)) {
log.Debugf("dropping STUN message from overlay source %s", udpAddr)
return nil
}
if m.isXORMappedResponse(msg, udpAddr.String()) {
err := m.handleXORMappedResponse(udpAddr, msg)
if err != nil {

View File

@@ -66,7 +66,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
iceBind := bind.NewICEBind(nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,

View File

@@ -22,7 +22,7 @@ func seedProxyForProxyCloseByRemoteConn() ([]proxyInstance, error) {
if err != nil {
return nil, err
}
iceBind := bind.NewICEBind(nil, nil, wgAddress, 1280)
iceBind := bind.NewICEBind(nil, wgAddress, 1280)
endpointAddress := &net.UDPAddr{
IP: net.IPv4(10, 0, 0, 1),
Port: 1234,

View File

@@ -118,6 +118,8 @@ func (c *ConnectClient) RunOniOS(
networkChangeListener listener.NetworkChangeListener,
dnsManager dns.IosDnsManager,
stateFilePath string,
cacheDir string,
logFilePath string,
) error {
// Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension.
debug.SetGCPercent(5)
@@ -127,8 +129,9 @@ func (c *ConnectClient) RunOniOS(
NetworkChangeListener: networkChangeListener,
DnsManager: dnsManager,
StateFilePath: stateFilePath,
TempDir: cacheDir,
}
return c.run(mobileDependency, nil, "")
return c.run(mobileDependency, nil, logFilePath)
}
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {

View File

@@ -250,6 +250,7 @@ type BundleGenerator struct {
syncResponse *mgmProto.SyncResponse
logPath string
tempDir string
statePath string
cpuProfile []byte
capturePath string
refreshStatus func() // Optional callback to refresh status before bundle generation
@@ -276,6 +277,7 @@ type GeneratorDependencies struct {
SyncResponse *mgmProto.SyncResponse
LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
StatePath string // Path to the state file. If empty, the ServiceManager default path is used.
CPUProfile []byte
CapturePath string
RefreshStatus func()
@@ -299,6 +301,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
syncResponse: deps.SyncResponse,
logPath: deps.LogPath,
tempDir: deps.TempDir,
statePath: deps.StatePath,
cpuProfile: deps.CPUProfile,
capturePath: deps.CapturePath,
refreshStatus: deps.RefreshStatus,
@@ -850,8 +853,11 @@ func (g *BundleGenerator) maskSecrets() {
}
func (g *BundleGenerator) addStateFile() error {
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()
path := g.statePath
if path == "" {
sm := profilemanager.NewServiceManager("")
path = sm.GetStatePath()
}
if path == "" {
return nil
}

View File

@@ -0,0 +1,36 @@
//go:build ios
package debug
import (
"path/filepath"
log "github.com/sirupsen/logrus"
)
// swiftLogFile is the Swift app log written by the iOS app into the same log
// directory as the Go client log, so it can be collected into the bundle.
const swiftLogFile = "swift-log.log"
// addPlatformLog collects logs for the iOS debug bundle. iOS has no logcat or
// systemd journal, so we rely on file-based logs. addLogfile handles the Go
// client log (logPath) with rotation, the stderr/stdout companions and
// anonymization. The iOS app writes its own Swift log into the same directory,
// so we add it alongside the Go log.
func (g *BundleGenerator) addPlatformLog() error {
if err := g.addLogfile(); err != nil {
return err
}
if g.logPath == "" {
return nil
}
swiftLogPath := filepath.Join(filepath.Dir(g.logPath), swiftLogFile)
if err := g.addSingleLogfile(swiftLogPath, swiftLogFile); err != nil {
// The Swift log is best-effort: the app may not have written it yet.
log.Warnf("failed to add %s to debug bundle: %v", swiftLogFile, err)
}
return nil
}

View File

@@ -1,4 +1,4 @@
//go:build !android
//go:build !android && !ios
package debug

View File

@@ -843,6 +843,7 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
"Name": "non-config: profile name is not needed for debug purposes",
"policy": "non-config: in-memory MDM policy snapshot, surfaced via Config.Policy() / GetConfigResponse.MDMManagedFields",
}

View File

@@ -53,7 +53,6 @@ import (
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/syncstore"
"github.com/netbirdio/netbird/client/internal/updater"
@@ -65,7 +64,6 @@ import (
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
@@ -240,7 +238,7 @@ type Engine struct {
syncStore syncstore.Store
syncStoreDir string
flowManager nftypes.FlowManager
flowManager nftypes.FlowManager
// auto-update
updateManager *updater.Manager
@@ -531,6 +529,10 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return fmt.Errorf("create wg interface: %w", err)
}
if filteredDevice := e.wgInterface.GetDevice(); filteredDevice != nil {
filteredDevice.SetPanicHandler(e.triggerClientRestart)
}
if err := e.createFirewall(); err != nil {
e.close()
return err
@@ -1075,11 +1077,17 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return ErrResetConnection
}
if !e.config.DisableIPv6 && e.hasIPv6Changed(conf) {
log.Infof("peer IPv6 address changed, restarting client")
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return ErrResetConnection
if !e.config.DisableIPv6 {
reset, err := e.reconcileIPv6(conf)
if err != nil {
log.Warnf("reconcile IPv6 from PeerConfig: %v", err)
}
if reset {
log.Infof("peer IPv6 address changed value, restarting client")
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return ErrResetConnection
}
}
if conf.GetSshConfig() != nil {
@@ -1101,25 +1109,58 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return nil
}
// hasIPv6Changed reports whether the IPv6 overlay address in the peer config
// differs from the configured address (added, removed, or changed).
// Compares against e.config.WgAddr (not the interface address, which may have
// been cleared by ClearIPv6 if OS assignment failed).
func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool {
current := e.config.WgAddr
// reconcileIPv6 applies the management-supplied IPv6 overlay address to the
// engine's WireGuard interface in place when possible. Three transitions:
//
// - First v6 assignment (current had no v6, conf carries one): apply via
// WGIface.UpdateAddr, no reset. Critical for embedded clients whose
// boot config has no v6 — without this we reset on every fresh start
// once management has v6 enabled, orphaning any netstack listeners
// held outside the engine.
// - v6 removed (current had v6, conf carries none): clear in place, no
// reset.
// - v6 swapped to a different non-empty value: returns reset=true so the
// caller falls back to the engine-recreate path — the underlying
// interface address can't be safely swapped in place across all
// backends (gVisor netstack in particular fixes its address at
// CreateNetTUN time).
//
// Mutates e.config.WgAddr to match the applied state so subsequent
// PeerConfig comparisons are stable.
func (e *Engine) reconcileIPv6(conf *mgmProto.PeerConfig) (reset bool, err error) {
raw := conf.GetAddressV6()
current := e.config.WgAddr
if len(raw) == 0 {
return current.HasIPv6()
if !current.HasIPv6() {
return false, nil
}
current.ClearIPv6()
e.config.WgAddr = current
if err := e.wgInterface.UpdateAddr(current); err != nil {
return false, fmt.Errorf("clear ipv6 on wg interface: %w", err)
}
return false, nil
}
prefix, err := netiputil.DecodePrefix(raw)
if err != nil {
log.Errorf("decode v6 overlay address: %v", err)
return false
incoming := current
if err := incoming.SetIPv6FromCompact(raw); err != nil {
return false, fmt.Errorf("decode v6 overlay address: %w", err)
}
return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked()
if !current.HasIPv6() {
e.config.WgAddr = incoming
if err := e.wgInterface.UpdateAddr(incoming); err != nil {
return false, fmt.Errorf("apply ipv6 on wg interface: %w", err)
}
return false, nil
}
if current.IPv6 == incoming.IPv6 && current.IPv6Net == incoming.IPv6Net {
return false, nil
}
return true, nil
}
func (e *Engine) receiveJobEvents() {
@@ -1711,6 +1752,13 @@ func (e *Engine) receiveSignalEvents() {
return e.ctx.Err()
}
// Self-addressed heartbeat: the signal client's receive watchdog
// round-trips this through the server to confirm the receive stream
// is delivering. Liveness is already recorded before this handler.
if msg.GetBody().GetType() == sProto.Body_HEARTBEAT {
return nil
}
conn, ok := e.peerStore.PeerConn(msg.Key)
if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key)
@@ -1909,7 +1957,6 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
WGPrivKey: e.config.WgPrivateKey.String(),
MTU: e.config.MTU,
TransportNet: transportNet,
FilterFn: e.addrViaRoutes,
DisableDNS: e.config.DisableDNS,
}
@@ -2157,21 +2204,6 @@ func (e *Engine) startNetworkMonitor() {
}()
}
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
var vpnRoutes []netip.Prefix
for _, routes := range e.routeManager.GetClientRoutes() {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
return true, prefix, nil
}
return false, netip.Prefix{}, nil
}
func (e *Engine) stopDNSServer() {
if e.dnsServer == nil {
return

View File

@@ -0,0 +1,305 @@
package internal
import (
"context"
"errors"
"net/netip"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/peer"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
)
// reconcileIPv6 / updateConfig regression suite. Locks down the behavior that
// PR #5631 (main-side IPv6 overlay support) accidentally broke for embedded
// netstack clients: any first NetworkMap update that brings an IPv6 address
// used to trigger ErrResetConnection, which destroys the netstack and orphans
// every listener bound on it (proxy-side inbound listeners in particular).
// The fix in reconcileIPv6 distinguishes "v6 first-assigned" (apply in place)
// from "v6 swapped value" (must reset).
func mustEncodeV6Prefix(t *testing.T, p netip.Prefix) []byte {
t.Helper()
b, err := netiputil.EncodePrefix(p)
require.NoError(t, err, "encode v6 prefix %s", p)
return b
}
// reconcileIPv6Fixture builds the smallest Engine the function under test
// needs: a config (with WgAddr being the load-bearing field) and a wgInterface
// whose UpdateAddr call we can observe.
func reconcileIPv6Fixture(t *testing.T, initial wgaddr.Address) (*Engine, *MockWGIface, *wgaddr.Address) {
t.Helper()
var applied wgaddr.Address
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return initial },
UpdateAddrFunc: func(a wgaddr.Address) error {
applied = a
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: initial},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
return e, mock, &applied
}
func TestReconcileIPv6_FirstAssignment_AppliesInPlace(t *testing.T) {
// Embedded clients boot v4-only; management later assigns a v6 overlay.
// The fix: apply v6 in place, return reset=false. Pre-fix this case
// fell through to the "v6 changed" branch and reset the engine.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
e, mock, applied := reconcileIPv6Fixture(t, v4)
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "first v6 assignment must NOT request an engine reset")
require.True(t, e.config.WgAddr.HasIPv6(), "engine config must record the new v6")
assert.Equal(t, v6Prefix.Addr(), e.config.WgAddr.IPv6, "engine config v6 address must match")
assert.Equal(t, v6Prefix.Masked(), e.config.WgAddr.IPv6Net, "engine config v6 prefix must match")
require.True(t, applied.HasIPv6(), "WGIface.UpdateAddr must be called with v6 populated")
assert.Equal(t, v6Prefix.Addr(), applied.IPv6, "UpdateAddr must carry the new v6")
_ = mock
}
func TestReconcileIPv6_NoChange_NoOp(t *testing.T) {
// Steady state: management redelivers the same PeerConfig. No interface
// mutation, no reset. Guards against an infinite reset loop if the
// comparison ever drifts (e.g. address-vs-prefix masking bugs).
v6Prefix := netip.MustParsePrefix("fd00::1/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, v6Prefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "unchanged v6 must NOT trigger reset")
assert.False(t, updateAddrCalled, "unchanged v6 must NOT call UpdateAddr")
}
func TestReconcileIPv6_Removed_AppliesInPlace(t *testing.T) {
// Management withdraws v6 (e.g. account toggled off the v6 group).
// Cleared in place, no reset.
v6Prefix := netip.MustParsePrefix("fd00::1/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, v6Prefix)))
e, _, applied := reconcileIPv6Fixture(t, addr)
e.config.WgAddr = addr
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: nil,
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "v6 removed must NOT trigger reset")
assert.False(t, e.config.WgAddr.HasIPv6(), "engine config must reflect v6 cleared")
assert.False(t, applied.HasIPv6(), "UpdateAddr must receive cleared v6")
}
func TestReconcileIPv6_PrefixLengthChanged_RequestsReset(t *testing.T) {
// Same v6 host, different mask (e.g. /64 → /80). Treated like a value
// change because the new netmask redefines the broadcast/scope.
oldPrefix := netip.MustParsePrefix("fd00::1/64")
newPrefix := netip.MustParsePrefix("fd00::1/80")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, oldPrefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, newPrefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.True(t, reset, "v6 prefix length change must request a reset")
assert.False(t, updateAddrCalled, "v6 prefix length change must NOT touch the interface")
}
func TestReconcileIPv6_ValueChanged_RequestsReset(t *testing.T) {
// v6 was X, now Y. The netstack backend can't safely swap an existing
// address in place — fall back to the engine recreate path.
oldPrefix := netip.MustParsePrefix("fd00::1/64")
newPrefix := netip.MustParsePrefix("fd00::2/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, oldPrefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, newPrefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.True(t, reset, "v6 value change must request a reset")
assert.False(t, updateAddrCalled,
"v6 value change must NOT call UpdateAddr — caller will recreate the interface")
}
func TestReconcileIPv6_InvalidBytes_ReturnsError(t *testing.T) {
// Corrupt PeerConfig.AddressV6 must not crash the engine and must not
// trigger a spurious reset.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
e, _, applied := reconcileIPv6Fixture(t, v4)
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: []byte{0x00}, // truncated, definitely not a valid prefix
}
reset, err := e.reconcileIPv6(conf)
require.Error(t, err, "malformed v6 bytes must surface an error")
assert.False(t, reset, "decode error must NOT request a reset")
assert.False(t, applied.HasIPv6(), "decode error must NOT touch the interface")
}
func TestReconcileIPv6_UpdateAddrError_DoesNotPropagateReset(t *testing.T) {
// If WGIface.UpdateAddr fails (e.g. OS-side assignment error on a
// kernel device), reconcileIPv6 returns the error to the caller for
// logging — but it must NOT request a reset. The whole point of the
// fix is to AVOID the reset cascade on v6 transitions.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return v4 },
UpdateAddrFunc: func(_ wgaddr.Address) error { return errors.New("os refused address") },
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: v4},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.Error(t, err, "UpdateAddr failure must surface")
assert.False(t, reset, "UpdateAddr failure must NOT request a reset")
}
func TestUpdateConfig_V6FirstAssignment_DoesNotResetEngine(t *testing.T) {
// The integration check: updateConfig must not return ErrResetConnection
// when the only change between current state and the new PeerConfig is
// "v6 added". Pre-fix this returned ErrResetConnection, tearing down
// every listener bound on the engine's netstack.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return v4 },
UpdateAddrFunc: func(_ wgaddr.Address) error { return nil },
IsUserspaceBindFunc: func() bool { return true },
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: v4, WgPort: 51820},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
statusRecorder: peer.NewRecorder("https://mgm.test"),
}
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
err := e.updateConfig(conf)
assert.NoError(t, err,
"updateConfig MUST NOT return ErrResetConnection when v6 is added for the first time — that's the bug fix")
assert.NotErrorIs(t, err, ErrResetConnection)
require.True(t, e.config.WgAddr.HasIPv6(), "engine config must record the assigned v6 after updateConfig")
assert.Equal(t, v6Prefix.Addr(), e.config.WgAddr.IPv6)
}

View File

@@ -66,7 +66,6 @@ import (
"github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
@@ -1707,82 +1706,12 @@ func getPeers(e *Engine) int {
return len(e.peerStore.PeersPubKey())
}
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
t.Helper()
b, err := netiputil.EncodePrefix(p)
require.NoError(t, err)
return b
}
func TestEngine_hasIPv6Changed(t *testing.T) {
v4Only := wgaddr.MustParseWGAddress("100.64.0.1/16")
v4v6 := wgaddr.MustParseWGAddress("100.64.0.1/16")
v4v6.IPv6 = netip.MustParseAddr("fd00::1")
v4v6.IPv6Net = netip.MustParsePrefix("fd00::1/64").Masked()
tests := []struct {
name string
current wgaddr.Address
confV6 []byte
expected bool
}{
{
name: "no v6 before, no v6 now",
current: v4Only,
confV6: nil,
expected: false,
},
{
name: "no v6 before, v6 added",
current: v4Only,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")),
expected: true,
},
{
name: "had v6, now removed",
current: v4v6,
confV6: nil,
expected: true,
},
{
name: "had v6, same v6",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")),
expected: false,
},
{
name: "had v6, different v6",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::2/64")),
expected: true,
},
{
name: "same v6 addr, different prefix length",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/80")),
expected: true,
},
{
name: "decode error keeps status quo",
current: v4Only,
confV6: []byte{1, 2, 3},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine := &Engine{
config: &EngineConfig{WgAddr: tt.current},
}
conf := &mgmtProto.PeerConfig{
AddressV6: tt.confV6,
}
assert.Equal(t, tt.expected, engine.hasIPv6Changed(conf))
})
}
}
// The former TestEngine_hasIPv6Changed has been superseded by
// engine_reconcileipv6_test.go — the underlying function (hasIPv6Changed)
// was replaced by reconcileIPv6, which applies "v6 added" / "v6 removed"
// in place instead of demanding a full engine reset. The behavioral
// matrix the old test enforced is now covered, with corrected expectations,
// by TestReconcileIPv6_* in that sibling file.
func TestFilterAllowedIPs(t *testing.T) {
v4v6Addr := wgaddr.MustParseWGAddress("100.64.0.1/16")

View File

@@ -1024,14 +1024,17 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return d.relayStates
}
// extend the list of stun, turn servers with relay address
// extend the list of stun, turn servers with the relay server connections
relayStates := slices.Clone(d.relayStates)
// if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
if err != nil {
// TODO add their status
states := d.relayMgr.RelayStates()
if len(states) == 0 {
// no relay connection tracked yet; surface configured servers as
// unavailable with the real reconnect error when known
err := relayClient.ErrRelayClientNotConnected
if connErr := d.relayMgr.RelayConnectError(); connErr != nil {
err = connErr
}
for _, r := range d.relayMgr.ServerURLs() {
relayStates = append(relayStates, relay.ProbeResult{
URI: r,
@@ -1041,10 +1044,14 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
return relayStates
}
relayState := relay.ProbeResult{
URI: instanceAddr,
for _, rs := range states {
relayStates = append(relayStates, relay.ProbeResult{
URI: rs.URL,
Err: rs.Err,
Transport: rs.Transport,
})
}
return append(relayStates, relayState)
return relayStates
}
func (d *Status) ForwardingRules() []firewall.ForwardRule {
@@ -1405,6 +1412,7 @@ func (fs FullStatus) ToProto() *proto.FullStatus {
pbRelayState := &proto.RelayState{
URI: relayState.URI,
Available: relayState.Err == nil,
Transport: relayState.Transport,
}
if err := relayState.Err; err != nil {
pbRelayState.Error = err.Error()

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"net"
"net/netip"
"strconv"
"sync"
"time"
@@ -165,10 +164,6 @@ func (w *WorkerICE) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HA
return
}
if candidateViaRoutes(candidate, haRoutes) {
return
}
if err := w.agent.AddRemoteCandidate(candidate); err != nil {
w.log.Errorf("error while handling remote candidate")
return
@@ -589,34 +584,6 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
return ec, nil
}
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
var routePrefixes []netip.Prefix
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
routePrefixes = append(routePrefixes, routes[0].Network)
}
}
for _, prefix := range routePrefixes {
// default route is handled by route exclusion / ip rules
if prefix.Bits() == 0 {
continue
}
if prefix.Contains(addr) {
log.Debugf("Ignoring candidate [%s], its address is part of routed network %s", candidate.String(), prefix)
return true
}
}
return false
}
func isRelayCandidate(candidate ice.Candidate) bool {
return candidate.Type() == ice.CandidateTypeRelay
}

View File

@@ -108,6 +108,10 @@ type ConfigInput struct {
// Config Configuration type
type Config struct {
// Name is the human-readable profile name shown in CLI/UI listings.
// It is independent of the profile's on-disk filename (which is the ID).
Name string
// Wireguard private key of local peer
PrivateKey string
PreSharedKey string
@@ -270,6 +274,16 @@ func createNewConfig(input ConfigInput) (*Config, error) {
}
func (config *Config) apply(input ConfigInput) (updated bool, err error) {
if config.Name != "" {
sanitized, err := sanitizeDisplayName(config.Name)
if err != nil {
return false, fmt.Errorf("invalid profile name: %w", err)
}
if sanitized != config.Name {
config.Name = sanitized
updated = true
}
}
if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)

View File

@@ -0,0 +1,118 @@
package profilemanager
import (
"crypto/rand"
"encoding/hex"
"fmt"
"path/filepath"
"strings"
"unicode"
"unicode/utf8"
)
const (
// profileIDByteLen is the number of random bytes generated for a new
// profile ID. The resulting hex string is twice this length.
profileIDByteLen = 16
// shortIDLen is the number of leading characters of an ID we render in
// list output. Profiles per device are few, so 8 chars is collision-safe
// in practice and easy to type as a prefix.
shortIDLen = 8
// maxProfileNameLen caps the human-readable profile name to keep table
// output legible and prevent denial-of-service via huge JSON fields.
maxProfileNameLen = 128
// maxProfileIDLen bounds the on-disk filename we'll accept. New
// IDs are 32 hex chars, legacy stems are sanitized profile names. The
// cap is generous enough to cover both without permitting absurdly
// long filenames.
maxProfileIDLen = 64
)
type ID string
// generateProfileID returns a new random hex ID for a profile file.
func generateProfileID() (ID, error) {
buf := make([]byte, profileIDByteLen)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("read random bytes: %w", err)
}
return ID(hex.EncodeToString(buf)), nil
}
// IsValidProfileFilenameStem reports whether id is safe to use as the stem
// of a profile JSON filename.
func IsValidProfileFilenameStem(id ID) bool {
s := id.String()
if s == "" || len(s) > maxProfileIDLen {
return false
}
if s == defaultProfileName {
return true
}
if strings.ContainsAny(s, `/\`) || strings.Contains(s, "..") {
return false
}
// filepath.Base catches any leftover separators on platforms with
// exotic path conventions.
if filepath.Base(s) != s {
return false
}
for _, r := range s {
if !(unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '-') {
return false
}
}
return true
}
// sanitizeDisplayName normalizes a user-supplied profile display name for
// storage. It strips ASCII control characters, rejects invalid UTF-8, and
// caps the length. Emojis, spaces, punctuation, and non-ASCII letters are
// preserved. Returns an error if nothing usable remains.
func sanitizeDisplayName(name string) (string, error) {
if !utf8.ValidString(name) {
return "", fmt.Errorf("name is not valid UTF-8")
}
name = StripCtrlChars(name)
name = strings.TrimSpace(name)
if name == "" {
return "", fmt.Errorf("name is empty after sanitization")
}
if utf8.RuneCountInString(name) > maxProfileNameLen {
return "", fmt.Errorf("name exceeds %d characters", maxProfileNameLen)
}
return name, nil
}
// StripCtrlChars control characters from a name before printing it.
func StripCtrlChars(name string) string {
var b strings.Builder
b.Grow(len(name))
for _, r := range name {
// Skip C0 controls and DEL, plus C1 controls (0x800x9F).
if r < 0x20 || r == 0x7F || (r >= 0x80 && r <= 0x9F) {
continue
}
b.WriteRune(r)
}
return b.String()
}
// ShortID truncates an ID for display.
func (id ID) ShortID() string {
if id == DefaultProfileName {
return DefaultProfileName
}
runes := []rune(id)
if len(runes) <= shortIDLen {
return id.String()
}
return string(runes[:shortIDLen])
}
func (id ID) String() string {
return string(id)
}

View File

@@ -19,19 +19,41 @@ const (
)
type Profile struct {
Name string
// ID is the on-disk filename stem (without .json). For new profiles
// it is a 32-char hex string; legacy profiles created before the
// ID-keyed layout keep their original name as their ID. The reserved
// value "default" identifies the special default profile.
ID ID
// Name is the human-readable display name. Falls back to ID when the
// underlying JSON has no "name" field set.
Name string
// Path is the absolute path to the profile JSON. Populated by the
// loader so callers do not have to reconstruct it from ID + dir.
Path string
IsActive bool
}
func (p *Profile) FilePath() (string, error) {
if p.Name == "" {
return "", fmt.Errorf("active profile name is empty")
if p.Path != "" {
return p.Path, nil
}
if p.Name == defaultProfileName {
id := p.ID
if id == "" {
id = ID(p.Name)
}
if id == "" {
return "", fmt.Errorf("profile ID is empty")
}
if id == defaultProfileName {
return DefaultConfigPath, nil
}
if !IsValidProfileFilenameStem(id) {
return "", fmt.Errorf("invalid profile ID: %q", id)
}
username, err := user.Current()
if err != nil {
return "", fmt.Errorf("failed to get current user: %w", err)
@@ -42,10 +64,13 @@ func (p *Profile) FilePath() (string, error) {
return "", fmt.Errorf("failed to get config directory for user %s: %w", username.Username, err)
}
return filepath.Join(configDir, p.Name+".json"), nil
return filepath.Join(configDir, id.String()+".json"), nil
}
func (p *Profile) IsDefault() bool {
if p.ID != "" {
return p.ID == defaultProfileName
}
return p.Name == defaultProfileName
}
@@ -57,18 +82,24 @@ func NewProfileManager() *ProfileManager {
return &ProfileManager{}
}
// GetActiveProfile returns the active profile as recorded in the local
// user state file. Only ID is populated.
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
prof := pm.getActiveProfileState()
return &Profile{Name: prof}, nil
id := pm.getActiveProfileState()
return &Profile{ID: id}, nil
}
func (pm *ProfileManager) SwitchProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
// SwitchProfile records the given profile ID as active in the local user
// state file.
func (pm *ProfileManager) SwitchProfile(id ID) error {
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
if err := pm.setActiveProfileState(profileName); err != nil {
if err := pm.setActiveProfileState(id); err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
return nil
@@ -85,7 +116,7 @@ func sanitizeProfileName(name string) string {
}, name)
}
func (pm *ProfileManager) getActiveProfileState() string {
func (pm *ProfileManager) getActiveProfileState() ID {
configDir, err := getConfigDir()
if err != nil {
@@ -113,10 +144,10 @@ func (pm *ProfileManager) getActiveProfileState() string {
return defaultProfileName
}
return profileName
return ID(profileName)
}
func (pm *ProfileManager) setActiveProfileState(profileName string) error {
func (pm *ProfileManager) setActiveProfileState(id ID) error {
configDir, err := getConfigDir()
if err != nil {
@@ -125,7 +156,7 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error {
statePath := filepath.Join(configDir, activeProfileStateFilename)
err = os.WriteFile(statePath, []byte(profileName), 0600)
err = os.WriteFile(statePath, []byte(id), 0600)
if err != nil {
return fmt.Errorf("failed to write active profile state: %w", err)
}
@@ -142,7 +173,7 @@ func GetLoginHint() string {
return ""
}
profileState, err := pm.GetProfileState(activeProf.Name)
profileState, err := pm.GetProfileState(activeProf.ID)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
return ""

View File

@@ -50,14 +50,14 @@ func TestServiceManager_CreateAndGetDefaultProfile(t *testing.T) {
state, err := sm.GetActiveProfileState()
assert.NoError(t, err)
assert.Equal(t, state.Name, defaultProfileName) // No active profile state yet
assert.Equal(t, defaultProfileName, state.ID.String()) // No active profile state yet
err = sm.SetActiveProfileStateToDefault()
assert.NoError(t, err)
active, err := sm.GetActiveProfileState()
assert.NoError(t, err)
assert.Equal(t, "default", active.Name)
assert.Equal(t, "default", active.ID.String())
})
})
}
@@ -92,14 +92,14 @@ func TestServiceManager_SetActiveProfileState(t *testing.T) {
currUser, err := user.Current()
assert.NoError(t, err)
sm := &ServiceManager{}
state := &ActiveProfileState{Name: "foo", Username: currUser.Username}
state := &ActiveProfileState{ID: "foo", Username: currUser.Username}
err = sm.SetActiveProfileState(state)
assert.NoError(t, err)
// Should error on nil or incomplete state
err = sm.SetActiveProfileState(nil)
assert.Error(t, err)
err = sm.SetActiveProfileState(&ActiveProfileState{Name: "", Username: ""})
err = sm.SetActiveProfileState(&ActiveProfileState{ID: "", Username: ""})
assert.Error(t, err)
})
})

View File

@@ -2,6 +2,7 @@ package profilemanager
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
@@ -23,12 +24,43 @@ var (
DefaultConfigPathDir = ""
DefaultConfigPath = ""
ActiveProfileStatePath = ""
)
var (
ErrorOldDefaultConfigNotFound = errors.New("old default config not found")
)
// ErrAmbiguousHandle is returned when a profile handle (ID prefix or name)
// matches more than one profile. Callers can render Candidates to help the
// user disambiguate.
type ErrAmbiguousHandle struct {
Handle string
Candidates []Profile
Kind AmbiguityKind
}
// AmbiguityKind describes which matcher produced the ambiguity, so callers
// can tailor the error message.
type AmbiguityKind int
const (
AmbiguityKindIDPrefix AmbiguityKind = iota
AmbiguityKindName
)
// profileMeta is the minimal slice of a profile JSON we need, so we avoid
// reading all fields
type profileMeta struct {
Name string
}
func (e *ErrAmbiguousHandle) Error() string {
switch e.Kind {
case AmbiguityKindIDPrefix:
return fmt.Sprintf("ID prefix %q is ambiguous (matches %d profiles)", e.Handle, len(e.Candidates))
default:
return fmt.Sprintf("name %q is ambiguous (%d profiles share this name)", e.Handle, len(e.Candidates))
}
}
func init() {
DefaultConfigPathDir = "/var/lib/netbird/"
@@ -54,25 +86,34 @@ func init() {
}
type ActiveProfileState struct {
Name string `json:"name"`
// ID is the on-disk filename stem of the active profile. The JSON tag stays
// as "name" for backwards compatibility with active state files written
// before the ID-based config files. Legacy values were profile names, which
// were also the legacy filename stems, so they still resolve to the correct
// file on disk.
ID ID `json:"name"`
Username string `json:"username"`
}
func (a *ActiveProfileState) FilePath() (string, error) {
if a.Name == "" {
return "", fmt.Errorf("active profile name is empty")
if a.ID == "" {
return "", fmt.Errorf("active profile ID is empty")
}
if a.Name == defaultProfileName {
if a.ID == defaultProfileName {
return DefaultConfigPath, nil
}
if !IsValidProfileFilenameStem(a.ID) {
return "", fmt.Errorf("invalid profile ID: %q", a.ID)
}
configDir, err := getConfigDirForUser(a.Username)
if err != nil {
return "", fmt.Errorf("failed to get config directory for user %s: %w", a.Username, err)
}
return filepath.Join(configDir, a.Name+".json"), nil
return filepath.Join(configDir, a.ID.String()+".json"), nil
}
type ServiceManager struct {
@@ -178,7 +219,7 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
}
return &ActiveProfileState{
Name: "default",
ID: defaultProfileName,
Username: "",
}, nil
} else {
@@ -186,12 +227,12 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
}
}
if activeProfile.Name == "" {
if activeProfile.ID == "" {
if err := s.SetActiveProfileStateToDefault(); err != nil {
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
}
return &ActiveProfileState{
Name: "default",
ID: defaultProfileName,
Username: "",
}, nil
}
@@ -216,25 +257,29 @@ func (s *ServiceManager) setDefaultActiveState() error {
}
func (s *ServiceManager) SetActiveProfileState(a *ActiveProfileState) error {
if a == nil || a.Name == "" {
if a == nil || a.ID == "" {
return errors.New("invalid active profile state")
}
if a.Name != defaultProfileName && a.Username == "" {
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.Name)
if a.ID != defaultProfileName && a.Username == "" {
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.ID)
}
if a.ID != defaultProfileName && !IsValidProfileFilenameStem(a.ID) {
return fmt.Errorf("invalid profile ID: %q", a.ID)
}
if err := util.WriteJsonWithRestrictedPermission(context.Background(), ActiveProfileStatePath, a); err != nil {
return fmt.Errorf("failed to write active profile state: %w", err)
}
log.Infof("active profile set to %s for %s", a.Name, a.Username)
log.Infof("active profile set to %s for %s", a.ID, a.Username)
return nil
}
func (s *ServiceManager) SetActiveProfileStateToDefault() error {
return s.SetActiveProfileState(&ActiveProfileState{
Name: "default",
ID: defaultProfileName,
Username: "",
})
}
@@ -243,57 +288,117 @@ func (s *ServiceManager) DefaultProfilePath() string {
return DefaultConfigPath
}
func (s *ServiceManager) AddProfile(profileName, username string) error {
// AddProfile creates a new profile with a generated ID. The user-supplied
// displayName is stored inside the JSON's name field, the on-disk filename
// uses the generated ID.
//
// The returned Profile carries the freshly-generated ID so callers can
// show it to the user (and so the gRPC AddProfileResponse can include
// it).
func (s *ServiceManager) AddProfile(displayName, username string) (*Profile, error) {
configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
return nil, fmt.Errorf("failed to get config directory: %w", err)
}
profileName = sanitizeProfileName(profileName)
if profileName == defaultProfileName {
return fmt.Errorf("cannot create profile with reserved name: %s", defaultProfileName)
}
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
displayName, err = sanitizeDisplayName(displayName)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if profileExists {
return ErrProfileAlreadyExists
return nil, fmt.Errorf("invalid profile name: %w", err)
}
id, err := generateProfileID()
if err != nil {
return nil, fmt.Errorf("generate profile id: %w", err)
}
profPath := filepath.Join(configDir, id.String()+".json")
cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath})
if err != nil {
return fmt.Errorf("failed to create new config: %w", err)
return nil, fmt.Errorf("failed to create new config: %w", err)
}
cfg.Name = displayName
if err := util.WriteJson(context.Background(), profPath, cfg); err != nil {
return nil, fmt.Errorf("failed to write profile config: %w", err)
}
err = util.WriteJson(context.Background(), profPath, cfg)
return &Profile{
ID: id,
Name: displayName,
Path: profPath,
}, nil
}
func (s *ServiceManager) RenameProfile(id ID, username string, newName string) error {
displayName, err := sanitizeDisplayName(newName)
if err != nil {
return fmt.Errorf("failed to write profile config: %w", err)
return fmt.Errorf("invalid profile name: %w", err)
}
if !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
profiles, err := s.loadAllProfiles(username)
if err != nil {
return fmt.Errorf("load profiles: %w", err)
}
var target *Profile
for i := range profiles {
if profiles[i].ID == id {
target = &profiles[i]
break
}
}
if target == nil {
return ErrProfileNotFound
}
data, err := os.ReadFile(target.Path)
if err != nil {
return err
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return err
}
cfg.Name = displayName
if err := util.WriteJson(context.Background(), target.Path, cfg); err != nil {
return fmt.Errorf("failed to write profile name: %w", err)
}
return nil
}
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
// RemoveProfile deletes the profile identified by id. Callers must have
// already resolved any user-supplied handle to a concrete ID via
// ResolveProfile.
func (s *ServiceManager) RemoveProfile(id ID, username string) error {
if id == defaultProfileName {
defaultName := readProfileName(DefaultConfigPath)
if defaultName == "" {
defaultName = defaultProfileName
}
return fmt.Errorf("cannot remove default profile with name: %s", defaultName)
}
if !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
profileName = sanitizeProfileName(profileName)
if profileName == defaultProfileName {
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
}
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
profiles, err := s.loadAllProfiles(username)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
return fmt.Errorf("load profiles: %w", err)
}
if !profileExists {
var target *Profile
for i := range profiles {
if profiles[i].ID == id {
target = &profiles[i]
break
}
}
if target == nil {
return ErrProfileNotFound
}
@@ -301,57 +406,26 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
if err != nil && !errors.Is(err, ErrNoActiveProfile) {
return fmt.Errorf("failed to get active profile: %w", err)
}
if activeProf != nil && activeProf.Name == profileName {
return fmt.Errorf("cannot remove active profile: %s", profileName)
if activeProf != nil && activeProf.ID == id {
return fmt.Errorf("cannot remove active profile: %s", id)
}
err = util.RemoveJson(profPath)
if err != nil {
if err := util.RemoveJson(target.Path); err != nil {
return fmt.Errorf("failed to remove profile config: %w", err)
}
stateFile := filepath.Join(filepath.Dir(target.Path), id.String()+".state.json")
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
log.Warnf("failed to remove profile state file %s: %v", stateFile, err)
}
return nil
}
// ListProfiles returns every profile for the given user, including the
// default profile, with IsActive flags set.
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("failed to get config directory: %w", err)
}
files, err := util.ListFiles(configDir, "*.json")
if err != nil {
return nil, fmt.Errorf("failed to list profile files: %w", err)
}
var filtered []string
for _, file := range files {
if strings.HasSuffix(file, "state.json") {
continue // skip state files
}
filtered = append(filtered, file)
}
sort.Strings(filtered)
var activeProfName string
activeProf, err := s.GetActiveProfileState()
if err == nil {
activeProfName = activeProf.Name
}
var profiles []Profile
// add default profile always
profiles = append(profiles, Profile{Name: defaultProfileName, IsActive: activeProfName == "" || activeProfName == defaultProfileName})
for _, file := range filtered {
profileName := strings.TrimSuffix(filepath.Base(file), ".json")
var isActive bool
if activeProfName != "" && activeProfName == profileName {
isActive = true
}
profiles = append(profiles, Profile{Name: profileName, IsActive: isActive})
}
return profiles, nil
return s.loadAllProfiles(username)
}
// GetStatePath returns the path to the state file based on the operating system
@@ -369,7 +443,12 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
if activeProf.Name == defaultProfileName {
if activeProf.ID == defaultProfileName {
return defaultStatePath
}
if !IsValidProfileFilenameStem(activeProf.ID) {
log.Warnf("invalid active profile ID %q, using default state path", activeProf.ID)
return defaultStatePath
}
@@ -379,7 +458,7 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
return filepath.Join(configDir, activeProf.Name+".state.json")
return filepath.Join(configDir, activeProf.ID.String()+".state.json")
}
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
@@ -390,3 +469,169 @@ func (s *ServiceManager) getConfigDir(username string) (string, error) {
return getConfigDirForUser(username)
}
// loadAllProfiles returns every profile visible to the daemon for the
// given user, including the default profile. The returned slice is sorted
// by ID for a stable display order.
//
// Each Profile is fully populated: ID is the filename stem, Name comes
// from the JSON's "name" field (falling back to the filename stem when absent)
// and Path is built from a basename read off disk.
func (s *ServiceManager) loadAllProfiles(username string) ([]Profile, error) {
activeID, activeIsDefault := s.activeProfileID()
defaultName := readProfileName(DefaultConfigPath)
if defaultName == "" {
defaultName = defaultProfileName
}
profiles := []Profile{{
ID: defaultProfileName,
Name: defaultName,
Path: DefaultConfigPath,
IsActive: activeIsDefault,
}}
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("get config directory: %w", err)
}
entries, err := os.ReadDir(configDir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return profiles, nil
}
return nil, fmt.Errorf("read profile directory: %w", err)
}
var fileProfiles []Profile
for _, entry := range entries {
if entry.IsDir() {
continue
}
base := entry.Name()
if !strings.HasSuffix(base, ".json") {
continue
}
if strings.HasSuffix(base, ".state.json") {
continue
}
stem := ID(strings.TrimSuffix(base, ".json"))
if stem == defaultProfileName {
// default lives at the top-level config dir, not under /<user>
continue
}
if !IsValidProfileFilenameStem(ID(stem)) {
continue
}
path := filepath.Join(configDir, base)
name := readProfileName(path)
if name == "" {
name = stem.String()
}
fileProfiles = append(fileProfiles, Profile{
ID: stem,
Name: name,
Path: path,
IsActive: stem == ID(activeID),
})
}
sort.Slice(fileProfiles, func(i, j int) bool {
if fileProfiles[i].Name != fileProfiles[j].Name {
return fileProfiles[i].Name < fileProfiles[j].Name
}
// Sort tie-break on ID so duplicate names always render in the same order.
return fileProfiles[i].ID < fileProfiles[j].ID
})
profiles = append(profiles, fileProfiles...)
return profiles, nil
}
// readProfileName parses just the "name" field from the profile Json.
func readProfileName(path string) string {
data, err := os.ReadFile(path)
if err != nil {
return ""
}
var meta profileMeta
if err := json.Unmarshal(data, &meta); err != nil {
return ""
}
return meta.Name
}
// activeProfileID returns the currently-active profile's ID. The second
// return value is true when the active profile is the default one.
func (s *ServiceManager) activeProfileID() (ID, bool) {
state, err := s.GetActiveProfileState()
if err != nil || state == nil {
return defaultProfileName, true
}
if state.ID == "" || state.ID == defaultProfileName {
return defaultProfileName, true
}
return state.ID, false
}
// ResolveProfile turns a user-supplied handle into a Profile. Resolution
// precedence is: exact ID match, then unique exact name, then unique ID
// prefix. Ambiguous matches return *ErrAmbiguousHandle so callers can
// surface the candidates.
func (s *ServiceManager) ResolveProfile(handle, username string) (*Profile, error) {
if handle == "" {
return nil, fmt.Errorf("profile handle is empty")
}
profiles, err := s.loadAllProfiles(username)
if err != nil {
return nil, err
}
for i := range profiles {
if profiles[i].ID == ID(handle) {
return &profiles[i], nil
}
}
var nameMatches []Profile
for i := range profiles {
if profiles[i].Name == handle {
nameMatches = append(nameMatches, profiles[i])
}
}
if len(nameMatches) == 1 {
return &nameMatches[0], nil
}
if len(nameMatches) > 1 {
return nil, &ErrAmbiguousHandle{
Handle: handle,
Candidates: nameMatches,
Kind: AmbiguityKindName,
}
}
// ID prefix match. Skip the default profile so `select d` does not
// accidentally pick it via prefix.
var prefixMatches []Profile
for i := range profiles {
if profiles[i].ID == defaultProfileName {
continue
}
if strings.HasPrefix(profiles[i].ID.String(), handle) {
prefixMatches = append(prefixMatches, profiles[i])
}
}
if len(prefixMatches) == 1 {
return &prefixMatches[0], nil
}
if len(prefixMatches) > 1 {
return nil, &ErrAmbiguousHandle{
Handle: handle,
Candidates: prefixMatches,
Kind: AmbiguityKindIDPrefix,
}
}
return nil, ErrProfileNotFound
}

View File

@@ -0,0 +1,230 @@
package profilemanager
import (
"context"
"errors"
"os"
"os/user"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/util"
)
// withTestSM wires up patched globals + a clean config dir and returns a
// fully initialized ServiceManager plus the username we are scoped to.
func withTestSM(t *testing.T, fn func(sm *ServiceManager, username string)) {
t.Helper()
withTempConfigDir(t, func(configDir string) {
withPatchedGlobals(t, configDir, func() {
u, err := user.Current()
require.NoError(t, err)
sm := &ServiceManager{}
require.NoError(t, sm.CreateDefaultProfile())
fn(sm, u.Username)
})
})
}
func TestServiceProfile_ExactID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
got, err := sm.ResolveProfile(created.ID.String(), username)
require.NoError(t, err)
assert.Equal(t, created.ID, got.ID)
assert.Equal(t, "work", got.Name)
})
}
func TestServiceProfile_IDPrefix(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
prefix := created.ID[:4]
got, err := sm.ResolveProfile(prefix.String(), username)
require.NoError(t, err)
assert.Equal(t, created.ID, got.ID)
})
}
func TestServiceProfile_AmbiguousPrefix(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
// Plant two profiles whose IDs share a known prefix by writing
// the files directly, since generated IDs are random.
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
for _, id := range []string{"abcd1111aaaa", "abcd2222bbbb"} {
path := filepath.Join(configDir, id+".json")
require.NoError(t, util.WriteJson(context.Background(), path, &Config{Name: id}))
}
_, err = sm.ResolveProfile("abcd", username)
var amb *ErrAmbiguousHandle
require.ErrorAs(t, err, &amb)
assert.Equal(t, AmbiguityKindIDPrefix, amb.Kind)
assert.Len(t, amb.Candidates, 2)
})
}
func TestServiceProfile_ExactNameUnique(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.AddProfile("work", username)
require.NoError(t, err)
got, err := sm.ResolveProfile("work", username)
require.NoError(t, err)
assert.Equal(t, "work", got.Name)
})
}
func TestServiceProfile_AmbiguousName(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.AddProfile("work", username)
require.NoError(t, err)
_, err = sm.AddProfile("work", username)
require.NoError(t, err)
_, err = sm.ResolveProfile("work", username)
var amb *ErrAmbiguousHandle
require.ErrorAs(t, err, &amb)
assert.Equal(t, AmbiguityKindName, amb.Kind)
assert.Len(t, amb.Candidates, 2)
})
}
func TestServiceProfile_NotFound(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.ResolveProfile("nope", username)
assert.ErrorIs(t, err, ErrProfileNotFound)
})
}
func TestServiceProfile_DefaultByExactID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
got, err := sm.ResolveProfile(defaultProfileName, username)
require.NoError(t, err)
assert.Equal(t, defaultProfileName, got.ID.String())
})
}
func TestServiceProfile_LegacyFilenameCoexists(t *testing.T) {
// Legacy profiles stored as <name>.json with no "name" JSON field
// should still be discoverable by name and removable by name.
withTestSM(t, func(sm *ServiceManager, username string) {
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
path := filepath.Join(configDir, "legacy.json")
require.NoError(t, util.WriteJson(context.Background(), path, &Config{}))
got, err := sm.ResolveProfile("legacy", username)
require.NoError(t, err)
assert.Equal(t, "legacy", got.ID.String())
// Name falls back to the filename stem when JSON omits it.
assert.Equal(t, "legacy", got.Name)
})
}
func TestAddProfile_AllowsDuplicateWithFlag(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
first, err := sm.AddProfile("work", username)
require.NoError(t, err)
second, err := sm.AddProfile("work", username)
require.NoError(t, err)
assert.NotEqual(t, first.ID, second.ID)
assert.Equal(t, "work", second.Name)
})
}
func TestAddProfile_RejectsInvalidNames(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
cases := []string{
"", // empty
"\x00\x01", // only control chars (becomes empty)
strings.Repeat("a", maxProfileNameLen+1), // too long
}
for _, name := range cases {
_, err := sm.AddProfile(name, username)
assert.Error(t, err, "expected error for %q", name)
}
})
}
func TestRemoveProfile_RejectsInvalidID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
err := sm.RemoveProfile("../escape", username)
assert.Error(t, err)
})
}
func TestSanitizeDisplayName(t *testing.T) {
cases := []struct {
in string
want string
wantErr bool
}{
{"work", "work", false},
{"My Work Account", "My Work Account", false},
{"emoji 🚀 ok", "emoji 🚀 ok", false},
{"漢字テスト", "漢字テスト", false},
{"with\x00null", "withnull", false},
{"\x01\x02\x03", "", true},
{"", "", true},
}
for _, tc := range cases {
got, err := sanitizeDisplayName(tc.in)
if tc.wantErr {
assert.Error(t, err, "case %q", tc.in)
continue
}
assert.NoError(t, err, "case %q", tc.in)
assert.Equal(t, tc.want, got, "case %q", tc.in)
}
}
func TestIsValidProfileFilenameStem(t *testing.T) {
cases := []struct {
in string
want bool
}{
{"default", true},
{"abc123def456", true},
{"legacy-name", true},
{"legacy_name", true},
{"", false},
{"..", false},
{"../etc", false},
{"foo/bar", false},
{`foo\bar`, false},
{"with space", false},
{"with.dot", false},
{strings.Repeat("a", maxProfileIDLen+1), false},
}
for _, tc := range cases {
got := IsValidProfileFilenameStem(ID(tc.in))
assert.Equal(t, tc.want, got, "case %q", tc.in)
}
}
func TestRemoveProfile_DeletesStateFile(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
statePath := filepath.Join(configDir, created.ID.String()+".state.json")
require.NoError(t, os.WriteFile(statePath, []byte(`{"email":"a@b"}`), 0600))
require.NoError(t, sm.RemoveProfile(created.ID, username))
_, err = os.Stat(statePath)
assert.True(t, errors.Is(err, os.ErrNotExist), "state file should be removed")
})
}

View File

@@ -13,13 +13,20 @@ type ProfileState struct {
Email string `json:"email"`
}
func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, error) {
// GetProfileState reads the per-profile state file keyed by profile ID.
// The state file lives in the user's config directory. Legacy state files
// keyed by the old profile name remain readable.
func (pm *ProfileManager) GetProfileState(id ID) (*ProfileState, error) {
configDir, err := getConfigDir()
if err != nil {
return nil, fmt.Errorf("get config directory: %w", err)
}
stateFile := filepath.Join(configDir, profileName+".state.json")
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return nil, fmt.Errorf("invalid profile ID: %q", id)
}
stateFile := filepath.Join(configDir, id.String()+".state.json")
stateFileExists, err := fileExists(stateFile)
if err != nil {
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
@@ -51,7 +58,12 @@ func (pm *ProfileManager) SetActiveProfileState(state *ProfileState) error {
return fmt.Errorf("get active profile: %w", err)
}
stateFile := filepath.Join(configDir, activeProf.Name+".state.json")
id := activeProf.ID
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid active profile ID: %q", id)
}
stateFile := filepath.Join(configDir, id.String()+".state.json")
err = util.WriteJsonWithRestrictedPermission(context.Background(), stateFile, state)
if err != nil {
return fmt.Errorf("write profile state: %w", err)

View File

@@ -32,6 +32,9 @@ type ProbeResult struct {
URI string
Err error
Addr string
// Transport is the negotiated relay transport, empty
// for stun/turn probes or when not connected.
Transport string
}
type StunTurnProbe struct {

View File

@@ -22,14 +22,14 @@ type removePeerCall struct {
}
type mockServer struct {
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
}
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
@@ -51,7 +51,7 @@ func (m *mockServer) RemovePeer(id rp.PeerID) error {
return m.removeErr
}
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Close() error { m.closed = true; return nil }
type setPSKCall struct {

View File

@@ -41,4 +41,3 @@ func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
_, err = DeterministicSeedKey(long, short)
require.Error(t, err)
}

View File

@@ -9,6 +9,7 @@ import (
"net/url"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
@@ -332,6 +333,8 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
}
}
m.notifier.Close()
m.mux.Lock()
defer m.mux.Unlock()
m.clientRoutes = nil
@@ -700,6 +703,8 @@ func resolveURLsToIPs(urls []string) []net.IP {
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
m.mirrorV6ExitPairSelections(clientRoutes)
// An explicit user "deselect all" must not be overridden by management auto-apply.
// Auto-applying an exit node here would call SelectRoutes, which clears the
// deselect-all flag and re-enables every route the user turned off.
@@ -716,6 +721,24 @@ func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HA
m.logExitNodeUpdate(exitNodeInfo)
}
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
// consistent with its v4 base. The v4/v6 exit pair is a single toggle, so the v6
// entry always follows the base: deselecting the v4 exit node also drops its ::/0
// pair, and any stale (orphaned) explicit selection on the v6 entry is reset. This
// runs before selection is read so both collectExitNodeInfo and FilterSelectedExitNodes
// see consistent state, including pairs loaded from persisted selector state.
func (m *DefaultManager) mirrorV6ExitPairSelections(clientRoutes route.HAMap) {
routesByNetID := make(map[route.NetID][]*route.Route, len(clientRoutes))
for haID, routes := range clientRoutes {
routesByNetID[haID.NetID()] = routes
}
for v6ID := range route.V6ExitMergeSet(routesByNetID) {
baseID := route.NetID(strings.TrimSuffix(string(v6ID), route.V6ExitSuffix))
m.routeSelector.SyncPairedSelection(baseID, v6ID)
}
}
type exitNodeInfo struct {
allIDs []route.NetID
selectedByManagement []route.NetID

View File

@@ -0,0 +1,47 @@
package routemanager
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
// TestUpdateRouteSelectorFromManagement_MirrorsV6ExitPair reproduces the bug seen
// in netbird-engine.log: persisted selector state has the v4 exit node deselected
// but its synthesized "-v6" pair explicitly selected (orphaned), so the ::/0 route
// leaked onto the tunnel. The management update must mirror the v4 deselect onto the
// v6 pair so FilterSelectedExitNodes drops it.
func TestUpdateRouteSelectorFromManagement_MirrorsV6ExitPair(t *testing.T) {
const (
v4ID = route.NetID("Exit Node (raspberrypi)")
v6ID = route.NetID("Exit Node (raspberrypi)-v6")
)
all := []route.NetID{v4ID, v6ID}
rs := routeselector.NewRouteSelector()
// Orphan the v6 selection: select the pair, then deselect only the v4 base.
require.NoError(t, rs.SelectRoutes([]route.NetID{v4ID, v6ID}, true, all))
require.NoError(t, rs.DeselectRoutes([]route.NetID{v4ID}, all))
require.True(t, rs.IsSelected(v6ID), "precondition: orphaned v6 selection survives v4 deselect")
m := &DefaultManager{routeSelector: rs}
v4Route := &route.Route{NetID: v4ID, Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: v6ID, Network: netip.MustParsePrefix("::/0")}
clientRoutes := route.HAMap{
"Exit Node (raspberrypi)|0.0.0.0/0": {v4Route},
"Exit Node (raspberrypi)-v6|::/0": {v6Route},
}
m.updateRouteSelectorFromManagement(clientRoutes)
assert.False(t, rs.IsSelected(v6ID), "v6 pair must follow the v4 base deselect after the management update")
filtered := rs.FilterSelectedExitNodes(clientRoutes)
assert.Empty(t, filtered, "deselected v4 exit node must not leak its ::/0 pair onto the tunnel")
}

View File

@@ -16,7 +16,7 @@ import (
type Notifier struct {
initialRoutes []*route.Route
currentRoutes []*route.Route
fakeIPRoutes []*route.Route
fakeIPRoutes []*route.Route
listener listener.NetworkChangeListener
listenerMux sync.Mutex
@@ -119,3 +119,7 @@ func (n *Notifier) GetInitialRouteRanges() []string {
sort.Strings(initialStrings)
return initialStrings
}
func (n *Notifier) Close() {
// unused
}

View File

@@ -3,6 +3,7 @@
package notifier
import (
"container/list"
"net/netip"
"slices"
"sort"
@@ -14,19 +15,26 @@ import (
)
type Notifier struct {
mu sync.Mutex
cond *sync.Cond
currentPrefixes []string
listener listener.NetworkChangeListener
listenerMux sync.Mutex
listener listener.NetworkChangeListener
queue *list.List
closed bool
}
func NewNotifier() *Notifier {
return &Notifier{}
n := &Notifier{
queue: list.New(),
}
n.cond = sync.NewCond(&n.mu)
go n.deliverLoop()
return n
}
func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
n.mu.Lock()
defer n.mu.Unlock()
n.listener = listener
}
@@ -43,32 +51,52 @@ func (n *Notifier) OnNewRoutes(route.HAMap) {
}
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0)
newNets := make([]string, 0, len(prefixes))
for _, prefix := range prefixes {
newNets = append(newNets, prefix.String())
}
sort.Strings(newNets)
n.mu.Lock()
if slices.Equal(n.currentPrefixes, newNets) {
n.mu.Unlock()
return
}
n.currentPrefixes = newNets
n.notify()
routes := strings.Join(n.currentPrefixes, ",")
n.queue.PushBack(routes)
n.cond.Signal()
n.mu.Unlock()
}
func (n *Notifier) notify() {
n.listenerMux.Lock()
defer n.listenerMux.Unlock()
if n.listener == nil {
return
}
go func(l listener.NetworkChangeListener) {
l.OnNetworkChanged(strings.Join(n.currentPrefixes, ","))
}(n.listener)
func (n *Notifier) Close() {
n.mu.Lock()
n.closed = true
n.cond.Signal()
n.mu.Unlock()
}
func (n *Notifier) GetInitialRouteRanges() []string {
return nil
}
func (n *Notifier) deliverLoop() {
for {
n.mu.Lock()
for n.queue.Len() == 0 && !n.closed {
n.cond.Wait()
}
if n.closed && n.queue.Len() == 0 {
n.mu.Unlock()
return
}
routes := n.queue.Remove(n.queue.Front()).(string)
l := n.listener
n.mu.Unlock()
if l != nil {
l.OnNetworkChanged(routes)
}
}
}

View File

@@ -38,3 +38,7 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
func (n *Notifier) GetInitialRouteRanges() []string {
return []string{}
}
func (n *Notifier) Close() {
// unused
}

View File

@@ -121,9 +121,12 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
return Nexthop{}, vars.ErrRouteNotAllowed
}
// Check if the prefix is part of any local subnets
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
// BSDs blackhole a /32 added inside a directly-connected subnet; Linux/Windows need it to beat the wt0 route.
switch runtime.GOOS {
case "darwin", "freebsd", "netbsd", "openbsd", "dragonfly":
if isLocal, subnet := r.isPrefixInLocalSubnets(prefix); isLocal {
return Nexthop{}, fmt.Errorf("prefix %s is part of local subnet %s: %w", prefix, subnet, vars.ErrRouteNotAllowed)
}
}
// Determine the exit interface and next hop for the prefix, so we can add a specific route

View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
"fmt"
"slices"
"strings"
"sync"
"github.com/hashicorp/go-multierror"
@@ -132,6 +131,33 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
return rs.isSelectedLocked(routeID)
}
// SyncPairedSelection forces pairedID's explicit selection state to match baseID's,
// so a synthesized "-v6" exit route always follows its v4 base: selecting or
// deselecting the v4 exit node governs the ::/0 pair, and any stale (orphaned)
// explicit state on the v6 entry is reset. The v4/v6 exit pair is treated as a single
// toggle, so the v6 entry carries no independent selection of its own.
func (rs *RouteSelector) SyncPairedSelection(baseID, pairedID route.NetID) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.deselectAll {
return
}
_, baseSelected := rs.selectedRoutes[baseID]
_, baseDeselected := rs.deselectedRoutes[baseID]
delete(rs.selectedRoutes, pairedID)
delete(rs.deselectedRoutes, pairedID)
switch {
case baseSelected:
rs.selectedRoutes[pairedID] = struct{}{}
case baseDeselected:
rs.deselectedRoutes[pairedID] = struct{}{}
}
}
// FilterSelected removes unselected routes from the provided map.
func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
rs.mu.RLock()
@@ -151,14 +177,13 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap {
}
// HasUserSelectionForRoute returns true if the user has explicitly selected or deselected this route.
// Intended for exit-node code paths: a v6 exit-node pair (e.g. "MyExit-v6") with no explicit state of
// its own inherits its v4 base's state, so legacy persisted selections that predate v6 pairing
// transparently apply to the synthesized v6 entry.
// The lookup is literal; v4/v6 exit pairs are kept consistent at write time via SyncPairedSelection,
// so a synthesized "-v6" entry carries the same explicit state as its v4 base.
func (rs *RouteSelector) HasUserSelectionForRoute(routeID route.NetID) bool {
rs.mu.RLock()
defer rs.mu.RUnlock()
return rs.hasUserSelectionForRouteLocked(rs.effectiveNetID(routeID))
return rs.hasUserSelectionForRouteLocked(routeID)
}
func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap {
@@ -187,83 +212,6 @@ func (rs *RouteSelector) FilterSelectedExitNodes(routes route.HAMap) route.HAMap
return filtered
}
// effectiveNetID returns the v4 base for a "-v6" exit pair entry that has no explicit
// state of its own, so selections made on the v4 entry govern the v6 entry automatically.
// Only call this from exit-node-specific code paths: applying it to a non-exit "-v6" route
// would make it inherit unrelated v4 state. Must be called with rs.mu held.
func (rs *RouteSelector) effectiveNetID(id route.NetID) route.NetID {
name := string(id)
if !strings.HasSuffix(name, route.V6ExitSuffix) {
return id
}
if _, ok := rs.selectedRoutes[id]; ok {
return id
}
if _, ok := rs.deselectedRoutes[id]; ok {
return id
}
return route.NetID(strings.TrimSuffix(name, route.V6ExitSuffix))
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func (rs *RouteSelector) applyExitNodeFilter(
id route.HAUniqueID,
netID route.NetID,
rt []*route.Route,
out route.HAMap,
) {
// Exit-node path: apply the v4/v6 pair mirror so a deselect on the v4 base also
// drops the synthesized v6 entry that lacks its own explicit state.
effective := rs.effectiveNetID(netID)
if rs.hasUserSelectionForRouteLocked(effective) {
if rs.isSelectedLocked(effective) {
out[id] = rt
}
return
}
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {
if !r.SkipAutoApply {
sel = append(sel, r)
}
}
return sel
}
// MarshalJSON implements the json.Marshaler interface
func (rs *RouteSelector) MarshalJSON() ([]byte, error) {
rs.mu.RLock()
@@ -317,3 +265,59 @@ func (rs *RouteSelector) UnmarshalJSON(data []byte) error {
return nil
}
func (rs *RouteSelector) isSelectedLocked(routeID route.NetID) bool {
if rs.deselectAll {
return false
}
_, deselected := rs.deselectedRoutes[routeID]
return !deselected
}
func (rs *RouteSelector) isDeselectedLocked(netID route.NetID) bool {
if rs.deselectAll {
return true
}
_, deselected := rs.deselectedRoutes[netID]
return deselected
}
func (rs *RouteSelector) hasUserSelectionForRouteLocked(routeID route.NetID) bool {
_, selected := rs.selectedRoutes[routeID]
_, deselected := rs.deselectedRoutes[routeID]
return selected || deselected
}
func (rs *RouteSelector) applyExitNodeFilter(
id route.HAUniqueID,
netID route.NetID,
rt []*route.Route,
out route.HAMap,
) {
if rs.hasUserSelectionForRouteLocked(netID) {
if rs.isSelectedLocked(netID) {
out[id] = rt
}
return
}
// no explicit selection for this route: defer to management's SkipAutoApply flag
sel := collectSelected(rt)
if len(sel) > 0 {
out[id] = sel
}
}
func isExitNode(rt []*route.Route) bool {
return len(rt) > 0 && (route.IsV4DefaultRoute(rt[0].Network) || route.IsV6DefaultRoute(rt[0].Network))
}
func collectSelected(rt []*route.Route) []*route.Route {
var sel []*route.Route
for _, r := range rt {
if !r.SkipAutoApply {
sel = append(sel, r)
}
}
return sel
}

View File

@@ -330,39 +330,73 @@ func TestRouteSelector_FilterSelectedExitNodes(t *testing.T) {
assert.Len(t, filtered, 0) // No routes should be selected
}
// TestRouteSelector_V6ExitPairInherits covers the v4/v6 exit-node pair selection
// mirror. The mirror is scoped to exit-node code paths: HasUserSelectionForRoute
// and FilterSelectedExitNodes resolve a "-v6" entry without explicit state to its
// v4 base, so legacy persisted selections that predate v6 pairing transparently
// apply to the synthesized v6 entry. General lookups (IsSelected, FilterSelected)
// stay literal so unrelated routes named "*-v6" don't inherit unrelated state.
func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
// TestRouteSelector_V6ExitPairSync covers SyncPairedSelection, which keeps a v4
// exit node and its synthesized "-v6" counterpart consistent. The selector itself
// is literal and never infers a v6 entry's state from its v4 base; callers that know
// the pairing (exit-node code paths) call SyncPairedSelection to force the v6 entry
// to follow the base, treating the pair as a single toggle.
func TestRouteSelector_V6ExitPairSync(t *testing.T) {
all := []route.NetID{"exit1", "exit1-v6", "exit2", "exit2-v6", "corp", "corp-v6"}
t.Run("HasUserSelectionForRoute mirrors deselected v4 base", func(t *testing.T) {
t.Run("selector lookups stay literal without sync", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 pair sees v4 base's user selection")
// The selector does not pair-resolve: the v6 entry is independent until synced.
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 entry has no state of its own")
assert.True(t, rs.IsSelected("exit1-v6"), "unsynced v6 entry stays selected by default")
// unrelated v6 with no v4 base touched is unaffected
assert.False(t, rs.HasUserSelectionForRoute("exit2-v6"))
// A route literally named "exit1-something" must never pair-resolve either.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
})
t.Run("IsSelected stays literal for non-exit lookups", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))
// A non-exit route literally named "corp-v6" must not inherit "corp"'s state
// via the mirror; the mirror only applies in exit-node code paths.
assert.False(t, rs.IsSelected("corp"))
assert.True(t, rs.IsSelected("corp-v6"), "non-exit *-v6 routes must not inherit unrelated v4 state")
})
t.Run("explicit v6 state overrides v4 base in filter", func(t *testing.T) {
t.Run("sync mirrors deselected v4 base onto v6", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.IsSelected("exit1"))
assert.False(t, rs.IsSelected("exit1-v6"), "v6 pair follows v4 base deselect")
assert.True(t, rs.HasUserSelectionForRoute("exit1-v6"), "v6 carries explicit deselect after sync")
})
t.Run("sync mirrors selected v4 base onto v6", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1"}, false, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.True(t, rs.IsSelected("exit1"))
assert.True(t, rs.IsSelected("exit1-v6"), "v6 pair follows v4 base select")
})
t.Run("sync clears v6 state when base has no explicit selection", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1-v6"}, true, all))
require.True(t, rs.HasUserSelectionForRoute("exit1-v6"))
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"),
"v6 explicit state is cleared so it follows management like its base")
})
// Regression for the observed bug (see netbird-engine.log): persisted state has
// the v4 base deselected but the v6 sibling explicitly selected (orphaned). The
// sync must reset the orphan so the ::/0 route does not leak onto the tunnel.
t.Run("sync clears orphaned explicit v6 selection on deselected base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
// Prior state: both explicitly selected, then only the v4 base deselected,
// leaving the v6 entry as a stale explicit selection.
require.NoError(t, rs.SelectRoutes([]route.NetID{"exit1", "exit1-v6"}, true, all))
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
require.True(t, rs.IsSelected("exit1-v6"), "precondition: orphaned v6 selection")
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.IsSelected("exit1-v6"), "orphaned v6 selection reset to follow v4 deselect")
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
@@ -370,23 +404,14 @@ func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
"exit1|0.0.0.0/0": {v4Route},
"exit1-v6|::/0": {v6Route},
}
filtered := rs.FilterSelectedExitNodes(routes)
assert.NotContains(t, filtered, route.HAUniqueID("exit1|0.0.0.0/0"))
assert.Contains(t, filtered, route.HAUniqueID("exit1-v6|::/0"), "explicit v6 select wins over v4 base")
assert.Empty(t, filtered, "deselecting v4 base must drop the v6 pair even if it was explicitly selected before")
})
t.Run("non-v6-suffix routes unaffected", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
// A route literally named "exit1-something" must not pair-resolve.
assert.False(t, rs.HasUserSelectionForRoute("exit1-something"))
})
t.Run("filter v6 paired with deselected v4 base", func(t *testing.T) {
t.Run("filter drops synced v6 pair of deselected v4 base", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"exit1"}, all))
rs.SyncPairedSelection("exit1", "exit1-v6")
v4Route := &route.Route{NetID: "exit1", Network: netip.MustParsePrefix("0.0.0.0/0")}
v6Route := &route.Route{NetID: "exit1-v6", Network: netip.MustParsePrefix("::/0")}
@@ -399,6 +424,15 @@ func TestRouteSelector_V6ExitPairInherits(t *testing.T) {
assert.Empty(t, filtered, "deselecting v4 base must also drop the v6 pair")
})
t.Run("deselectAll makes sync a no-op", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
rs.DeselectAllRoutes()
rs.SyncPairedSelection("exit1", "exit1-v6")
assert.False(t, rs.HasUserSelectionForRoute("exit1-v6"), "sync must not write explicit state under deselectAll")
})
t.Run("non-exit *-v6 routes pass through FilterSelectedExitNodes", func(t *testing.T) {
rs := routeselector.NewRouteSelector()
require.NoError(t, rs.DeselectRoutes([]route.NetID{"corp"}, all))

View File

@@ -17,6 +17,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -25,6 +26,7 @@ import (
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
types "github.com/netbirdio/netbird/upload-server/types"
)
// ConnectionListener export internal Listener for mobile
@@ -54,6 +56,7 @@ type selectRoute struct {
Network netip.Prefix
Domains domain.List
Selected bool
Status string
extraNetworks []netip.Prefix
}
@@ -65,6 +68,8 @@ func init() {
type Client struct {
cfgFile string
stateFile string
cacheDir string
logFilePath string
recorder *peer.Status
ctxCancel context.CancelFunc
ctxCancelLock *sync.Mutex
@@ -75,16 +80,21 @@ type Client struct {
onHostDnsFn func([]string)
dnsManager dns.IosDnsManager
loginComplete bool
connectClient *internal.ConnectClient
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
preloadedConfig *profilemanager.Config
stateMu sync.RWMutex
connectClient *internal.ConnectClient
config *profilemanager.Config
}
// NewClient instantiate a new Client
func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
func NewClient(cfgFile, stateFile, cacheDir, logFilePath, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client {
return &Client{
cfgFile: cfgFile,
stateFile: stateFile,
cacheDir: cacheDir,
logFilePath: logFilePath,
deviceName: deviceName,
osName: osName,
osVersion: osVersion,
@@ -161,8 +171,13 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, connectClient)
// Persist the latest sync response so DebugBundle can include the network
// map. On iOS this is backed by disk to keep it out of the constrained
// process memory (see the syncstore package).
connectClient.SetSyncResponsePersistence(true)
return connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile, c.cacheDir, c.logFilePath)
}
// Stop the internal client and free the resources
@@ -174,6 +189,84 @@ func (c *Client) Stop() {
}
c.ctxCancel()
c.setState(nil, nil)
}
// DebugBundle generates a debug bundle, uploads it and returns the upload key.
// It works with or without a running engine: when the engine is up it reuses
// the live config, sync response and client metrics; otherwise it loads the
// config from disk (or the preloaded tvOS config).
func (c *Client) DebugBundle(anonymize bool) (string, error) {
cfg, cc := c.stateSnapshot()
// If the engine hasn't been started, load config so we can reach management.
if cfg == nil {
if c.preloadedConfig != nil {
cfg = c.preloadedConfig
} else {
var err error
// Use DirectUpdateOrCreateConfig to avoid atomic file operations
// (temp file + rename) blocked by the tvOS sandbox.
cfg, err = profilemanager.DirectUpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: c.cfgFile,
StateFilePath: c.stateFile,
})
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
}
}
deps := debug.GeneratorDependencies{
InternalConfig: cfg,
StatusRecorder: c.recorder,
TempDir: c.cacheDir,
StatePath: c.stateFile,
LogPath: c.logFilePath,
}
if cc != nil {
resp, err := cc.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
deps.SyncResponse = resp
if e := cc.Engine(); e != nil {
if cm := e.GetClientMetrics(); cm != nil {
deps.ClientMetrics = cm
}
}
}
bundleGenerator := debug.NewBundleGenerator(
deps,
debug.BundleConfig{
Anonymize: anonymize,
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
if err != nil {
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle uploaded with key %s", key)
return key, nil
}
// SetTraceLogLevel configure the logger to trace level
@@ -227,6 +320,16 @@ func (c *Client) RemoveConnectionListener() {
c.recorder.RemoveConnectionListener()
}
// IsLoginRequiredCached reports whether the LAST observed management error was an
// auth failure (PermissionDenied/InvalidArgument), using the in-memory status
// recorder. Unlike IsLoginRequired() it performs NO network call, so it is safe to
// call from the connection listener during teardown (e.g. onDisconnected) without
// blocking on a slow or unavailable network. Returns false while connected to
// management or when the last error was not auth-related.
func (c *Client) IsLoginRequiredCached() bool {
return c.recorder.IsLoginRequired()
}
func (c *Client) IsLoginRequired() bool {
var ctx context.Context
//nolint
@@ -354,11 +457,12 @@ func (c *Client) ClearLoginComplete() {
}
func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
if c.connectClient == nil {
_, connectClient := c.stateSnapshot()
if connectClient == nil {
return nil, fmt.Errorf("not connected")
}
engine := c.connectClient.Engine()
engine := connectClient.Engine()
if engine == nil {
return nil, fmt.Errorf("not connected")
}
@@ -377,9 +481,57 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) {
routes := buildSelectRoutes(routesMap, routeSelector.IsSelected, v6ExitMerged)
resolvedDomains := c.recorder.GetResolvedDomainsStates()
// Compute each route's connection status in the core (mirroring the Android
// bridge), so the UI doesn't have to infer it by string-matching the joined
// Network value against peer routes. For a merged exit node the status reflects
// whichever of the v4/v6 prefixes is served by a connected peer; for dynamic
// (DNS) routes the peer route key is the domain pattern (see dynamic.Route.String).
connectedRoutes := c.connectedRouteSet()
for _, r := range routes {
r.Status = routeStatus(r, connectedRoutes)
}
return prepareRouteSelectionDetails(routes, resolvedDomains), nil
}
// connectedRouteSet returns the set of route keys (as strings) currently served by a
// connected peer, gathered across all connected peers' route tables. The keys match
// what the route manager records: a prefix string for static routes (e.g. "0.0.0.0/0")
// and the domain pattern for dynamic routes (e.g. "*.example.com").
func (c *Client) connectedRouteSet() map[string]struct{} {
connected := map[string]struct{}{}
for _, p := range c.recorder.GetFullStatus().Peers {
if p.ConnStatus != peer.StatusConnected {
continue
}
for r := range p.GetRoutes() {
connected[r] = struct{}{}
}
}
return connected
}
// routeStatus reports "Connected" if any of the route's keys is served by a connected
// peer: the primary Network prefix, an extra v6 network of a merged exit node, or the
// domain pattern for a dynamic DNS route. Otherwise "Idle".
func routeStatus(r *selectRoute, connectedRoutes map[string]struct{}) string {
keys := make([]string, 0, 1+len(r.extraNetworks))
if len(r.Domains) > 0 {
keys = append(keys, r.Domains.SafeString())
} else {
keys = append(keys, r.Network.String())
}
for _, extra := range r.extraNetworks {
keys = append(keys, extra.String())
}
for _, k := range keys {
if _, ok := connectedRoutes[k]; ok {
return peer.StatusConnected.String()
}
}
return peer.StatusIdle.String()
}
func buildSelectRoutes(routesMap map[route.NetID][]*route.Route, isSelected func(route.NetID) bool, v6Merged map[route.NetID]struct{}) []*selectRoute {
var routes []*selectRoute
for id, rt := range routesMap {
@@ -462,6 +614,7 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
Network: netStr,
Domains: &domainDetails,
Selected: r.Selected,
Status: r.Status,
})
}
@@ -470,11 +623,12 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom
}
func (c *Client) SelectRoute(id string) error {
if c.connectClient == nil {
_, connectClient := c.stateSnapshot()
if connectClient == nil {
return fmt.Errorf("not connected")
}
engine := c.connectClient.Engine()
engine := connectClient.Engine()
if engine == nil {
return fmt.Errorf("not connected")
}
@@ -500,10 +654,11 @@ func (c *Client) SelectRoute(id string) error {
}
func (c *Client) DeselectRoute(id string) error {
if c.connectClient == nil {
_, connectClient := c.stateSnapshot()
if connectClient == nil {
return fmt.Errorf("not connected")
}
engine := c.connectClient.Engine()
engine := connectClient.Engine()
if engine == nil {
return fmt.Errorf("not connected")
}
@@ -527,6 +682,22 @@ func (c *Client) DeselectRoute(id string) error {
return nil
}
// setState stores the running engine state so DebugBundle can reuse the live
// config and ConnectClient. It is cleared on Stop.
func (c *Client) setState(cfg *profilemanager.Config, cc *internal.ConnectClient) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
c.config = cfg
c.connectClient = cc
}
// stateSnapshot returns the current config and ConnectClient under the lock.
func (c *Client) stateSnapshot() (*profilemanager.Config, *internal.ConnectClient) {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.config, c.connectClient
}
func formatDuration(d time.Duration) string {
ds := d.String()
dotIndex := strings.Index(ds, ".")

View File

@@ -20,6 +20,7 @@ type RoutesSelectionInfo struct {
Network string
Domains *DomainDetails
Selected bool
Status string
}
type DomainCollection interface {

File diff suppressed because it is too large Load Diff

View File

@@ -85,6 +85,8 @@ service DaemonService {
rpc AddProfile(AddProfileRequest) returns (AddProfileResponse) {}
rpc RenameProfile(RenameProfileRequest) returns (RenameProfileResponse) {}
rpc RemoveProfile(RemoveProfileRequest) returns (RemoveProfileResponse) {}
rpc ListProfiles(ListProfilesRequest) returns (ListProfilesResponse) {}
@@ -378,6 +380,9 @@ message RelayState {
string URI = 1;
bool available = 2;
string error = 3;
// transport is the negotiated relay transport (e.g. "ws", "quic"),
// empty for stun/turn probes or when not connected.
string transport = 4;
}
message NSGroupState {
@@ -622,11 +627,18 @@ message GetEventsResponse {
}
message SwitchProfileRequest {
// profileName is treated as a handle: exact ID, unique ID prefix, or
// unique display name. The daemon resolves it server-side.
optional string profileName = 1;
optional string username = 2;
}
message SwitchProfileResponse {}
message SwitchProfileResponse {
// id is the resolved on-disk ID of the profile that became active.
// Lets CLI clients update their local active-profile state without
// duplicating the resolution logic.
string id = 1;
}
message SetConfigRequest {
string username = 1;
@@ -693,17 +705,42 @@ message SetConfigResponse{}
message AddProfileRequest {
string username = 1;
// profileName carries the human-readable display name for the new
// profile. The on-disk filename is a separately-generated ID.
string profileName = 2;
}
message AddProfileResponse {}
message AddProfileResponse {
// id is the generated on-disk ID of the new profile. CLI clients
// display a truncated form, UI clients can ignore it.
string id = 1;
}
message RenameProfileRequest {
string username = 1;
// handle: an exact ID, a unique ID prefix, or a unique display name.
string handle = 2;
// newProfileName is the new human-readable display name for the profile.
string newProfileName = 3;
}
message RenameProfileResponse {
// confirm the old profile name after resolving handle.
string oldProfileName = 1;
}
message RemoveProfileRequest {
string username = 1;
// profileName is treated as a handle: an exact ID, a unique ID
// prefix, or a unique display name. Resolution happens server-side.
string profileName = 2;
}
message RemoveProfileResponse {}
message RemoveProfileResponse {
// id is the full resolved ID of the removed profile, so callers can
// confirm exactly which profile a name/prefix handle resolved to.
string id = 1;
}
message ListProfilesRequest {
string username = 1;
@@ -716,6 +753,7 @@ message ListProfilesResponse {
message Profile {
string name = 1;
bool is_active = 2;
string id = 3;
}
message GetActiveProfileRequest {}
@@ -723,6 +761,7 @@ message GetActiveProfileRequest {}
message GetActiveProfileResponse {
string profileName = 1;
string username = 2;
string id = 3;
}
message LogoutRequest {

View File

@@ -45,6 +45,7 @@ const (
DaemonService_SwitchProfile_FullMethodName = "/daemon.DaemonService/SwitchProfile"
DaemonService_SetConfig_FullMethodName = "/daemon.DaemonService/SetConfig"
DaemonService_AddProfile_FullMethodName = "/daemon.DaemonService/AddProfile"
DaemonService_RenameProfile_FullMethodName = "/daemon.DaemonService/RenameProfile"
DaemonService_RemoveProfile_FullMethodName = "/daemon.DaemonService/RemoveProfile"
DaemonService_ListProfiles_FullMethodName = "/daemon.DaemonService/ListProfiles"
DaemonService_GetActiveProfile_FullMethodName = "/daemon.DaemonService/GetActiveProfile"
@@ -112,6 +113,7 @@ type DaemonServiceClient interface {
SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error)
SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error)
AddProfile(ctx context.Context, in *AddProfileRequest, opts ...grpc.CallOption) (*AddProfileResponse, error)
RenameProfile(ctx context.Context, in *RenameProfileRequest, opts ...grpc.CallOption) (*RenameProfileResponse, error)
RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error)
ListProfiles(ctx context.Context, in *ListProfilesRequest, opts ...grpc.CallOption) (*ListProfilesResponse, error)
GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error)
@@ -422,6 +424,16 @@ func (c *daemonServiceClient) AddProfile(ctx context.Context, in *AddProfileRequ
return out, nil
}
func (c *daemonServiceClient) RenameProfile(ctx context.Context, in *RenameProfileRequest, opts ...grpc.CallOption) (*RenameProfileResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RenameProfileResponse)
err := c.cc.Invoke(ctx, DaemonService_RenameProfile_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RemoveProfileResponse)
@@ -613,6 +625,7 @@ type DaemonServiceServer interface {
SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error)
SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error)
AddProfile(context.Context, *AddProfileRequest) (*AddProfileResponse, error)
RenameProfile(context.Context, *RenameProfileRequest) (*RenameProfileResponse, error)
RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error)
ListProfiles(context.Context, *ListProfilesRequest) (*ListProfilesResponse, error)
GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error)
@@ -723,6 +736,9 @@ func (UnimplementedDaemonServiceServer) SetConfig(context.Context, *SetConfigReq
func (UnimplementedDaemonServiceServer) AddProfile(context.Context, *AddProfileRequest) (*AddProfileResponse, error) {
return nil, status.Error(codes.Unimplemented, "method AddProfile not implemented")
}
func (UnimplementedDaemonServiceServer) RenameProfile(context.Context, *RenameProfileRequest) (*RenameProfileResponse, error) {
return nil, status.Error(codes.Unimplemented, "method RenameProfile not implemented")
}
func (UnimplementedDaemonServiceServer) RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error) {
return nil, status.Error(codes.Unimplemented, "method RemoveProfile not implemented")
}
@@ -1237,6 +1253,24 @@ func _DaemonService_AddProfile_Handler(srv interface{}, ctx context.Context, dec
return interceptor(ctx, in, info, handler)
}
func _DaemonService_RenameProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RenameProfileRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).RenameProfile(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: DaemonService_RenameProfile_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).RenameProfile(ctx, req.(*RenameProfileRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_RemoveProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RemoveProfileRequest)
if err := dec(in); err != nil {
@@ -1567,6 +1601,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "AddProfile",
Handler: _DaemonService_AddProfile_Handler,
},
{
MethodName: "RenameProfile",
Handler: _DaemonService_RenameProfile_Handler,
},
{
MethodName: "RemoveProfile",
Handler: _DaemonService_RemoveProfile_Handler,

View File

@@ -79,7 +79,7 @@ func TestPersistLoginOverrides(t *testing.T) {
_, err := profilemanager.UpdateOrCreateConfig(seed)
require.NoError(t, err, "seed config")
activeProf := &profilemanager.ActiveProfileState{Name: "default"}
activeProf := &profilemanager.ActiveProfileState{ID: "default"}
err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK)
require.NoError(t, err, "persistLoginOverrides")

View File

@@ -78,7 +78,7 @@ type Server struct {
// changed by connectWithRetryRuns goroutine exit — for that
// (goroutine-still-alive) check, see connectionGoroutineRunning() which
// derives from clientGiveUpChan close state. Protected by s.mutex.
clientRunning bool
clientRunning bool
clientRunningChan chan struct{}
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
@@ -375,7 +375,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
return nil, err
}
config, err := setConfigInputFromRequest(msg)
config, err := s.setConfigInputFromRequest(msg)
if err != nil {
return nil, err
}
@@ -398,17 +398,17 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
// field is its own optional case. Returns the resolved ConfigInput
// and a non-nil error only when the active profile file path cannot
// be determined.
func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.ConfigInput, error) {
func (s *Server) setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.ConfigInput, error) {
var config profilemanager.ConfigInput
profState := profilemanager.ActiveProfileState{
Name: msg.ProfileName,
Username: msg.Username,
}
profPath, err := profState.FilePath()
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return config, fmt.Errorf("failed to get active profile file path: %w", err)
log.Errorf("failed to resolve profile %q: %v", msg.ProfileName, err)
return config, err
}
profPath := resolved.Path
if profPath == "" {
profPath = profilemanager.DefaultConfigPath
}
config.ConfigPath = profPath
@@ -535,30 +535,9 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}
if msg.ProfileName != nil {
if *msg.ProfileName != "default" && (msg.Username == nil || *msg.Username == "") {
log.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
}
var username string
if *msg.ProfileName != "default" {
username = *msg.Username
}
if *msg.ProfileName != activeProf.Name && username != activeProf.Username {
if s.checkProfilesDisabled() {
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
log.Infof("switching to profile %s for user '%s'", *msg.ProfileName, username)
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: *msg.ProfileName,
Username: username,
}); err != nil {
log.Errorf("failed to set active profile state: %v", err)
return nil, fmt.Errorf("failed to set active profile state: %w", err)
}
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
log.Errorf("failed to switch profile: %v", err)
return nil, err
}
}
@@ -568,7 +547,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
s.mutex.Lock()
@@ -806,10 +785,10 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
}
if msg != nil && msg.ProfileName != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
s.mutex.Unlock()
log.Errorf("failed to switch profile: %v", err)
return nil, fmt.Errorf("failed to switch profile: %w", err)
return nil, err
}
}
@@ -820,7 +799,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
config, _, err := s.getConfig(activeProf)
if err != nil {
@@ -864,34 +843,60 @@ func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error)
}
}
func (s *Server) switchProfileIfNeeded(profileName string, userName *string, activeProf *profilemanager.ActiveProfileState) error {
if profileName != "default" && (userName == nil || *userName == "") {
log.Errorf("profile name is set to %s, but username is not provided", profileName)
return fmt.Errorf("profile name is set to %s, but username is not provided", profileName)
// resolveProfileHandle resolves a wire-level profile handle (display
// name, ID, or unique ID prefix) to a concrete profile. Returns gRPC
// status errors so handlers can return them directly.
func (s *Server) resolveProfileHandle(handle, username string) (*profilemanager.Profile, error) {
p, err := s.profileManager.ResolveProfile(handle, username)
if err == nil {
return p, nil
}
var amb *profilemanager.ErrAmbiguousHandle
if errors.As(err, &amb) {
return nil, gstatus.Errorf(codes.InvalidArgument, "%v", amb)
}
if errors.Is(err, profilemanager.ErrProfileNotFound) {
return nil, gstatus.Errorf(codes.NotFound, "profile %q not found", handle)
}
return nil, fmt.Errorf("resolve profile: %w", err)
}
// switchProfileIfNeeded resolves the user-supplied handle, updates the
// active profile state if it differs from the current one, and returns
// the resolved profile so callers can include its ID in RPC responses.
func (s *Server) switchProfileIfNeeded(handle string, userName *string, activeProf *profilemanager.ActiveProfileState) (*profilemanager.Profile, error) {
if handle != profilemanager.DefaultProfileName && (userName == nil || *userName == "") {
log.Errorf("profile name is set to %s, but username is not provided", handle)
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", handle)
}
var username string
if profileName != "default" {
if handle != profilemanager.DefaultProfileName {
username = *userName
}
if profileName != activeProf.Name || username != activeProf.Username {
resolved, err := s.resolveProfileHandle(handle, username)
if err != nil {
return nil, err
}
if resolved.ID != activeProf.ID || username != activeProf.Username {
if s.checkProfilesDisabled() {
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
log.Infof("switching to profile %s for user %s", profileName, username)
log.Infof("switching to profile %s (%s) for user %s", resolved.Name, resolved.ID, username)
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profileName,
ID: resolved.ID,
Username: username,
}); err != nil {
log.Errorf("failed to set active profile state: %v", err)
return fmt.Errorf("failed to set active profile state: %w", err)
return nil, fmt.Errorf("failed to set active profile state: %w", err)
}
}
return nil
return resolved, nil
}
// SwitchProfile switches the active profile in the daemon.
@@ -906,9 +911,9 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
}
if msg != nil && msg.ProfileName != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
log.Errorf("failed to switch profile: %v", err)
return nil, fmt.Errorf("failed to switch profile: %w", err)
return nil, err
}
}
activeProf, err = s.profileManager.GetActiveProfileState()
@@ -924,7 +929,7 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
s.config = config
return &proto.SwitchProfileResponse{}, nil
return &proto.SwitchProfileResponse{Id: activeProf.ID.String()}, nil
}
// Down engine work in the daemon.
@@ -1014,22 +1019,27 @@ func (s *Server) Logout(ctx context.Context, msg *proto.LogoutRequest) (*proto.L
}
func (s *Server) handleProfileLogout(ctx context.Context, msg *proto.LogoutRequest) (*proto.LogoutResponse, error) {
if err := s.validateProfileOperation(*msg.ProfileName, true); err != nil {
return nil, err
}
if msg.Username == nil || *msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided when profile name is specified")
}
username := *msg.Username
if err := s.logoutFromProfile(ctx, *msg.ProfileName, username); err != nil {
log.Errorf("failed to logout from profile %s: %v", *msg.ProfileName, err)
resolved, err := s.resolveProfileHandle(*msg.ProfileName, username)
if err != nil {
return nil, err
}
if err := s.validateProfileOperation(resolved.ID, true); err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, resolved); err != nil {
log.Errorf("failed to logout from profile %s: %v", resolved.ID, err)
return nil, gstatus.Errorf(codes.Internal, "logout: %v", err)
}
activeProf, _ := s.profileManager.GetActiveProfileState()
if activeProf != nil && activeProf.Name == *msg.ProfileName {
if activeProf != nil && activeProf.ID == resolved.ID {
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
log.Errorf("failed to cleanup connection: %v", err)
}
@@ -1091,30 +1101,30 @@ func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*prof
return config, configExisted, nil
}
func (s *Server) canRemoveProfile(profileName string) error {
if profileName == profilemanager.DefaultProfileName {
func (s *Server) canRemoveProfile(id profilemanager.ID) error {
if id == profilemanager.DefaultProfileName {
return fmt.Errorf("remove profile with reserved name: %s", profilemanager.DefaultProfileName)
}
activeProf, err := s.profileManager.GetActiveProfileState()
if err == nil && activeProf.Name == profileName {
return fmt.Errorf("remove active profile: %s", profileName)
if err == nil && activeProf.ID == id {
return fmt.Errorf("remove active profile: %s", id)
}
return nil
}
func (s *Server) validateProfileOperation(profileName string, allowActiveProfile bool) error {
func (s *Server) validateProfileOperation(id profilemanager.ID, allowActiveProfile bool) error {
if s.checkProfilesDisabled() {
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if profileName == "" {
if id == "" {
return gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
}
if !allowActiveProfile {
if err := s.canRemoveProfile(profileName); err != nil {
if err := s.canRemoveProfile(id); err != nil {
return gstatus.Errorf(codes.InvalidArgument, "%v", err)
}
}
@@ -1122,25 +1132,20 @@ func (s *Server) validateProfileOperation(profileName string, allowActiveProfile
return nil
}
// logoutFromProfile logs out from a specific profile by loading its config and sending logout request
func (s *Server) logoutFromProfile(ctx context.Context, profileName, username string) error {
func (s *Server) logoutFromProfile(ctx context.Context, profile *profilemanager.Profile) error {
activeProf, err := s.profileManager.GetActiveProfileState()
if err == nil && activeProf.Name == profileName && s.connectClient != nil {
if err == nil && activeProf.ID == profile.ID && s.connectClient != nil {
return s.sendLogoutRequest(ctx)
}
profileState := &profilemanager.ActiveProfileState{
Name: profileName,
Username: username,
}
profilePath, err := profileState.FilePath()
if err != nil {
return fmt.Errorf("get profile path: %w", err)
cfgPath := profile.Path
if cfgPath == "" {
cfgPath = profilemanager.DefaultConfigPath
}
config, err := profilemanager.GetConfig(profilePath)
config, err := profilemanager.GetConfig(cfgPath)
if err != nil {
return fmt.Errorf("profile '%s' not found", profileName)
return fmt.Errorf("profile '%s' not found", profile.ID)
}
return s.sendLogoutRequestWithConfig(ctx, config)
@@ -1558,15 +1563,14 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
return nil, ctx.Err()
}
prof := profilemanager.ActiveProfileState{
Name: req.ProfileName,
Username: req.Username,
}
cfgPath, err := prof.FilePath()
resolved, err := s.resolveProfileHandle(req.ProfileName, req.Username)
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
log.Errorf("failed to resolve profile %q: %v", req.ProfileName, err)
return nil, err
}
cfgPath := resolved.Path
if cfgPath == "" {
cfgPath = profilemanager.DefaultConfigPath
}
cfg, err := profilemanager.GetConfig(cfgPath)
@@ -1671,12 +1675,39 @@ func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
}
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
created, err := s.profileManager.AddProfile(msg.ProfileName, msg.Username)
if err != nil {
log.Errorf("failed to create profile: %v", err)
return nil, fmt.Errorf("failed to create profile: %w", err)
}
return &proto.AddProfileResponse{}, nil
return &proto.AddProfileResponse{Id: created.ID.String()}, nil
}
func (s *Server) RenameProfile(ctx context.Context, msg *proto.RenameProfileRequest) (*proto.RenameProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if msg.Handle == "" || msg.Username == "" || msg.NewProfileName == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name, username and new profile name must be provided")
}
resolved, err := s.resolveProfileHandle(msg.Handle, msg.Username)
if err != nil {
return nil, err
}
err = s.profileManager.RenameProfile(resolved.ID, msg.Username, msg.NewProfileName)
if err != nil {
log.Errorf("failed to rename profile: %v", err)
return nil, fmt.Errorf("failed to rename profile: %w", err)
}
return &proto.RenameProfileResponse{OldProfileName: resolved.Name}, nil
}
// RemoveProfile removes a profile from the daemon.
@@ -1684,20 +1715,29 @@ func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequ
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if msg.ProfileName == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
}
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
if err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
if err := s.logoutFromProfile(ctx, resolved); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", resolved.ID, err)
}
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
if err := s.profileManager.RemoveProfile(resolved.ID, msg.Username); err != nil {
log.Errorf("failed to remove profile: %v", err)
return nil, fmt.Errorf("failed to remove profile: %w", err)
}
return &proto.RemoveProfileResponse{}, nil
return &proto.RemoveProfileResponse{Id: resolved.ID.String()}, nil
}
// ListProfiles lists all profiles in the daemon.
@@ -1720,6 +1760,7 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
}
for i, profile := range profiles {
response.Profiles[i] = &proto.Profile{
Id: profile.ID.String(),
Name: profile.Name,
IsActive: profile.IsActive,
}
@@ -1728,7 +1769,9 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
return response, nil
}
// GetActiveProfile returns the active profile in the daemon.
// GetActiveProfile returns the active profile in the daemon. The ProfileName
// field carries the display name for backwards compatibility with UI clients,
// new callers should prefer Id.
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
@@ -1739,9 +1782,23 @@ func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfi
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
// Fallback to legacy name == ID
displayName := activeProfile.ID.String()
if activeProfile.ID != profilemanager.DefaultProfileName {
if profiles, lerr := s.profileManager.ListProfiles(activeProfile.Username); lerr == nil {
for _, p := range profiles {
if p.ID == activeProfile.ID {
displayName = p.Name
break
}
}
}
}
return &proto.GetActiveProfileResponse{
ProfileName: activeProfile.Name,
ProfileName: displayName,
Username: activeProfile.Username,
Id: activeProfile.ID.String(),
}, nil
}

View File

@@ -97,7 +97,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "test-profile",
ID: "test-profile",
Username: currUser.Username,
})
if err != nil {
@@ -158,7 +158,7 @@ func TestServer_Up(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profName,
ID: profilemanager.ID(profName),
Username: currUser.Username,
})
if err != nil {
@@ -228,7 +228,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "default",
ID: "default",
Username: currUser.Username,
})
if err != nil {

View File

@@ -62,7 +62,7 @@ func setupServerWithProfile(t *testing.T) (s *Server, ctx context.Context, profN
pm := profilemanager.ServiceManager{}
require.NoError(t, pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profName,
ID: profilemanager.ID(profName),
Username: currUser.Username,
}))
@@ -107,9 +107,9 @@ func TestSetConfig_MDMReject_SingleField(t *testing.T) {
func TestSetConfig_MDMReject_MultipleFields(t *testing.T) {
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
mdm.KeyBlockInbound: true,
mdm.KeyRosenpassEnabled: true,
mdm.KeyManagementURL: "https://mdm.example.com:443",
mdm.KeyBlockInbound: true,
mdm.KeyRosenpassEnabled: true,
}))
s, ctx, profName, username, _ := setupServerWithProfile(t)

View File

@@ -47,7 +47,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profName,
ID: profilemanager.ID(profName),
Username: currUser.Username,
})
require.NoError(t, err)
@@ -96,7 +96,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
DisableNotifications: &disableNotifications,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,
DisableIpv6: &disableIPv6,
DisableIpv6: &disableIPv6,
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
CleanNATExternalIPs: false,
CustomDNSAddress: []byte("1.1.1.1:53"),
@@ -112,7 +112,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.NoError(t, err)
profState := profilemanager.ActiveProfileState{
Name: profName,
ID: profilemanager.ID(profName),
Username: currUser.Username,
}
cfgPath, err := profState.FilePath()

View File

@@ -98,6 +98,7 @@ type RelayStateOutputDetail struct {
URI string `json:"uri" yaml:"uri"`
Available bool `json:"available" yaml:"available"`
Error string `json:"error" yaml:"error"`
Transport string `json:"transport,omitempty" yaml:"transport,omitempty"`
}
type RelayStateOutput struct {
@@ -219,7 +220,8 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
RelayStateOutputDetail{
URI: relay.URI,
Available: available,
Error: relay.GetError(),
Error: relayErrorString(relay.GetError()),
Transport: relay.GetTransport(),
},
)
@@ -235,6 +237,12 @@ func mapRelays(relays []*proto.RelayState) RelayStateOutput {
}
}
// relayErrorString flattens a newline-joined aggregated relay error onto a
// single line for status output.
func relayErrorString(s string) string {
return strings.ReplaceAll(s, "\n", "; ")
}
func mapNSGroups(servers []*proto.NSGroupState) []NsServerGroupStateOutput {
mappedNSGroups := make([]NsServerGroupStateOutput, 0, len(servers))
for _, pbNsGroupServer := range servers {
@@ -441,6 +449,8 @@ func (o *OutputOverview) GeneralSummary(showURL bool, showRelays bool, showNameS
available = "Unavailable"
reason = fmt.Sprintf(", reason: %s", relay.Error)
}
} else if relay.Transport != "" {
available = fmt.Sprintf("%s via %s", available, relay.Transport)
}
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)

View File

@@ -647,3 +647,13 @@ func TestTimeAgo(t *testing.T) {
})
}
}
func TestMapRelaysTransport(t *testing.T) {
out := mapRelays([]*proto.RelayState{
{URI: "rels://relay.example:443", Available: true, Transport: "quic"},
{URI: "rels://relay2.example:443", Available: true, Transport: "ws"},
})
require.Len(t, out.Details, 2)
assert.Equal(t, "quic", out.Details[0].Transport)
assert.Equal(t, "ws", out.Details[1].Transport)
}

View File

@@ -645,7 +645,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
}
req := &proto.SetConfigRequest{
ProfileName: activeProf.Name,
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
}
@@ -818,13 +818,15 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe
return nil, fmt.Errorf("get current user: %w", err)
}
handle := activeProf.ID.String()
loginReq := &proto.LoginRequest{
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ProfileName: &activeProf.Name,
ProfileName: &handle,
Username: &currUser.Username,
}
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
profileState, err := s.profileManager.GetProfileState(activeProf.ID)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -1367,7 +1369,7 @@ func (s *serviceClient) getSrvConfig() {
}
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
ProfileName: activeProf.Name,
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
})
if err != nil {
@@ -1613,7 +1615,7 @@ func (s *serviceClient) loadSettings() {
}
cfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
ProfileName: activeProf.Name,
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
})
if err != nil {
@@ -1813,7 +1815,7 @@ func (s *serviceClient) updateConfig() error {
}
req := proto.SetConfigRequest{
ProfileName: activeProf.Name,
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
DisableAutoConnect: &disableAutoStart,
ServerSSHAllowed: &sshAllowed,

View File

@@ -66,7 +66,7 @@ func (s *serviceClient) showProfilesUI() {
} else {
indicator.SetText("")
}
nameLabel.SetText(profile.Name)
nameLabel.SetText(formatProfileLabel(profile, profiles))
// Configure Select/Active button
selectBtn.SetText(func() string {
@@ -88,7 +88,7 @@ func (s *serviceClient) showProfilesUI() {
return
}
// switch
err = s.switchProfile(profile.Name)
err = s.switchProfile(profile.ID)
if err != nil {
log.Errorf("failed to switch profile: %v", err)
dialog.ShowError(errors.New("failed to select profile"), s.wProfiles)
@@ -130,7 +130,7 @@ func (s *serviceClient) showProfilesUI() {
logoutBtn.Show()
logoutBtn.SetText("Deregister")
logoutBtn.OnTapped = func() {
s.handleProfileLogout(profile.Name, refresh)
s.handleProfileLogout(profile, refresh)
}
// Remove profile
@@ -144,7 +144,7 @@ func (s *serviceClient) showProfilesUI() {
return
}
err = s.removeProfile(profile.Name)
err = s.removeProfile(profile.ID)
if err != nil {
log.Errorf("failed to remove profile: %v", err)
dialog.ShowError(fmt.Errorf("failed to remove profile"), s.wProfiles)
@@ -250,7 +250,7 @@ func (s *serviceClient) addProfile(profileName string) error {
return nil
}
func (s *serviceClient) switchProfile(profileName string) error {
func (s *serviceClient) switchProfile(handle string) error {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
return fmt.Errorf(getClientFMT, err)
@@ -261,15 +261,15 @@ func (s *serviceClient) switchProfile(profileName string) error {
return fmt.Errorf("get current user: %w", err)
}
if _, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
ProfileName: &profileName,
resp, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
Username: &currUser.Username,
}); err != nil {
})
if err != nil {
return fmt.Errorf("switch profile failed: %w", err)
}
err = s.profileManager.SwitchProfile(profileName)
if err != nil {
if err := s.profileManager.SwitchProfile(profilemanager.ID(resp.Id)); err != nil {
return fmt.Errorf("switch profile: %w", err)
}
@@ -299,10 +299,27 @@ func (s *serviceClient) removeProfile(profileName string) error {
}
type Profile struct {
ID string
Name string
IsActive bool
}
// formatProfileLabel returns the display label for a profile. Profiles can
// share the same Name, so when more than one profile in profiles carries this
// Name, a short form of the ID is appended to disambiguate the entries.
func formatProfileLabel(profile Profile, profiles []Profile) string {
count := 0
for _, p := range profiles {
if p.Name == profile.Name {
count++
}
}
if count <= 1 {
return profile.Name
}
return fmt.Sprintf("%s (%s)", profile.Name, profilemanager.ID(profile.ID).ShortID())
}
func (s *serviceClient) getProfiles() ([]Profile, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
@@ -324,6 +341,7 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
for _, profile := range profilesResp.Profiles {
profiles = append(profiles, Profile{
ID: profile.Id,
Name: profile.Name,
IsActive: profile.IsActive,
})
@@ -332,10 +350,10 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
return profiles, nil
}
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
func (s *serviceClient) handleProfileLogout(profile Profile, refreshCallback func()) {
dialog.ShowConfirm(
"Deregister",
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profileName),
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profile.Name),
func(confirm bool) {
if !confirm {
return
@@ -356,8 +374,10 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
}
username := currUser.Username
// ProfileName is treated as a handle; send the ID so the
// daemon resolves to exactly this profile.
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{
ProfileName: &profileName,
ProfileName: &profile.ID,
Username: &username,
})
if err != nil {
@@ -368,7 +388,7 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
dialog.ShowInformation(
"Deregistered",
fmt.Sprintf("Successfully deregistered from '%s'", profileName),
fmt.Sprintf("Successfully deregistered from '%s'", profile.Name),
s.wProfiles,
)
@@ -461,6 +481,7 @@ func (p *profileMenu) getProfiles() ([]Profile, error) {
for _, profile := range profilesResp.Profiles {
profiles = append(profiles, Profile{
ID: profile.Id,
Name: profile.Name,
IsActive: profile.IsActive,
})
@@ -501,7 +522,7 @@ func (p *profileMenu) refresh() {
}
if activeProf.ProfileName == "default" || activeProf.Username == currUser.Username {
activeProfState, err := p.profileManager.GetProfileState(activeProf.ProfileName)
activeProfState, err := p.profileManager.GetProfileState(profilemanager.ID(activeProf.Id))
if err != nil {
log.Warnf("failed to get active profile state: %v", err)
p.emailMenuItem.Hide()
@@ -512,7 +533,7 @@ func (p *profileMenu) refresh() {
}
for _, profile := range profiles {
item := p.profileMenuItem.AddSubMenuItem(profile.Name, "")
item := p.profileMenuItem.AddSubMenuItem(formatProfileLabel(profile, profiles), "")
if profile.IsActive {
item.Check()
}
@@ -541,8 +562,8 @@ func (p *profileMenu) refresh() {
return
}
_, err = conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profile.Name,
switchResp, err := conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profile.ID,
Username: &currUser.Username,
})
if err != nil {
@@ -552,7 +573,7 @@ func (p *profileMenu) refresh() {
return
}
err = p.profileManager.SwitchProfile(profile.Name)
err = p.profileManager.SwitchProfile(profilemanager.ID(switchResp.Id))
if err != nil {
log.Errorf("failed to switch profile '%s': %v", profile.Name, err)
return
@@ -727,7 +748,10 @@ func (p *profileMenu) updateMenu() {
}
sort.Slice(profiles, func(i, j int) bool {
return profiles[i].Name < profiles[j].Name
if profiles[i].Name != profiles[j].Name {
return profiles[i].Name < profiles[j].Name
}
return profiles[i].ID < profiles[j].ID
})
p.mu.Lock()

View File

@@ -21,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
nbwebsocket "github.com/netbirdio/netbird/client/wasm/internal/websocket"
"github.com/netbirdio/netbird/util"
)
@@ -30,6 +31,7 @@ const (
pingTimeout = 10 * time.Second
defaultLogLevel = "warn"
defaultSSHDetectionTimeout = 20 * time.Second
dialWebSocketTimeout = 30 * time.Second
icmpEchoRequest = 8
icmpCodeEcho = 0
@@ -677,6 +679,7 @@ func createClientObject(client *netbird.Client) js.Value {
obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client)
obj["dialWebSocket"] = createDialWebSocketMethod(client)
obj["status"] = createStatusMethod(client)
obj["statusSummary"] = createStatusSummaryMethod(client)
obj["statusDetail"] = createStatusDetailMethod(client)
@@ -691,6 +694,74 @@ func createClientObject(client *netbird.Client) js.Value {
return js.ValueOf(obj)
}
func createDialWebSocketMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
url, protocols, timeout, errVal := parseDialWebSocketArgs(args)
if !errVal.IsUndefined() {
return errVal
}
return createPromise(func(resolve, reject js.Value) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := nbwebsocket.Dial(ctx, client, url, protocols)
if err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("dial websocket: %v", err)))
return
}
resolve.Invoke(nbwebsocket.NewJSInterface(conn))
})
})
}
func parseDialWebSocketArgs(args []js.Value) (url string, protocols []string, timeout time.Duration, errVal js.Value) {
if len(args) < 1 || args[0].Type() != js.TypeString {
return "", nil, 0, js.ValueOf("error: dialWebSocket requires a URL string argument")
}
url = args[0].String()
if len(args) >= 2 && !args[1].IsNull() && !args[1].IsUndefined() {
arr, err := jsStringArray(args[1])
if err != nil {
return "", nil, 0, js.ValueOf(fmt.Sprintf("error: protocols: %v", err))
}
protocols = arr
}
timeout = dialWebSocketTimeout
if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() {
if args[2].Type() != js.TypeNumber {
return "", nil, 0, js.ValueOf("error: timeoutMs must be a number")
}
timeoutMs := args[2].Int()
if timeoutMs <= 0 {
return "", nil, 0, js.ValueOf("error: timeout must be positive")
}
timeout = time.Duration(timeoutMs) * time.Millisecond
}
return url, protocols, timeout, js.Undefined()
}
// jsStringArray converts a JS array of strings to a Go []string.
func jsStringArray(v js.Value) ([]string, error) {
if !v.InstanceOf(js.Global().Get("Array")) {
return nil, fmt.Errorf("expected array")
}
n := v.Length()
out := make([]string, n)
for i := 0; i < n; i++ {
el := v.Index(i)
if el.Type() != js.TypeString {
return nil, fmt.Errorf("element %d is not a string", i)
}
out[i] = el.String()
}
return out, nil
}
// netBirdClientConstructor acts as a JavaScript constructor function
func netBirdClientConstructor(_ js.Value, args []js.Value) any {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {

View File

@@ -0,0 +1,304 @@
//go:build js
package websocket
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"sync"
"syscall/js"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
netbird "github.com/netbirdio/netbird/client/embed"
log "github.com/sirupsen/logrus"
)
type closeError struct {
code uint16
reason string
}
func (e *closeError) Error() string {
return fmt.Sprintf("websocket closed: %d %s", e.code, e.reason)
}
// bufferedConn fronts a net.Conn with a reader that serves any bytes buffered
// during the WebSocket handshake before falling through to the raw conn.
type bufferedConn struct {
net.Conn
r io.Reader
}
func (c *bufferedConn) Read(p []byte) (int, error) { return c.r.Read(p) }
// Conn wraps a WebSocket connection over a NetBird TCP connection.
type Conn struct {
conn net.Conn
mu sync.Mutex
closed chan struct{}
closeOnce sync.Once
closeErr error
}
// Dial establishes a WebSocket connection to the given URL through the NetBird network.
// Optional protocols are sent via the Sec-WebSocket-Protocol header.
func Dial(ctx context.Context, client *netbird.Client, rawURL string, protocols []string) (*Conn, error) {
d := ws.Dialer{
NetDial: client.Dial,
Protocols: protocols,
}
conn, br, _, err := d.Dial(ctx, rawURL)
if err != nil {
return nil, fmt.Errorf("websocket dial: %w", err)
}
// br is non-nil when the server pushed frames alongside the handshake
// response; those bytes live in the bufio.Reader and must be drained
// before reading from conn, otherwise we'd skip the first frames.
if br != nil {
if br.Buffered() > 0 {
conn = &bufferedConn{Conn: conn, r: io.MultiReader(br, conn)}
} else {
ws.PutReader(br)
}
}
return &Conn{
conn: conn,
closed: make(chan struct{}),
}, nil
}
// ReadMessage reads the next WebSocket message, handling control frames automatically.
func (c *Conn) ReadMessage() (ws.OpCode, []byte, error) {
for {
msgs, err := wsutil.ReadServerMessage(c.conn, nil)
if err != nil {
return 0, nil, err
}
for _, msg := range msgs {
if msg.OpCode.IsControl() {
if err := c.handleControl(msg); err != nil {
return 0, nil, err
}
continue
}
return msg.OpCode, msg.Payload, nil
}
}
}
func (c *Conn) handleControl(msg wsutil.Message) error {
switch msg.OpCode {
case ws.OpPing:
c.mu.Lock()
defer c.mu.Unlock()
return wsutil.WriteClientMessage(c.conn, ws.OpPong, msg.Payload)
case ws.OpClose:
code, reason := parseClosePayload(msg.Payload)
return &closeError{code: code, reason: reason}
default:
return nil
}
}
// WriteText sends a text WebSocket message.
func (c *Conn) WriteText(data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
return wsutil.WriteClientMessage(c.conn, ws.OpText, data)
}
// WriteBinary sends a binary WebSocket message.
func (c *Conn) WriteBinary(data []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
return wsutil.WriteClientMessage(c.conn, ws.OpBinary, data)
}
// Close sends a close frame with StatusNormalClosure and closes the underlying connection.
func (c *Conn) Close() error {
return c.closeWith(ws.StatusNormalClosure, "")
}
// closeWith sends a close frame with the given code/reason and closes the underlying connection.
// Used to echo the server's code when responding to a server-initiated close per RFC 6455 §5.5.1.
func (c *Conn) closeWith(code ws.StatusCode, reason string) error {
var first bool
c.closeOnce.Do(func() {
first = true
close(c.closed)
c.mu.Lock()
_ = wsutil.WriteClientMessage(c.conn, ws.OpClose, ws.NewCloseFrameBody(code, reason))
c.mu.Unlock()
c.closeErr = c.conn.Close()
})
if !first {
return net.ErrClosed
}
return c.closeErr
}
// NewJSInterface creates a JavaScript object wrapping the WebSocket connection.
// It exposes: send(string|Uint8Array), close(), and callback properties
// onmessage, onclose, onerror.
//
// Callback properties may be set from the JS thread while the read loop
// goroutine reads them. In WASM this is safe because Go and JS share a
// single thread, but the design would need synchronization on
// multi-threaded runtimes.
func NewJSInterface(conn *Conn) js.Value {
obj := js.Global().Get("Object").Call("create", js.Null())
sendFunc := js.FuncOf(func(_ js.Value, args []js.Value) any {
if len(args) < 1 {
log.Errorf("websocket send requires a data argument")
return js.ValueOf(false)
}
data := args[0]
switch data.Type() {
case js.TypeString:
if err := conn.WriteText([]byte(data.String())); err != nil {
log.Errorf("failed to send websocket text: %v", err)
return js.ValueOf(false)
}
default:
buf, err := jsToBytes(data)
if err != nil {
log.Errorf("failed to convert js value to bytes: %v", err)
return js.ValueOf(false)
}
if err := conn.WriteBinary(buf); err != nil {
log.Errorf("failed to send websocket binary: %v", err)
return js.ValueOf(false)
}
}
return js.ValueOf(true)
})
obj.Set("send", sendFunc)
closeFunc := js.FuncOf(func(_ js.Value, _ []js.Value) any {
if err := conn.Close(); err != nil {
log.Debugf("failed to close websocket: %v", err)
}
return js.Undefined()
})
obj.Set("close", closeFunc)
go func() {
defer func() {
if err := conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
log.Debugf("close websocket on readLoop exit: %v", err)
}
}()
readLoop(conn, obj)
// Undefining before Release turns post-close JS calls into TypeError
// instead of a silent "call to released function".
obj.Set("send", js.Undefined())
obj.Set("close", js.Undefined())
sendFunc.Release()
closeFunc.Release()
}()
return obj
}
func jsToBytes(data js.Value) ([]byte, error) {
var uint8Array js.Value
switch {
case data.InstanceOf(js.Global().Get("Uint8Array")):
uint8Array = data
case data.InstanceOf(js.Global().Get("ArrayBuffer")):
uint8Array = js.Global().Get("Uint8Array").New(data)
default:
return nil, fmt.Errorf("send: unsupported data type, use string, Uint8Array, or ArrayBuffer")
}
buf := make([]byte, uint8Array.Get("length").Int())
js.CopyBytesToGo(buf, uint8Array)
return buf, nil
}
func readLoop(conn *Conn, obj js.Value) {
var ce *closeError
defer func() { invokeOnClose(obj, ce) }()
for {
select {
case <-conn.closed:
return
default:
}
op, payload, err := conn.ReadMessage()
if err != nil {
ce = handleReadError(conn, obj, err)
return
}
dispatchMessage(obj, op, payload)
}
}
func handleReadError(conn *Conn, obj js.Value, err error) *closeError {
var ce *closeError
if errors.As(err, &ce) {
if cerr := conn.closeWith(ws.StatusCode(ce.code), ce.reason); cerr != nil {
log.Debugf("failed to close websocket after server close frame: %v", cerr)
}
return ce
}
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
return nil
}
if onerror := obj.Get("onerror"); onerror.Truthy() {
onerror.Invoke(js.ValueOf(err.Error()))
}
return nil
}
func invokeOnClose(obj js.Value, ce *closeError) {
onclose := obj.Get("onclose")
if !onclose.Truthy() {
return
}
if ce != nil {
onclose.Invoke(js.ValueOf(int(ce.code)), js.ValueOf(ce.reason))
return
}
onclose.Invoke()
}
func dispatchMessage(obj js.Value, op ws.OpCode, payload []byte) {
onmessage := obj.Get("onmessage")
if !onmessage.Truthy() {
return
}
switch op {
case ws.OpText:
onmessage.Invoke(js.ValueOf(string(payload)))
case ws.OpBinary:
uint8Array := js.Global().Get("Uint8Array").New(len(payload))
js.CopyBytesToJS(uint8Array, payload)
onmessage.Invoke(uint8Array)
}
}
func parseClosePayload(payload []byte) (uint16, string) {
if len(payload) < 2 {
return 1005, "" // RFC 6455: No Status Rcvd
}
code := binary.BigEndian.Uint16(payload[:2])
return code, string(payload[2:])
}

View File

@@ -2,4 +2,5 @@ FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-server" ]
CMD ["--config", "/etc/netbird/config.yaml"]
COPY netbird-server /go/bin/netbird-server
ARG TARGETPLATFORM
COPY ${TARGETPLATFORM}/netbird-server /go/bin/netbird-server

3
go.mod
View File

@@ -56,6 +56,7 @@ require (
github.com/fsnotify/fsnotify v1.9.0
github.com/gliderlabs/ssh v0.3.8
github.com/go-jose/go-jose/v4 v4.1.4
github.com/gobwas/ws v1.4.0
github.com/goccy/go-yaml v1.18.0
github.com/godbus/dbus/v5 v5.1.0
github.com/golang-jwt/jwt/v5 v5.3.1
@@ -215,6 +216,8 @@ require (
github.com/go-viper/mapstructure/v2 v2.5.0 // indirect
github.com/go-webauthn/webauthn v0.16.4 // indirect
github.com/go-webauthn/x v0.2.3 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
github.com/google/btree v1.1.3 // indirect

7
go.sum
View File

@@ -249,6 +249,12 @@ github.com/go-webauthn/webauthn v0.16.4 h1:R9jqR/cYZa7hRquFF7Za/8qoH/K/TIs1/Q/4C
github.com/go-webauthn/webauthn v0.16.4/go.mod h1:SU2ljAgToTV/YLPI0C05QS4qn+e04WpB5g1RMfcZfS4=
github.com/go-webauthn/x v0.2.3 h1:8oArS+Rc1SWFLXhE17KZNx258Z4kUSyaDgsSncCO5RA=
github.com/go-webauthn/x v0.2.3/go.mod h1:tM04GF3V6VYq79AZMl7vbj4q6pz9r7L2criWRzbWhPk=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.4.0 h1:CTaoG1tojrh4ucGPcoJFiAQUAsEWekEWvLy7GsVNqGs=
github.com/gobwas/ws v1.4.0/go.mod h1:G3gNqMNtPppf5XUz7O4shetPpcZ1VJ7zt18dlUeakrc=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk=
@@ -845,6 +851,7 @@ golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -41,6 +41,8 @@ type Config struct {
GRPCAddr string
}
const localConnectorID = "local"
// Provider wraps a Dex server
type Provider struct {
config *Config
@@ -544,7 +546,7 @@ func (p *Provider) CreateUser(ctx context.Context, email, username, password str
// Encode the user ID in Dex's format: base64(protobuf{user_id, connector_id})
// This matches the format Dex uses in JWT tokens
encodedID := EncodeDexUserID(userID, "local")
encodedID := EncodeDexUserID(userID, localConnectorID)
return encodedID, nil
}
@@ -619,6 +621,13 @@ func DecodeDexUserID(encodedID string) (userID, connectorID string, err error) {
return userID, connectorID, nil
}
// IsLocalUserID reports whether encodedID is a Dex subject for the built-in
// local password connector.
func IsLocalUserID(encodedID string) bool {
_, connectorID, err := DecodeDexUserID(encodedID)
return err == nil && connectorID == localConnectorID
}
// GetUser returns a user by email
func (p *Provider) GetUser(ctx context.Context, email string) (storage.Password, error) {
return p.storage.GetPassword(ctx, email)

View File

@@ -115,6 +115,26 @@ func TestDecodeDexUserID(t *testing.T) {
}
}
func TestIsLocalUserID(t *testing.T) {
tests := []struct {
name string
encodedID string
want bool
}{
{name: "local connector", encodedID: EncodeDexUserID("7aad8c05-3287-473f-b42a-365504bf25e7", "local"), want: true},
{name: "federated connector", encodedID: EncodeDexUserID("entra-user", "entra"), want: false},
{name: "non-dex external IdP id", encodedID: "google-oauth2|1234567890", want: false},
{name: "invalid base64", encodedID: "not-valid-base64!!!", want: false},
{name: "empty", encodedID: "", want: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, IsLocalUserID(tt.encodedID))
})
}
}
func TestEncodeDexUserID(t *testing.T) {
userID := "7aad8c05-3287-473f-b42a-365504bf25e7"
connectorID := "local"

View File

@@ -2,4 +2,5 @@ FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-mgmt","management"]
CMD ["--log-file", "console"]
COPY netbird-mgmt /go/bin/netbird-mgmt
ARG TARGETPLATFORM
COPY ${TARGETPLATFORM}/netbird-mgmt /go/bin/netbird-mgmt

View File

@@ -1,5 +0,0 @@
FROM ubuntu:24.04
RUN apt update && apt install -y ca-certificates && rm -fr /var/cache/apt
ENTRYPOINT [ "/go/bin/netbird-mgmt","management","--log-level","debug"]
CMD ["--log-file", "console"]
COPY netbird-mgmt /go/bin/netbird-mgmt

View File

@@ -45,7 +45,7 @@ type Controller struct {
EphemeralPeersManager ephemeral.Manager
accountUpdateLocks sync.Map
sendAccountUpdateLocks sync.Map
affectedPeerUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
@@ -64,6 +64,13 @@ type bufferUpdate struct {
update atomic.Bool
}
type bufferAffectedUpdate struct {
sendMu sync.Mutex
dataMu sync.Mutex
next *time.Timer
peerIDs map[string]struct{}
}
var _ network_map.Controller = (*Controller)(nil)
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
@@ -201,7 +208,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
@@ -226,44 +233,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
return nil
}
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return nil
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
})
return
}
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
}()
return nil
}
// UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
@@ -273,6 +242,143 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
return c.sendUpdateAccountPeers(ctx, accountID, reason)
}
// UpdateAffectedPeers updates only the specified peers that belong to an account.
func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
if len(peerIDs) == 0 {
return nil
}
return c.sendUpdateForAffectedPeers(ctx, accountID, peerIDs)
}
func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers: %v (caller: %s)", accountID, len(peerIDs), peerIDs, util.GetCallerName())
if !c.hasConnectedPeers(peerIDs) {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no connected peers among %v, skipping", peerIDs)
return nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
globalStart := time.Now()
peersToUpdate := c.filterConnectedAffectedPeers(account, peerIDs)
if len(peersToUpdate) == 0 {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no peers to update (affected peers not found in account or no channels)")
return nil
}
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: sending network map to %d connected peers", len(peersToUpdate))
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return fmt.Errorf("failed to get validate peers: %v", err)
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
account.InjectProxyPolicies(ctx)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return fmt.Errorf("failed to get proxy network maps: %v", err)
}
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get flow enabled status: %v", err)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return fmt.Errorf("failed to get account zones: %v", err)
}
for _, peer := range peersToUpdate {
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
defer wg.Done()
defer func() { <-semaphore }()
start := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
return
}
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
})
}(peer)
}
wg.Wait()
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
}
return nil
}
func (c *Controller) hasConnectedPeers(peerIDs []string) bool {
for _, id := range peerIDs {
if c.peersUpdateManager.HasChannel(id) {
return true
}
}
return false
}
func (c *Controller) filterConnectedAffectedPeers(account *types.Account, peerIDs []string) []*nbpeer.Peer {
affected := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
affected[id] = struct{}{}
}
var result []*nbpeer.Peer
for _, peer := range account.Peers {
if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) {
result = append(result, peer)
}
}
return result
}
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
if !c.peersUpdateManager.HasChannel(peerId) {
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
@@ -381,66 +487,164 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
if len(peerIDs) == 0 {
return nil
}
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
peerIDs: make(map[string]struct{}),
})
b := bufUpd.(*bufferAffectedUpdate)
b.addPeerIDs(peerIDs)
if !b.sendMu.TryLock() {
// Another goroutine is already sending; it will pick up our IDs on its next drain.
return nil
}
b.stopTimer()
// The send and the debounced timer outlive the calling request, so detach from
// its context to avoid sending with a cancelled context once the handler returns.
bgCtx := context.WithoutCancel(ctx)
collected := b.drainPeerIDs()
go func() {
defer b.sendMu.Unlock()
_ = c.sendUpdateForAffectedPeers(bgCtx, accountID, collected)
// Check if more peer IDs accumulated while we were sending.
if !b.hasPending() {
return
}
// Schedule a debounced flush for the newly accumulated IDs.
b.setTimer(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
ids := b.drainPeerIDs()
if len(ids) > 0 {
_ = c.sendUpdateForAffectedPeers(bgCtx, accountID, ids)
}
})
}()
return nil
}
func (b *bufferAffectedUpdate) addPeerIDs(ids []string) {
b.dataMu.Lock()
for _, id := range ids {
b.peerIDs[id] = struct{}{}
}
b.dataMu.Unlock()
}
func (b *bufferAffectedUpdate) drainPeerIDs() []string {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if len(b.peerIDs) == 0 {
return nil
}
ids := make([]string, 0, len(b.peerIDs))
for id := range b.peerIDs {
ids = append(ids, id)
}
b.peerIDs = make(map[string]struct{})
return ids
}
func (b *bufferAffectedUpdate) hasPending() bool {
b.dataMu.Lock()
defer b.dataMu.Unlock()
return len(b.peerIDs) > 0
}
func (b *bufferAffectedUpdate) stopTimer() {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if b.next != nil {
b.next.Stop()
}
}
func (b *bufferAffectedUpdate) setTimer(d time.Duration, f func()) {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if b.next == nil {
b.next = time.AfterFunc(d, f)
return
}
b.next.Reset(d)
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
emptyMap := &types.NetworkMap{
Network: network.Copy(),
}
return peer, emptyMap, nil, 0, nil
return emptyMap, nil, 0, nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
account.InjectProxyPolicies(ctx)
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
startPosture := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, peer.ID)
postureChecks, err := c.getPeerPostureChecks(account, peerID)
if err != nil {
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers)
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return nil, nil, nil, 0, err
return nil, nil, 0, err
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peerID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
proxyNetworkMap, ok := proxyNetworkMaps[peerID]
if ok {
networkMap.Merge(proxyNetworkMap)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
return peer, networkMap, postureChecks, dnsFwdPort, nil
return networkMap, postureChecks, dnsFwdPort, nil
}
// GetDNSDomain returns the configured dnsDomain
@@ -578,21 +782,24 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
return false, nil
}
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
err := c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer update in account %s, skipping", accountID)
return nil
}
return nil
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
}
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer add in account %s, skipping", accountID)
return nil
}
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
}
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return err
@@ -625,7 +832,11 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
c.peersUpdateManager.CloseChannel(ctx, peerID)
}
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer delete in account %s, skipping", accountID)
return nil
}
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)

View File

@@ -19,17 +19,19 @@ const (
type Controller interface {
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error)
GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
CountStreams() int
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)

View File

@@ -57,6 +57,20 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, r
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
}
// BufferUpdateAffectedPeers mocks base method.
func (m *MockController) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BufferUpdateAffectedPeers", ctx, accountID, peerIDs, reason)
ret0, _ := ret[0].(error)
return ret0
}
// BufferUpdateAffectedPeers indicates an expected call of BufferUpdateAffectedPeers.
func (mr *MockControllerMockRecorder) BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAffectedPeers), ctx, accountID, peerIDs, reason)
}
// CountStreams mocks base method.
func (m *MockController) CountStreams() int {
m.ctrl.T.Helper()
@@ -113,21 +127,20 @@ func (mr *MockControllerMockRecorder) GetNetworkMap(ctx, peerID any) *gomock.Cal
}
// GetValidatedPeerWithMap mocks base method.
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *peer.Peer) (*peer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
func (m *MockController) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peerID string) (*types.NetworkMap, []*posture.Checks, int64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, p)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap)
ret2, _ := ret[2].([]*posture.Checks)
ret3, _ := ret[3].(int64)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
ret := m.ctrl.Call(m, "GetValidatedPeerWithMap", ctx, isRequiresApproval, accountID, peerID)
ret0, _ := ret[0].(*types.NetworkMap)
ret1, _ := ret[1].([]*posture.Checks)
ret2, _ := ret[2].(int64)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
}
// GetValidatedPeerWithMap indicates an expected call of GetValidatedPeerWithMap.
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, p any) *gomock.Call {
func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, peerID)
}
// OnPeerConnected mocks base method.
@@ -158,45 +171,45 @@ func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID
}
// OnPeersAdded mocks base method.
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs, affectedPeerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersAdded indicates an expected call of OnPeersAdded.
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs, affectedPeerIDs)
}
// OnPeersDeleted mocks base method.
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs, affectedPeerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs, affectedPeerIDs)
}
// OnPeersUpdated mocks base method.
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs, affectedPeerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs, affectedPeerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs, affectedPeerIDs)
}
// StartWarmup mocks base method.
@@ -250,3 +263,17 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason)
}
// UpdateAffectedPeers mocks base method.
func (m *MockController) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
func (mr *MockControllerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
}

View File

@@ -242,7 +242,7 @@ func (m *managerImpl) CreateProxyPeer(ctx context.Context, accountID string, pee
},
}
_, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
_, _, _, _, err = m.accountManager.AddPeer(ctx, accountID, "", "", peer, true)
if err != nil {
return fmt.Errorf("failed to create proxy peer: %w", err)
}

View File

@@ -918,6 +918,10 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin
}
for _, svc := range services {
if err = transaction.DeleteServiceTargets(ctx, accountID, svc.ID); err != nil {
return fmt.Errorf("failed to delete service targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, svc.ID); err != nil {
return fmt.Errorf("failed to delete service: %w", err)
}
@@ -1270,6 +1274,10 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv
return status.Errorf(status.PermissionDenied, "cannot delete service exposed by another peer")
}
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}
@@ -1319,6 +1327,10 @@ func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerI
return nil
}
if err = transaction.DeleteServiceTargets(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service targets: %w", err)
}
if err = transaction.DeleteService(ctx, accountID, serviceID); err != nil {
return fmt.Errorf("delete service: %w", err)
}

View File

@@ -458,6 +458,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteServiceTargets(ctx, accountID, serviceID).
Return(nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
@@ -560,6 +563,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteServiceTargets(ctx, accountID, serviceID).
Return(nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
@@ -604,6 +610,9 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
txMock.EXPECT().
GetServiceByID(ctx, store.LockingStrengthUpdate, accountID, serviceID).
Return(newEphemeralService(), nil)
txMock.EXPECT().
DeleteServiceTargets(ctx, accountID, serviceID).
Return(nil)
txMock.EXPECT().
DeleteService(ctx, accountID, serviceID).
Return(nil)
@@ -1192,6 +1201,67 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
assert.Len(t, targets, 0, "All targets should be deleted when service is deleted")
}
func TestDeleteExpiredPeerService_DeletesTargets(t *testing.T) {
ctx := context.Background()
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
targets, err := testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
require.Len(t, targets, 1, "ephemeral peer-exposed service should have exactly one persisted target before reaping")
expireEphemeralService(t, testStore, testAccountID, resp.Domain)
err = mgr.deleteExpiredPeerService(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "expired peer-exposed service should be deleted")
s, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.NotFound, s.Type())
targets, err = testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
assert.Len(t, targets, 0, "orphaned target rows must be deleted when an expired peer-exposed service is reaped")
}
func TestDeleteServiceFromPeer_DeletesTargets(t *testing.T) {
ctx := context.Background()
mgr, testStore := setupIntegrationTest(t)
resp, err := mgr.CreateServiceFromPeer(ctx, testAccountID, testPeerID, &rpservice.ExposeServiceRequest{
Port: 8080,
Mode: "http",
})
require.NoError(t, err)
svcID := resolveServiceIDByDomain(t, testStore, resp.Domain)
targets, err := testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
require.Len(t, targets, 1, "ephemeral peer-exposed service should have exactly one persisted target before stopping")
err = mgr.StopServiceFromPeer(ctx, testAccountID, testPeerID, svcID)
require.NoError(t, err)
_, err = testStore.GetServiceByDomain(ctx, resp.Domain)
require.Error(t, err, "stopped peer-exposed service should be deleted")
s, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, status.NotFound, s.Type())
targets, err = testStore.GetTargetsByServiceID(ctx, store.LockingStrengthNone, testAccountID, svcID)
require.NoError(t, err)
assert.Len(t, targets, 0, "orphaned target rows must be deleted when a peer stops its exposed service")
}
func TestValidateProtocolChange(t *testing.T) {
tests := []struct {
name string

View File

@@ -778,7 +778,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
}
peer, netMap, postureChecks, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
peer, network, postureChecks, enableSSH, err := s.accountManager.LoginPeer(ctx, types.PeerLogin{
WireGuardPubKey: peerKey.String(),
SSHKey: string(sshKey),
Meta: peerMeta,
@@ -792,7 +792,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
return nil, mapError(ctx, err)
}
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
loginResp, err := s.prepareLoginResponse(ctx, peer, network, postureChecks, enableSSH)
if err != nil {
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
return nil, status.Errorf(codes.Internal, "failed logging in peer")
@@ -895,7 +895,7 @@ func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMess
}, nil
}
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, network *types.Network, postureChecks []*posture.Checks, enableSSH bool) (*proto.LoginResponse, error) {
var relayToken *Token
var err error
if s.config.Relay != nil && len(s.config.Relay.Addresses) > 0 {
@@ -914,7 +914,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
// if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, netMap.EnableSSH),
PeerConfig: toPeerConfig(peer, network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, enableSSH),
Checks: toProtocolChecks(ctx, postureChecks),
}

View File

@@ -28,6 +28,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/account"
@@ -1588,7 +1589,10 @@ func (am *DefaultAccountManager) updateUserAuthWithSingleMode(ctx context.Contex
// and propagates changes to peers if group propagation is enabled.
// requires userAuth to have been ValidateAndParseToken and EnsureUserAccessByJWTGroups by the AuthManager
func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error {
if userAuth.IsChild || userAuth.IsPAT {
// Child accounts and PAT-authenticated requests do not sync JWT groups.
// Embedded-Dex local users also skip sync because local password authentication
// does not provide external IdP group claims.
if userAuth.IsChild || userAuth.IsPAT || dex.IsLocalUserID(userAuth.UserId) {
return nil
}
@@ -1890,7 +1894,7 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil {
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
@@ -2573,7 +2577,9 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
if err != nil {
return err
}
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
changedPeerIDs := []string{peerID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -2664,7 +2670,9 @@ func (am *DefaultAccountManager) UpdatePeerIPv6(ctx context.Context, accountID,
}
if updateNetworkMap {
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peerID}); err != nil {
changedPeerIDs := []string{peerID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
return fmt.Errorf("notify network map controller: %w", err)
}
}

View File

@@ -13,6 +13,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -61,7 +62,7 @@ type Manager interface {
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -69,7 +70,7 @@ type Manager interface {
UpdatePeerIPv6(ctx context.Context, accountID, userID, peerID string, newIPv6 netip.Addr) error
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
AddPeer(ctx context.Context, accountID, setupKey, userID string, p *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error)
CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error)
DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error
GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error)
@@ -108,8 +109,8 @@ type Manager interface {
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) // used by peer gRPC API
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
@@ -128,6 +129,7 @@ type Manager interface {
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change)
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error

View File

@@ -15,6 +15,7 @@ import (
dns "github.com/netbirdio/netbird/dns"
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
activity "github.com/netbirdio/netbird/management/server/activity"
affectedpeers "github.com/netbirdio/netbird/management/server/affectedpeers"
idp "github.com/netbirdio/netbird/management/server/idp"
peer "github.com/netbirdio/netbird/management/server/peer"
posture "github.com/netbirdio/netbird/management/server/posture"
@@ -79,14 +80,15 @@ func (mr *MockManagerMockRecorder) AccountExists(ctx, accountID interface{}) *go
}
// AddPeer mocks base method.
func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (m *MockManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, p *peer.Peer, temporary bool) (*peer.Peer, *types.Network, []*posture.Checks, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddPeer", ctx, accountID, setupKey, userID, p, temporary)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap)
ret1, _ := ret[1].(*types.Network)
ret2, _ := ret[2].([]*posture.Checks)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
ret3, _ := ret[3].(bool)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
}
// AddPeer indicates an expected call of AddPeer.
@@ -1288,14 +1290,15 @@ func (mr *MockManagerMockRecorder) ListUsers(ctx, accountID interface{}) *gomock
}
// LoginPeer mocks base method.
func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.NetworkMap, []*posture.Checks, error) {
func (m *MockManager) LoginPeer(ctx context.Context, login types.PeerLogin) (*peer.Peer, *types.Network, []*posture.Checks, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LoginPeer", ctx, login)
ret0, _ := ret[0].(*peer.Peer)
ret1, _ := ret[1].(*types.NetworkMap)
ret1, _ := ret[1].(*types.Network)
ret2, _ := ret[2].([]*posture.Checks)
ret3, _ := ret[3].(error)
return ret0, ret1, ret2, ret3
ret3, _ := ret[3].(bool)
ret4, _ := ret[4].(error)
return ret0, ret1, ret2, ret3, ret4
}
// LoginPeer indicates an expected call of LoginPeer.
@@ -1320,17 +1323,17 @@ func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID int
}
// MarkPeerConnected mocks base method.
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt)
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
ret0, _ := ret[0].(error)
return ret0
}
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
}
// MarkPeerDisconnected mocks base method.
@@ -1637,6 +1640,18 @@ func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason)
}
// ExpandAndUpdateAffected mocks base method.
func (m *MockManager) ExpandAndUpdateAffected(ctx context.Context, accountID string, snap *affectedpeers.Snapshot, change affectedpeers.Change) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ExpandAndUpdateAffected", ctx, accountID, snap, change)
}
// ExpandAndUpdateAffected indicates an expected call of ExpandAndUpdateAffected.
func (mr *MockManagerMockRecorder) ExpandAndUpdateAffected(ctx, accountID, snap, change interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExpandAndUpdateAffected", reflect.TypeOf((*MockManager)(nil).ExpandAndUpdateAffected), ctx, accountID, snap, change)
}
// UpdateAccountSettings mocks base method.
func (m *MockManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
m.ctrl.T.Helper()

View File

@@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
@@ -83,7 +84,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager nbAccount.Manager, account
setupKey = key.Key
}
_, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
_, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey, userID, peer, false)
if err != nil {
t.Error("expected to add new peer successfully after creating new account, but failed", err)
}
@@ -723,6 +724,28 @@ func TestDefaultAccountManager_SyncUserJWTGroups(t *testing.T) {
require.Equal(t, g2.Name, "group2", "group2 name should match")
require.Equal(t, g2.Issued, types.GroupIssuedJWT, "group2 issued should match")
})
t.Run("local embedded-Dex user is skipped", func(t *testing.T) {
initAccount.Settings.JWTGroupsEnabled = true
initAccount.Settings.JWTGroupsClaimName = "idp-groups"
err := manager.Store.SaveAccount(context.Background(), initAccount)
require.NoError(t, err, "save account failed")
localClaims := auth.UserAuth{
AccountId: accountID,
Domain: domain,
UserId: dex.EncodeDexUserID("local-owner", "local"),
Groups: []string{"group3", "group4"},
}
err = manager.SyncUserJWTGroups(context.Background(), localClaims)
require.NoError(t, err, "sync should be a no-op for local users")
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "get account failed")
for _, g := range account.Groups {
require.NotEqual(t, "group3", g.Name, "local user JWT groups must not be synced")
require.NotEqual(t, "group4", g.Name, "local user JWT groups must not be synced")
}
})
}
func TestAccountManager_PrivateAccount(t *testing.T) {
@@ -1069,7 +1092,7 @@ func TestAccountManager_AddPeer(t *testing.T) {
}
expectedPeerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}, false)
@@ -1133,7 +1156,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
expectedPeerKey := key.PublicKey().String()
expectedUserID := userID
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
}, false)
@@ -1481,7 +1504,7 @@ func TestAccountManager_DeletePeer(t *testing.T) {
peerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: peerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: peerKey},
}, false)
@@ -1803,7 +1826,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
@@ -1813,7 +1836,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -1859,7 +1882,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
@@ -1884,7 +1907,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1904,7 +1927,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
}, false)
@@ -1912,7 +1935,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("disconnect peer when session token matches", func(t *testing.T) {
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1933,7 +1956,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
// Newer stream wins on connect (sets SessionStartedAt = now ns).
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1957,7 +1980,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
node2SyncTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano())
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano(), nil)
require.NoError(t, err, "node 2 should connect peer")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1967,7 +1990,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
"SessionStartedAt should equal node2SyncTime token")
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano())
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano(), nil)
require.NoError(t, err, "stale connect should not return error")
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1994,7 +2017,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"},
}, false)
@@ -2029,7 +2052,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
defer done.Done()
ready.Done()
start.Wait()
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token)
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token, nil)
}()
}
@@ -2057,7 +2080,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
LoginExpirationEnabled: true,
@@ -2070,7 +2093,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}
@@ -3253,7 +3276,7 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
}
expectedPeerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
peer, _, _, _, err := manager.AddPeer(context.Background(), "", setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
Status: &nbpeer.PeerStatus{
@@ -3282,6 +3305,19 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
// when the channel delivers.
const peerUpdateTimeout = 5 * time.Second
func drainPeerUpdates(ch <-chan *network_map.UpdateMessage) {
for {
select {
case _, ok := <-ch:
if !ok {
return
}
case <-time.After(200 * time.Millisecond):
return
}
}
}
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {
@@ -3408,7 +3444,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
_, _, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
WireGuardPubKey: account.Peers["peer-1"].Key,
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
@@ -3477,7 +3513,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
_, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
_, _, _, _, err := manager.LoginPeer(context.Background(), types.PeerLogin{
WireGuardPubKey: "some-new-key" + strconv.Itoa(i),
SSHKey: "someKey",
Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)},
@@ -3872,13 +3908,13 @@ func TestDefaultAccountManager_UpdatePeerIP(t *testing.T) {
key2, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peer1, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer1, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key1.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-1"},
}, false)
require.NoError(t, err, "unable to add peer1")
peer2, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
peer2, _, _, _, err := manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: key2.PublicKey().String(),
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer-2"},
}, false)

View File

@@ -0,0 +1,117 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/affectedpeers"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
)
// TestAffectedPeers_DependencyCoverageMatrix enumerates each network-map
// dependency crossed with the change-type that can alter it, asserting the
// resolver folds in exactly the peers whose map changes. A new dependency that
// the resolver fails to walk should fail one of these rows; a new change-type
// without a row is a coverage gap to add here.
func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
type row struct {
name string
build func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string)
}
rows := []row{
{
name: "policy-groups/source-group-change refreshes source+routing, excludes unrelated",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "resource-routing-bridge/router-peer-change refreshes policy sources",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "policy-change/explicit-policy refreshes source+routing",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
return affectedpeers.Change{Policies: []*types.Policy{policy}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "policy-destinationresource/explicit-policy bridges to routing peer",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
return affectedpeers.Change{Policies: []*types.Policy{policy}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "resource-change refreshes source+routing on its network",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{Resources: []*resourceTypes.NetworkResource{
{ID: s.resourceID, NetworkID: s.networkID, GroupIDs: []string{s.resourceGroupID}},
}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "network-change refreshes source+routing on that network",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{Networks: []*networkTypes.Network{{ID: s.networkID}}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "posture-check-change refreshes source+routing of gated policy",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
check, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
Name: "cov-min-version",
Checks: posture.ChecksDefinition{NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"}},
}, true)
require.NoError(t, err)
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
policy.SourcePostureChecks = []string{check.ID}
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
require.NoError(t, err)
return affectedpeers.Change{PostureCheckIDs: []string{check.ID}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
}
for _, r := range rows {
t.Run(r.name, func(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
change, mustContain, mustExclude := r.build(t, s, ctx)
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
for _, id := range mustContain {
assert.Contains(t, affected, id, "expected peer to be affected")
}
for _, id := range mustExclude {
assert.NotContains(t, affected, id, "peer must not be affected")
}
})
}
}

View File

@@ -0,0 +1,143 @@
package server
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// An update spans an old and a new state. The affected set must be the UNION of
// peers reachable before and after the change; resolving only against the final
// state drops peers that were reachable but no longer are. These tests pin the
// two paths where the old state is reachable only by the changed object's
// previous references: detaching a resource group, and re-pointing a router peer.
// TestAffectedPeers_E2E_UpdateResource_DetachGroup_RefreshesOldGroupSources:
// a resource is reachable by a source group via two destination resource groups;
// detaching one of them must still refresh that group's policy source peers, even
// though the post-update resource no longer maps to it.
func TestAffectedPeers_E2E_UpdateResource_DetachGroup_RefreshesOldGroupSources(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
// A second resource group + a second source group/peer that reaches the
// resource only through that second group.
const detachGroupID = "rs-detach-grp"
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{ID: detachGroupID, Name: "rs-detach"}))
const secondSourceGroupID = "rs-source-grp-2"
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-detach-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
secondSourcePeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: secondSourceGroupID, Name: "rs-source-2", Peers: []string{secondSourcePeer.ID},
}))
resourcesManager, _, _ := s.managers()
// Attach the resource to the detach group as well: now in [resourceGroup, detachGroup].
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/24",
GroupIDs: []string{s.resourceGroupID, detachGroupID},
Enabled: true,
})
require.NoError(t, err)
// Policy granting the second source group access via the detach group.
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(secondSourceGroupID, detachGroupID), true)
require.NoError(t, err)
secondSrcCh := s.updateManager.CreateChannel(ctx, secondSourcePeer.ID)
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, secondSourcePeer.ID) })
settleAffectedUpdates(secondSrcCh)
done := make(chan struct{})
go func() {
// Detaching the resource from detachGroup removes the second source's
// access; that source peer must be refreshed even though the post-update
// resource no longer maps to detachGroup.
peerShouldReceiveUpdate(t, secondSrcCh)
close(done)
}()
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/24",
GroupIDs: []string{s.resourceGroupID}, // detached detachGroup
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: detaching a resource group did not refresh the old group's policy source peer")
}
}
// TestAffectedPeers_E2E_UpdateRouter_RepointPeer_RefreshesOldRoutingPeer:
// changing router.Peer within the same network must still refresh the OLD routing
// peer, which loses its routing role.
func TestAffectedPeers_E2E_UpdateRouter_RepointPeer_RefreshesOldRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
_, routersManager, _ := s.managers()
routers, err := s.manager.Store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, s.accountID, s.networkID)
require.NoError(t, err)
require.Len(t, routers, 1)
router := routers[0]
oldRoutingPeer := router.Peer
require.NotEmpty(t, oldRoutingPeer)
// A new peer to become the routing peer in place of the old one.
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-newrouter-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
newRoutingPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
oldCh := s.updateManager.CreateChannel(ctx, oldRoutingPeer)
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, oldRoutingPeer) })
settleAffectedUpdates(oldCh)
done := make(chan struct{})
go func() {
// The old routing peer stops serving the resource and must be refreshed.
peerShouldReceiveUpdate(t, oldCh)
close(done)
}()
_, err = routersManager.UpdateRouter(ctx, userID, &routerTypes.NetworkRouter{
ID: router.ID,
NetworkID: s.networkID,
AccountID: s.accountID,
Peer: newRoutingPeer.ID, // repoint within the same network
Masquerade: true,
Metric: 9999,
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: re-pointing the router peer did not refresh the old routing peer")
}
}

View File

@@ -0,0 +1,255 @@
package server
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"sort"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// allPeerMaps computes the serialized per-peer network map for every peer in the
// account, mirroring the controller's compute path so the property test compares
// against real output.
func allPeerMaps(t *testing.T, manager *DefaultAccountManager, accountID string) map[string]string {
t.Helper()
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
require.NoError(t, err)
account.InjectProxyPolicies(ctx)
validated := make(map[string]struct{}, len(account.Peers))
for id := range account.Peers {
validated[id] = struct{}{}
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
out := make(map[string]string, len(account.Peers))
for peerID := range account.Peers {
nm := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validated, resourcePolicies, routers, nil, groupIDToUserIDs)
// Network.Serial is an account-global counter bumped on every change; it
// is not a per-peer dependency, so normalize it out of the comparison.
if nm.Network != nil {
nm.Network.Serial = 0
}
out[peerID] = canonicalJSON(t, nm)
}
return out
}
// canonicalJSON marshals v and returns an order-insensitive string form: every
// JSON array is sorted by the canonical form of its elements. The network map's
// Peers/Routes/FirewallRules/SourceRanges slices have nondeterministic order, so
// a raw JSON compare would report spurious changes.
func canonicalJSON(t *testing.T, v interface{}) string {
t.Helper()
b, err := json.Marshal(v)
require.NoError(t, err)
var parsed interface{}
require.NoError(t, json.Unmarshal(b, &parsed))
canonicalized, err := json.Marshal(sortAny(parsed))
require.NoError(t, err)
return string(canonicalized)
}
func sortAny(v interface{}) interface{} {
switch val := v.(type) {
case []interface{}:
for i := range val {
val[i] = sortAny(val[i])
}
sort.Slice(val, func(i, j int) bool {
bi, _ := json.Marshal(val[i])
bj, _ := json.Marshal(val[j])
return string(bi) < string(bj)
})
return val
case map[string]interface{}:
for k := range val {
val[k] = sortAny(val[k])
}
return val
default:
return v
}
}
// changedPeers returns the peer IDs whose serialized map differs between before
// and after.
func changedPeers(before, after map[string]string) []string {
var changed []string
for id, b := range before {
a, ok := after[id]
if !ok || a != b {
changed = append(changed, id)
}
}
for id := range after {
if _, ok := before[id]; !ok {
changed = append(changed, id)
}
}
return changed
}
// TestAffectedPeers_Property_ResolverSupersetsRealChanges builds a topology,
// applies random changes, and asserts that the resolver's affected set is a
// superset of the peers whose real network map actually changed. If the resolver
// ever misses a dependency, a change will alter a peer's map without that peer
// appearing in the affected set, failing here.
func TestAffectedPeers_Property_ResolverSupersetsRealChanges(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
// A pre-existing peer->resource policy so the resource/router bridge is live.
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
// Extra peers and groups to give mutations room to move membership around.
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "prop-key", types.SetupKeyReusable, 0, nil, 999, userID, false, false)
require.NoError(t, err)
extraPeers := make([]string, 0, 4)
for i := 0; i < 4; i++ {
p := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
extraPeers = append(extraPeers, p.ID)
}
extraGroups := []string{"prop-grp-0", "prop-grp-1"}
for _, g := range extraGroups {
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{ID: g, Name: g}))
}
rng := rand.New(rand.NewSource(1))
allGroups := append([]string{s.sourceGroupID, s.resourceGroupID, s.routerPeerGroupID}, extraGroups...)
allPeers := append([]string{s.sourcePeerID, s.routerPeerID, s.routerGroupPeerID, s.unrelatedPeerID}, extraPeers...)
for iter := 0; iter < 60; iter++ {
change, apply := s.randomMutation(t, rng, allGroups, allPeers)
if apply == nil {
continue
}
before := allPeerMaps(t, s.manager, s.accountID)
resolvedSet := make(map[string]struct{})
resolve := func() {
require.NoError(t, s.manager.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
snap, err := affectedpeers.Load(ctx, tx, s.accountID, change)
if err != nil {
return err
}
for _, id := range snap.Expand(ctx, s.accountID, change) {
resolvedSet[id] = struct{}{}
}
return nil
}))
}
// Resolve on both sides of the mutation and union: removals are visible
// only pre-apply (the leaving peer is still a member), additions only
// post-apply (the joining peer is now a member). Production captures both
// via per-path handling (e.g. UpdateGroup passes peersToRemove); the union
// models that without coupling the test to each path's ordering.
resolve()
changedIDs := change.ChangedPeerIDs
apply()
resolve()
after := allPeerMaps(t, s.manager, s.accountID)
// The explicitly-changed peer's own map refresh is the caller's
// responsibility (the resolver returns the peers to propagate to), so it
// is allowed to be absent from the resolved set.
changedExplicitly := make(map[string]struct{}, len(changedIDs))
for _, id := range changedIDs {
changedExplicitly[id] = struct{}{}
}
for _, id := range changedPeers(before, after) {
if _, stillExists := after[id]; !stillExists {
continue
}
if _, isExplicit := changedExplicitly[id]; isExplicit {
continue
}
_, ok := resolvedSet[id]
require.Truef(t, ok,
"iter %d: peer %s network map changed but was not in the resolver's affected set %v (change=%+v)",
iter, id, maps.Keys(resolvedSet), change)
}
}
}
// randomMutation picks a random change, returns the Change to resolve and a
// function that applies the underlying store mutation. apply is nil when the
// drawn mutation is a no-op for the current state.
func (s *routerScenario) randomMutation(t *testing.T, rng *rand.Rand, allGroups, allPeers []string) (affectedpeers.Change, func()) {
t.Helper()
ctx := context.Background()
switch rng.Intn(3) {
case 0:
groupID := allGroups[rng.Intn(len(allGroups))]
peerID := allPeers[rng.Intn(len(allPeers))]
grp, err := s.manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, s.accountID, groupID)
require.NoError(t, err)
if slicesContains(grp.Peers, peerID) {
return affectedpeers.Change{}, nil
}
return affectedpeers.Change{ChangedGroupIDs: []string{groupID}, ChangedPeerIDs: []string{peerID}},
func() {
require.NoError(t, s.manager.GroupAddPeer(ctx, s.accountID, groupID, peerID))
}
case 1:
groupID := allGroups[rng.Intn(len(allGroups))]
grp, err := s.manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, s.accountID, groupID)
require.NoError(t, err)
if len(grp.Peers) == 0 {
return affectedpeers.Change{}, nil
}
peerID := grp.Peers[rng.Intn(len(grp.Peers))]
return affectedpeers.Change{ChangedGroupIDs: []string{groupID}, ChangedPeerIDs: []string{peerID}},
func() {
require.NoError(t, s.manager.GroupDeletePeer(ctx, s.accountID, groupID, peerID))
}
default:
src := allGroups[rng.Intn(len(allGroups))]
dst := allGroups[rng.Intn(len(allGroups))]
policy := &types.Policy{
Enabled: true,
Name: fmt.Sprintf("prop-policy-%d", rng.Int()),
Rules: []*types.PolicyRule{{
Enabled: true,
Sources: []string{src},
Destinations: []string{dst},
Action: types.PolicyTrafficActionAccept,
}},
}
return affectedpeers.Change{Policies: []*types.Policy{policy}},
func() {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
require.NoError(t, err)
}
}
}
func slicesContains(s []string, v string) bool {
for _, x := range s {
if x == v {
return true
}
}
return false
}

View File

@@ -0,0 +1,164 @@
package server
import (
"context"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/affectedpeers"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// countingStore wraps a real store and counts the per-account collection loads
// the resolver performs, so a test can assert each is read at most once and that
// irrelevant collections are skipped entirely.
type countingStore struct {
store.Store
mu sync.Mutex
counts map[string]int
}
func newCountingStore(s store.Store) *countingStore {
return &countingStore{Store: s, counts: map[string]int{}}
}
func (c *countingStore) bump(name string) {
c.mu.Lock()
c.counts[name]++
c.mu.Unlock()
}
func (c *countingStore) count(name string) int {
c.mu.Lock()
defer c.mu.Unlock()
return c.counts[name]
}
func (c *countingStore) total() int {
c.mu.Lock()
defer c.mu.Unlock()
n := 0
for _, v := range c.counts {
n += v
}
return n
}
func (c *countingStore) GetAccountPolicies(ctx context.Context, ls store.LockingStrength, accountID string) ([]*types.Policy, error) {
c.bump("policies")
return c.Store.GetAccountPolicies(ctx, ls, accountID)
}
func (c *countingStore) GetAccountRoutes(ctx context.Context, ls store.LockingStrength, accountID string) ([]*route.Route, error) {
c.bump("routes")
return c.Store.GetAccountRoutes(ctx, ls, accountID)
}
func (c *countingStore) GetAccountNameServerGroups(ctx context.Context, ls store.LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
c.bump("nameservers")
return c.Store.GetAccountNameServerGroups(ctx, ls, accountID)
}
func (c *countingStore) GetAccountDNSSettings(ctx context.Context, ls store.LockingStrength, accountID string) (*types.DNSSettings, error) {
c.bump("dnssettings")
return c.Store.GetAccountDNSSettings(ctx, ls, accountID)
}
func (c *countingStore) GetNetworkRoutersByAccountID(ctx context.Context, ls store.LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) {
c.bump("routers")
return c.Store.GetNetworkRoutersByAccountID(ctx, ls, accountID)
}
func (c *countingStore) GetNetworkResourcesByAccountID(ctx context.Context, ls store.LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) {
c.bump("resources")
return c.Store.GetNetworkResourcesByAccountID(ctx, ls, accountID)
}
func (c *countingStore) GetAccountServices(ctx context.Context, ls store.LockingStrength, accountID string) ([]*rpservice.Service, error) {
c.bump("services")
return c.Store.GetAccountServices(ctx, ls, accountID)
}
// TestAffectedPeers_QueryCount_NoRedundantFullTableLoads asserts the resolver
// loads each per-account collection at most once per Resolve (memoization) even
// on a change that drives every bridge, and skips the services table when the
// account has no embedded proxy peers.
func TestAffectedPeers_QueryCount_NoRedundantFullTableLoads(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
cs := newCountingStore(s.manager.Store)
// A group change that exercises policies, routers, resources and the bridge.
change := affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}}
snap, err := affectedpeers.Load(ctx, cs, s.accountID, change)
require.NoError(t, err)
affected := snap.Expand(ctx, s.accountID, change)
assert.Contains(t, affected, s.routerPeerID, "bridge must still resolve the routing peer")
for _, name := range []string{"policies", "routes", "nameservers", "dnssettings", "routers", "resources"} {
assert.LessOrEqualf(t, cs.count(name), 1,
"%s must be loaded at most once per Resolve, got %d", name, cs.count(name))
}
assert.Equal(t, 0, cs.count("services"),
"services must not be loaded when the account has no embedded proxy peers")
}
// TestAffectedPeers_QueryCount_NarrowChangeSkipsLoads asserts that a change with
// no group/peer signal touches no per-account collections beyond what its inputs
// require.
func TestAffectedPeers_QueryCount_NarrowChangeSkipsLoads(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
cs := newCountingStore(s.manager.Store)
// A bare network change drives only the router->source bridge: routers and
// resources are needed, but routes/nameservers/dnssettings/services are not.
_, err := affectedpeers.Load(ctx, cs, s.accountID, affectedpeers.Change{Networks: []*networkTypes.Network{{ID: s.networkID}}})
require.NoError(t, err)
assert.Equal(t, 0, cs.count("routes"), "routes must not be loaded for a network-only change")
assert.Equal(t, 0, cs.count("nameservers"), "nameservers must not be loaded for a network-only change")
assert.Equal(t, 0, cs.count("dnssettings"), "dnssettings must not be loaded for a network-only change")
assert.Equal(t, 0, cs.count("services"), "services must not be loaded for a network-only change")
}
// TestAffectedPeers_QueryCount_ExpandReadsNothing is the core invariant of the
// Load/Expand split: Load (run inside the transaction) does all store reads;
// Expand (run after commit) must touch the store ZERO times, so it never holds
// the write lock and never reads post-commit state.
func TestAffectedPeers_QueryCount_ExpandReadsNothing(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
change := affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}}
cs := newCountingStore(s.manager.Store)
snap, err := affectedpeers.Load(ctx, cs, s.accountID, change)
require.NoError(t, err)
require.Greater(t, cs.total(), 0, "Load must read the store")
// Any store access during Expand would increment the same counter. Expand
// operates purely on the snapshot, so the count must not move.
readsAfterLoad := cs.total()
affected := snap.Expand(ctx, s.accountID, change)
assert.Contains(t, affected, s.routerPeerID, "Expand must still produce the affected peers from the snapshot")
assert.Equal(t, readsAfterLoad, cs.total(), "Expand must perform zero store reads — it operates purely on the loaded snapshot")
}

View File

@@ -0,0 +1,333 @@
package server
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/affectedpeers"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
)
func (s *routerScenario) resolveGroupChangeAffected(ctx context.Context, changedGroupIDs []string) []string {
change := affectedpeers.Change{ChangedGroupIDs: changedGroupIDs}
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
if err != nil {
return nil
}
return snap.Expand(ctx, s.accountID, change)
}
func (s *routerScenario) resolvePeerChangeAffected(ctx context.Context, changedPeerIDs []string) []string {
change := affectedpeers.Change{ChangedPeerIDs: changedPeerIDs}
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
if err != nil {
return nil
}
return snap.Expand(ctx, s.accountID, change)
}
func TestAffectedPeers_GroupChange_SourceGroupMembership_RefreshesRoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
assert.Contains(t, affected, s.sourcePeerID, "source group member must be affected")
assert.Contains(t, affected, s.routerPeerID,
"changing the source group of a peer->resource policy must refresh the resource's routing peer")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_GroupChange_SourceGroupMembership_RefreshesRoutingPeer_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
assert.Contains(t, affected, s.routerGroupPeerID,
"changing the source group must refresh the routing peer defined via router.PeerGroups")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_GroupChange_RouterPeerGroupMembership_RefreshesPolicySources(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolveGroupChangeAffected(ctx, []string{s.routerPeerGroupID})
assert.Contains(t, affected, s.routerGroupPeerID, "the routing peer itself must be affected")
assert.Contains(t, affected, s.sourcePeerID,
"changing the router's PeerGroups must refresh the source peers of policies serving the resource")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_PeerChange_SourcePeer_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
assert.Contains(t, affected, s.routerPeerID,
"a status change on a source peer must refresh the resource's routing peer that serves it")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_PeerChange_SourcePeer_ByDestinationResource_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
require.NoError(t, err)
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
assert.Contains(t, affected, s.routerPeerID,
"DestinationResource-targeted policy must still bridge a source-peer change to the routing peer")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_E2E_DeleteGroup_ResolvesAffectedPeers(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
const memberOnlyGroupID = "rs-memberonly-grp"
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: memberOnlyGroupID, Name: "rs-memberonly", Peers: []string{s.sourcePeerID},
}))
affected := s.resolveGroupChangeAffected(ctx, []string{memberOnlyGroupID})
assert.Empty(t, affected, "an unlinked group has no network-map impact, so no peer is affected")
require.NoError(t, s.manager.DeleteGroup(ctx, s.accountID, userID, memberOnlyGroupID))
}
func TestAffectedPeers_GroupAddResource_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
const extraResourceGroupID = "rs-resource-grp-extra"
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: extraResourceGroupID, Name: "rs-resource-extra",
}))
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, extraResourceGroupID), true)
require.NoError(t, err)
require.NoError(t, s.manager.GroupAddResource(ctx, s.accountID, extraResourceGroupID, types.Resource{
ID: s.resourceID,
Type: types.ResourceTypeHost,
}))
affected := s.resolveGroupChangeAffected(ctx, []string{extraResourceGroupID})
assert.Contains(t, affected, s.routerPeerID,
"attaching a resource to a policy destination group must refresh the resource's routing peer")
assert.Contains(t, affected, s.sourcePeerID, "policy source peers must refresh")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func (s *routerScenario) createPostureCheckGatedPolicy(t *testing.T, ctx context.Context) string {
t.Helper()
check, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
Name: "rs-min-version",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
},
}, true)
require.NoError(t, err)
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
policy.SourcePostureChecks = []string{check.ID}
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
require.NoError(t, err)
return check.ID
}
func TestAffectedPeers_E2E_SavePostureCheck_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
checkID := s.createPostureCheckGatedPolicy(t, ctx)
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
})
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
peerShouldNotReceiveUpdate(t, unrelatedCh)
close(done)
}()
_, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
ID: checkID,
Name: "rs-min-version",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.31.0"},
},
}, false)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: editing a posture check did not refresh source + routing peers")
}
}
func TestAffectedPeers_E2E_UpdateResource_DestinationResourcePolicy_RefreshesSourcePeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
require.NoError(t, err)
resourcesManager, _, _ := s.managers()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
})
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
peerShouldNotReceiveUpdate(t, unrelatedCh)
close(done)
}()
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/25",
GroupIDs: []string{s.resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: updating a DestinationResource-targeted resource did not refresh its policy source peer")
}
}
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
resourcesManager, routersManager, _ := s.managers()
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-disabled", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
disabledRouterPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
NetworkID: s.networkID,
AccountID: s.accountID,
Peer: disabledRouterPeer.ID,
Masquerade: true,
Metric: 9000,
Enabled: false,
})
require.NoError(t, err)
disabledCh := s.updateManager.CreateChannel(ctx, disabledRouterPeer.ID)
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID) })
settleAffectedUpdates(disabledCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, disabledCh)
close(done)
}()
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/25",
GroupIDs: []string{s.resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: resource update did not refresh the disabled sibling router's peer")
}
}
func TestAffectedPeers_GroupChange_RouterInOtherNetworkNotAffected(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "groupiso")
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
assert.NotContains(t, affected, second.routerPeerID,
"a router in an unrelated network must not be affected by a source-group change for another resource")
}
func TestAffectedPeers_PeerChange_RouterInOtherNetworkNotAffected(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "peeriso")
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
assert.NotContains(t, affected, second.routerPeerID,
"a router in an unrelated network must not be affected by a source-peer change for another resource")
}

View File

@@ -0,0 +1,771 @@
package server
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/server/networks/routers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// routerScenario captures the topology from the bug report:
//
// network ── router (routing peer) ── resource (in resourceGroup)
// independent peer ──(policy: source -> resource)──> resource
//
// The routing peer must be refreshed when a policy grants a source peer access
// to the resource, because the network map connects the source peer to the
// routing peer at compute time (Account.GetPoliciesForNetworkResource +
// addNetworksRoutingPeers). The routing peer is NOT a member of the resource
// group, so static group/peer resolution alone cannot find it.
type routerScenario struct {
manager *DefaultAccountManager
updateManager *update_channel.PeersUpdateManager
accountID string
networkID string
sourcePeerID string // independent peer that the policy grants access from
sourceGroupID string // group containing the source peer
routerPeerID string // peer acting as the routing peer (direct router.Peer)
routerGroupPeerID string // peer that is a member of routerPeerGroup
routerPeerGroupID string // group used for router.PeerGroups
resourceID string // network resource
resourceGroupID string // group whose member is the resource (no peers)
unrelatedPeerID string // peer in no relevant entity
}
// setupRouterScenario builds the topology above with the default policy removed
// and channels NOT yet created, so callers control exactly when updates can flow.
func setupRouterScenario(t *testing.T, directRouterPeer bool) *routerScenario {
t.Helper()
manager, updateManager, err := createManager(t)
require.NoError(t, err)
ctx := context.Background()
account, err := createAccount(manager, "router_scenario", userID, "")
require.NoError(t, err)
accountID := account.Id
// Remove the default policy so AddPeer/CreateGroup don't schedule unrelated updates.
policies, err := manager.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
require.NoError(t, err)
for _, p := range policies {
require.NoError(t, manager.Store.DeletePolicy(ctx, accountID, p.ID))
}
setupKey, err := manager.CreateSetupKey(ctx, accountID, "rs-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
sourcePeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
routerPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
routerGroupPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
unrelatedPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
const (
sourceGroupID = "rs-source-grp"
routerPeerGroupID = "rs-router-grp"
resourceGroupID = "rs-resource-grp"
)
for _, g := range []*types.Group{
{ID: sourceGroupID, Name: "rs-source", Peers: []string{sourcePeer.ID}},
{ID: routerPeerGroupID, Name: "rs-router", Peers: []string{routerGroupPeer.ID}},
{ID: resourceGroupID, Name: "rs-resource"}, // intentionally peerless; the resource is its only member
} {
require.NoError(t, manager.CreateGroup(ctx, accountID, userID, g))
}
permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
network, err := networksManager.CreateNetwork(ctx, userID, &networkTypes.Network{
ID: "rs-network",
AccountID: accountID,
Name: "rs-network",
})
require.NoError(t, err)
resource, err := resourcesManager.CreateResource(ctx, userID, &resourceTypes.NetworkResource{
AccountID: accountID,
NetworkID: network.ID,
Name: "rs-resource-host",
Address: "10.20.30.0/24",
GroupIDs: []string{resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
router := &routerTypes.NetworkRouter{
ID: "rs-router",
NetworkID: network.ID,
AccountID: accountID,
Masquerade: true,
Metric: 9999,
Enabled: true,
}
if directRouterPeer {
router.Peer = routerPeer.ID
} else {
router.PeerGroups = []string{routerPeerGroupID}
}
_, err = routersManager.CreateRouter(ctx, userID, router)
require.NoError(t, err)
return &routerScenario{
manager: manager,
updateManager: updateManager,
accountID: accountID,
networkID: network.ID,
sourcePeerID: sourcePeer.ID,
sourceGroupID: sourceGroupID,
routerPeerID: routerPeer.ID,
routerGroupPeerID: routerGroupPeer.ID,
routerPeerGroupID: routerPeerGroupID,
resourceID: resource.ID,
resourceGroupID: resourceGroupID,
unrelatedPeerID: unrelatedPeer.ID,
}
}
// peerToResourcePolicy builds a policy granting the source group access to the
// resource, referencing the resource by its group in the rule destination.
func peerToResourcePolicyByGroup(sourceGroupID, resourceGroupID string) *types.Policy {
return &types.Policy{
Enabled: true,
Name: "peer-to-resource-by-group",
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{sourceGroupID},
Destinations: []string{resourceGroupID},
Action: types.PolicyTrafficActionAccept,
},
},
}
}
// peerToResourcePolicyByResource builds a policy referencing the resource
// directly via DestinationResource rather than its group.
func peerToResourcePolicyByResource(sourceGroupID, resourceID string) *types.Policy {
return &types.Policy{
Enabled: true,
Name: "peer-to-resource-by-resource",
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{sourceGroupID},
DestinationResource: types.Resource{ID: resourceID, Type: types.ResourceTypeHost},
Action: types.PolicyTrafficActionAccept,
},
},
}
}
// resolvePolicyAffected mirrors SavePolicy's resolution: resolve the affected
// peers for the given policy.
func (s *routerScenario) resolvePolicyAffected(ctx context.Context, policy *types.Policy) []string {
change := affectedpeers.Change{Policies: []*types.Policy{policy}}
snap, err := affectedpeers.Load(ctx, s.manager.Store, s.accountID, change)
if err != nil {
return nil
}
return snap.Expand(ctx, s.accountID, change)
}
func TestAffectedPeers_SourcePeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
affected := s.resolvePolicyAffected(ctx, policy)
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
}
func TestAffectedPeers_RoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
affected := s.resolvePolicyAffected(ctx, policy)
// BUG: the direct routing peer serves the resource's subnet to the source
// peer, so it must be refreshed when the policy is created. The policy path
// only resolves the literal rule groups (source group + resource group);
// the resource group has no peer members and the router peer is reachable
// only through the network, so it is dropped.
assert.Contains(t, affected, s.routerPeerID,
"routing peer (router.Peer) serving the resource must be affected by a policy granting access to it")
}
func TestAffectedPeers_RoutingPeer_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
affected := s.resolvePolicyAffected(ctx, policy)
// Router defined via PeerGroups instead of a direct peer.
assert.Contains(t, affected, s.routerGroupPeerID,
"routing peer (router.PeerGroups member) serving the resource must be affected")
}
func TestAffectedPeers_DestResource_RoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
affected := s.resolvePolicyAffected(ctx, policy)
// When the resource is referenced via DestinationResource, RuleGroups()
// returns only the source group and the resource ID is not a peer, so
// collectPolicyAffectedGroupsAndPeers yields nothing for the destination at
// all. The routing peer is dropped here too.
assert.Contains(t, affected, s.routerPeerID,
"routing peer must be affected when the resource is referenced via DestinationResource")
}
func TestAffectedPeers_DestResource_RoutingPeer_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
affected := s.resolvePolicyAffected(ctx, policy)
assert.Contains(t, affected, s.routerGroupPeerID,
"routing peer (PeerGroups) must be affected when the resource is referenced via DestinationResource")
}
func TestAffectedPeers_SourceResourcePeer_RoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
// Source expressed as a direct peer (SourceResource), destination as resource group.
policy := &types.Policy{
Enabled: true,
Name: "sourceResource-peer-to-resource",
Rules: []*types.PolicyRule{
{
Enabled: true,
SourceResource: types.Resource{ID: s.sourcePeerID, Type: types.ResourceTypePeer},
Destinations: []string{s.resourceGroupID},
Action: types.PolicyTrafficActionAccept,
},
},
}
affected := s.resolvePolicyAffected(ctx, policy)
// The direct source peer IS picked up (collectPolicyAffectedGroupsAndPeers
// handles SourceResource peers), but the routing peer is still missing.
assert.Contains(t, affected, s.sourcePeerID, "direct source peer must be affected")
assert.Contains(t, affected, s.routerPeerID, "routing peer must be affected")
}
func TestAffectedPeers_PolicyToResource_UnrelatedPeerNotAffected(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
affected := s.resolvePolicyAffected(ctx, policy)
// Guard against an over-broad fix: a peer in no relevant entity must never
// be pulled in.
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_ResourceSideBridgesToRoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
// A pre-existing policy grants the source group access to the resource.
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
// Drive an update through the resource manager and assert the routing peer
// is among the affected set by observing the channel. This path walks
// policies whose destinations reference the resource's groups, folds in the
// source groups, and loads the network's routers, so it reaches both the
// source peer and the routing peer.
permissionsManager := permissions.NewManager(s.manager.Store)
groupsManager := groups.NewManager(s.manager.Store, permissionsManager, s.manager)
rm := resources.NewManager(s.manager.Store, permissionsManager, groupsManager, s.manager, s.manager.serviceManager)
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
_, err = rm.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/24",
GroupIDs: []string{s.resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: resource update did not refresh source peer + routing peer")
}
}
// settleAffectedUpdates waits for in-flight async updates to arrive, then drains
// every given channel so subsequent assertions start from a clean slate.
//
// Setup (CreateNetwork/CreateResource/CreateRouter) fires async UpdateAffectedPeers
// goroutines; draining first means the assertion only observes updates from the
// action under test, not setup stragglers.
func settleAffectedUpdates(chans ...<-chan *network_map.UpdateMessage) {
time.Sleep(300 * time.Millisecond)
for _, ch := range chans {
drainPeerUpdates(ch)
}
}
func TestAffectedPeers_E2E_CreatePolicy_RoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
})
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
peerShouldNotReceiveUpdate(t, unrelatedCh)
close(done)
}()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: creating peer->resource policy did not refresh the routing peer")
}
}
func TestAffectedPeers_E2E_CreatePolicy_RoutingPeer_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerGroupPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerGroupPeerID)
})
settleAffectedUpdates(srcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: routing peer (PeerGroups) not refreshed on policy create")
}
}
func TestAffectedPeers_E2E_DestResource_RoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
settleAffectedUpdates(srcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: routing peer not refreshed when policy targets DestinationResource")
}
}
func TestAffectedPeers_E2E_DeletePolicy_RoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
settleAffectedUpdates(srcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
require.NoError(t, s.manager.DeletePolicy(ctx, s.accountID, policy.ID, userID))
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: deleting peer->resource policy did not refresh the routing peer")
}
}
func (s *routerScenario) managers() (resources.Manager, routers.Manager, networks.Manager) {
permissionsManager := permissions.NewManager(s.manager.Store)
groupsManager := groups.NewManager(s.manager.Store, permissionsManager, s.manager)
resourcesManager := resources.NewManager(s.manager.Store, permissionsManager, groupsManager, s.manager, s.manager.serviceManager)
routersManager := routers.NewManager(s.manager.Store, permissionsManager, s.manager)
networksManager := networks.NewManager(s.manager.Store, permissionsManager, resourcesManager, routersManager, s.manager)
return resourcesManager, routersManager, networksManager
}
type secondTopology struct {
networkID string
resourceID string
resourceGroupID string
routerPeerID string
}
func (s *routerScenario) addSecondTopology(t *testing.T, suffix string) secondTopology {
t.Helper()
ctx := context.Background()
resourcesManager, routersManager, networksManager := s.managers()
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-"+suffix, types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
routerPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
resourceGroupID := "rs-resource-grp-" + suffix
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: resourceGroupID, Name: "rs-resource-" + suffix,
}))
network, err := networksManager.CreateNetwork(ctx, userID, &networkTypes.Network{
ID: "rs-network-" + suffix,
AccountID: s.accountID,
Name: "rs-network-" + suffix,
})
require.NoError(t, err)
resource, err := resourcesManager.CreateResource(ctx, userID, &resourceTypes.NetworkResource{
AccountID: s.accountID,
NetworkID: network.ID,
Name: "rs-resource-host-" + suffix,
Address: "10.40.50.0/24",
GroupIDs: []string{resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
NetworkID: network.ID,
AccountID: s.accountID,
Peer: routerPeer.ID,
Masquerade: true,
Metric: 9999,
Enabled: true,
})
require.NoError(t, err)
return secondTopology{
networkID: network.ID,
resourceID: resource.ID,
resourceGroupID: resourceGroupID,
routerPeerID: routerPeer.ID,
}
}
func TestAffectedPeers_E2E_UpdatePolicy_BothRoutingPeers(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "b")
ctx := context.Background()
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerACh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
routerBCh := s.updateManager.CreateChannel(ctx, second.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
s.updateManager.CloseChannel(ctx, second.routerPeerID)
})
settleAffectedUpdates(srcCh, routerACh, routerBCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerACh)
peerShouldReceiveUpdate(t, routerBCh)
close(done)
}()
policy.Rules[0].Destinations = []string{second.resourceGroupID}
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, false)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: re-pointing the policy destination did not refresh both routing peers")
}
}
func TestAffectedPeers_E2E_UpdatePolicy_AddSource(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
const secondSourceGroupID = "rs-source-grp-2"
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-2", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
secondSourcePeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: secondSourceGroupID, Name: "rs-source-2", Peers: []string{secondSourcePeer.ID},
}))
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
newSrcCh := s.updateManager.CreateChannel(ctx, secondSourcePeer.ID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, secondSourcePeer.ID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
settleAffectedUpdates(newSrcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, newSrcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
policy.Rules[0].Sources = []string{s.sourceGroupID, secondSourceGroupID}
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, false)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: adding a source group did not refresh the new source peer + routing peer")
}
}
func TestAffectedPeers_E2E_DestResource_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerGroupPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerGroupPeerID)
})
settleAffectedUpdates(srcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: DestinationResource policy with PeerGroups router did not refresh the routing peer")
}
}
func TestAffectedPeers_AllRoutingPeers_Network(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, routersManager, _ := s.managers()
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-r2", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
secondRouterPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
NetworkID: s.networkID,
AccountID: s.accountID,
Peer: secondRouterPeer.ID,
Masquerade: true,
Metric: 9998,
Enabled: true,
})
require.NoError(t, err)
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.routerPeerID, "first routing peer must be affected")
assert.Contains(t, affected, secondRouterPeer.ID, "second routing peer on the same network must also be affected")
}
func TestAffectedPeers_DisabledRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
routers, err := s.manager.Store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, s.accountID, s.networkID)
require.NoError(t, err)
require.Len(t, routers, 1)
routers[0].Enabled = false
require.NoError(t, s.manager.Store.UpdateNetworkRouter(ctx, routers[0]))
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
assert.Contains(t, affected, s.routerPeerID,
"disabled router's peer must still be affected: Enabled must not gate affected-peers")
}
func TestAffectedPeers_DisabledResource(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
res, err := s.manager.Store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, s.accountID, s.resourceID)
require.NoError(t, err)
res.Enabled = false
require.NoError(t, s.manager.Store.SaveNetworkResource(ctx, res))
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
assert.Contains(t, affected, s.routerPeerID,
"disabled resource must still resolve the routing peer: Enabled must not gate affected-peers")
}
func TestAffectedPeers_DisabledRule(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
policy.Rules[0].Enabled = false
affected := s.resolvePolicyAffected(ctx, policy)
assert.Contains(t, affected, s.routerPeerID,
"disabled rule must still resolve the routing peer: Enabled must not gate affected-peers")
}
func TestAffectedPeers_MultiRule(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "c")
ctx := context.Background()
policy := &types.Policy{
Enabled: true,
Name: "multi-rule-two-resources",
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{s.sourceGroupID},
Destinations: []string{s.resourceGroupID},
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{s.sourceGroupID},
Destinations: []string{second.resourceGroupID},
Action: types.PolicyTrafficActionAccept,
},
},
}
affected := s.resolvePolicyAffected(ctx, policy)
assert.Contains(t, affected, s.routerPeerID, "routing peer for resource A must be affected")
assert.Contains(t, affected, second.routerPeerID, "routing peer for resource B must be affected")
}
func TestAffectedPeers_RouterOtherNetwork(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "d")
ctx := context.Background()
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
assert.NotContains(t, affected, second.routerPeerID,
"a router in an unrelated network must not be affected by a policy that does not target its resource")
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,825 @@
// Package affectedpeers computes which peers' network maps a change touches, so
// only those peers are refreshed instead of the whole account.
//
// Two phases keep the dependency walk off the write transaction:
// - Load: reads the needed collections. Call INSIDE the mutating tx (consistent,
// and before a delete/removal severs the old state).
// - Snapshot.Expand: in-memory walk, no store access. Run AFTER the tx commits.
//
// Enabled is never consulted: toggling it is itself an observable change.
package affectedpeers
import (
"context"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// Snapshot is an in-memory view of the collections needed to expand a Change.
// Loaded in-tx, walked by Expand after commit. Only the collections the Change
// can touch are loaded; the rest stay nil (see Load).
type Snapshot struct {
policies []*types.Policy
routes []*route.Route
nsGroups []*nbdns.NameServerGroup
dnsSettings *types.DNSSettings
routers []*routerTypes.NetworkRouter
resources []*resourceTypes.NetworkResource
services []*rpservice.Service
proxyByCluster map[string][]string
groups map[string]*types.Group
groupPeers map[string]map[string]struct{} // groupID -> member peer IDs
}
// Load reads the collections a Change requires, inside the caller's tx. It mirrors
// Expand's walker preconditions, loading only what the change can touch.
func Load(ctx context.Context, s store.Store, accountID string, c Change) (*Snapshot, error) {
snap := &Snapshot{}
if c.isEmpty() {
return snap, nil
}
if err := snap.loadCollections(ctx, s, accountID, c); err != nil {
return nil, err
}
if err := snap.loadGroupIndex(ctx, s, accountID); err != nil {
return nil, err
}
return snap, nil
}
// loadCollections reads the policy/route/nameserver/dns/router/resource/proxy
// collections a Change can touch, gated to what the walk needs.
func (snap *Snapshot) loadCollections(ctx context.Context, s store.Store, accountID string, c Change) error {
hasGroupOrPeerChange := len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 || len(c.Resources) > 0
hasNetworkObject := len(c.Routers) > 0 || len(c.Resources) > 0 || len(c.Networks) > 0
// the resource<->router bridge can fire for any of these
needsRoutersResources := hasGroupOrPeerChange || len(c.PostureCheckIDs) > 0 || len(c.Policies) > 0 || hasNetworkObject
if needsRoutersResources {
if err := snap.loadPolicyRoutersResources(ctx, s, accountID); err != nil {
return err
}
}
if hasGroupOrPeerChange {
if err := snap.loadRoutesAndProxy(ctx, s, accountID); err != nil {
return err
}
}
if len(c.ChangedGroupIDs) > 0 || len(c.ChangedPeerIDs) > 0 {
if err := snap.loadDNS(ctx, s, accountID); err != nil {
return err
}
}
return nil
}
// loadPolicyRoutersResources loads the policies plus the routers and resources
// the resource<->router bridge walks.
func (snap *Snapshot) loadPolicyRoutersResources(ctx context.Context, s store.Store, accountID string) error {
var err error
if snap.policies, err = s.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID); err != nil {
return err
}
if snap.routers, err = s.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID); err != nil {
return err
}
snap.resources, err = s.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
return err
}
// loadRoutesAndProxy loads the routes and the embedded-proxy services index.
func (snap *Snapshot) loadRoutesAndProxy(ctx context.Context, s store.Store, accountID string) error {
var err error
if snap.routes, err = s.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID); err != nil {
return err
}
return snap.loadProxyServices(ctx, s, accountID)
}
// loadDNS loads the nameserver groups and account DNS settings.
func (snap *Snapshot) loadDNS(ctx context.Context, s store.Store, accountID string) error {
var err error
if snap.nsGroups, err = s.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID); err != nil {
return err
}
snap.dnsSettings, err = s.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
return err
}
// loadProxyServices loads the embedded-proxy cluster index, and the services only
// when the account actually has embedded proxy peers.
func (snap *Snapshot) loadProxyServices(ctx context.Context, s store.Store, accountID string) error {
var err error
if snap.proxyByCluster, err = s.GetEmbeddedProxyPeerIDsByCluster(ctx, accountID); err != nil {
return err
}
if len(snap.proxyByCluster) == 0 {
return nil
}
snap.services, err = s.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
return err
}
// loadGroupIndex loads all groups (for group.Resources) and builds the
// group->member-peers index. Always needed: the bridge resolves group.Resources
// and Expand maps groups to member peers.
func (snap *Snapshot) loadGroupIndex(ctx context.Context, s store.Store, accountID string) error {
groups, err := s.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
snap.groups = make(map[string]*types.Group, len(groups))
snap.groupPeers = make(map[string]map[string]struct{}, len(groups))
for _, g := range groups {
snap.groups[g.ID] = g
members := make(map[string]struct{}, len(g.Peers))
for _, pID := range g.Peers {
members[pID] = struct{}{}
}
snap.groupPeers[g.ID] = members
}
return nil
}
// Change describes what changed in an account.
type Change struct {
ChangedGroupIDs []string
ChangedPeerIDs []string
Policies []*types.Policy
Routes []*route.Route
Routers []*routerTypes.NetworkRouter
Resources []*resourceTypes.NetworkResource
Networks []*networkTypes.Network
PostureCheckIDs []string
// DistributionGroupIDs are groups whose members are directly affected, with no
// dependency walk — the change distributes config to the groups' member peers
// only (nameserver groups, DNS DisabledManagementGroups), not through the
// policy/route reachability graph. Pass oldnew so both states refresh.
DistributionGroupIDs []string
// RemovedPeersByGroup: peers that left a group, keyed by that group. They are no
// longer in the group's member index but still lose its reachability, so they are
// folded in — but only when the group is linked (an unlinked group has no map
// impact), matching how current members are handled.
RemovedPeersByGroup map[string][]string
}
func (c Change) isEmpty() bool {
return len(c.ChangedGroupIDs) == 0 &&
len(c.ChangedPeerIDs) == 0 &&
len(c.Policies) == 0 &&
len(c.Routes) == 0 &&
len(c.Routers) == 0 &&
len(c.Resources) == 0 &&
len(c.Networks) == 0 &&
len(c.PostureCheckIDs) == 0 &&
len(c.DistributionGroupIDs) == 0 &&
len(c.RemovedPeersByGroup) == 0
}
// Expand returns the deduplicated affected peer IDs from the preloaded Snapshot,
// no store access. Run after the producing tx commits. Logs the full walk at
// trace level for diagnosing a miscalculation.
func (snap *Snapshot) Expand(ctx context.Context, accountID string, c Change) []string {
if c.isEmpty() {
return nil
}
r := newResolver(ctx, snap, accountID, c)
log.WithContext(ctx).Tracef("affectedpeers expand start: account=%s changedGroups=%v changedPeers=%v policies=%d routes=%d routers=%d resources=%d networks=%d postureChecks=%v distributionGroups=%v",
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, len(c.Policies), len(c.Routes), len(c.Routers), len(c.Resources), len(c.Networks), c.PostureCheckIDs, c.DistributionGroupIDs)
r.walk()
return r.expand()
}
// Collect returns the affected group and direct-peer IDs without expanding groups
// to members. Test-only introspection; use Resolve otherwise.
func Collect(ctx context.Context, s store.Store, accountID string, c Change) (groupIDs []string, directPeerIDs []string) {
if c.isEmpty() {
return nil, nil
}
snap, err := Load(ctx, s, accountID, c)
if err != nil {
log.WithContext(ctx).Errorf("failed to load snapshot for affected peers collect: %v", err)
return nil, nil
}
r := newResolver(ctx, snap, accountID, c)
r.walk()
return setToSlice(r.groupSet), setToSlice(r.peerSet)
}
func newResolver(ctx context.Context, snap *Snapshot, accountID string, c Change) *resolver {
r := &resolver{
ctx: ctx,
snap: snap,
accountID: accountID,
change: c,
changedGroupSet: toSet(c.ChangedGroupIDs),
changedPeerSet: toSet(c.ChangedPeerIDs),
groupSet: make(map[string]struct{}),
peerSet: make(map[string]struct{}),
networkIDs: make(map[string]struct{}),
}
// Resolve each changed peer to its groups here so callers pass only ChangedPeerIDs.
r.seedChangedGroupsFromPeers()
r.matchedPolicies = append(r.matchedPolicies, c.Policies...)
return r
}
// seedChangedGroupsFromPeers adds each changed peer's groups to changedGroupSet so
// the group-driven walkers fire for memberships, not just direct peer references.
func (r *resolver) seedChangedGroupsFromPeers() {
if len(r.changedPeerSet) == 0 {
return
}
for groupID, members := range r.snap.groupPeers {
for pID := range r.changedPeerSet {
if _, ok := members[pID]; ok {
r.changedGroupSet[groupID] = struct{}{}
break
}
}
}
}
func (r *resolver) walk() {
r.collectFromExplicitPolicies()
r.collectFromExplicitRoutes(r.change.Routes)
r.collectFromExplicitRouters(r.change.Routers)
r.collectFromExplicitResources(r.change.Resources)
r.collectFromExplicitNetworks(r.change.Networks)
r.collectFromPostureChecks(r.change.PostureCheckIDs)
// Distribution groups (nameserver/DNS) affect only their member peers: fold them
// straight into groupSet so expand() maps them to members, without the policy/
// route walk that changedGroupSet would trigger.
addAll(r.groupSet, r.change.DistributionGroupIDs)
if len(r.changedGroupSet) > 0 || len(r.changedPeerSet) > 0 {
r.collectFromPolicies()
r.collectFromRoutes()
r.collectFromNameServers()
r.collectFromDNSSettings()
r.collectFromNetworkRouters()
r.collectFromProxyServices()
}
r.collectResourceRouterBridge()
}
type resolver struct {
ctx context.Context
snap *Snapshot
accountID string
change Change
changedGroupSet map[string]struct{}
changedPeerSet map[string]struct{}
groupSet map[string]struct{}
peerSet map[string]struct{}
matchedPolicies []*types.Policy
networkIDs map[string]struct{}
}
func (r *resolver) policies() []*types.Policy { return r.snap.policies }
func (r *resolver) networkResources() []*resourceTypes.NetworkResource { return r.snap.resources }
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter { return r.snap.routers }
// peerIDsForGroups maps a group set to its member peer IDs via the preloaded index.
func (r *resolver) peerIDsForGroups(groupSet map[string]struct{}) []string {
seen := make(map[string]struct{})
var ids []string
for gID := range groupSet {
for pID := range r.snap.groupPeers[gID] {
if _, ok := seen[pID]; ok {
continue
}
seen[pID] = struct{}{}
ids = append(ids, pID)
}
}
return ids
}
func (r *resolver) expand() []string {
peerIDs := r.peerIDsForGroups(r.groupSet)
log.WithContext(r.ctx).Tracef("affectedpeers expand: account=%s affectedGroups=%v -> %d group-member peers; direct peers=%v",
r.accountID, setToSlice(r.groupSet), len(peerIDs), setToSlice(r.peerSet))
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for id := range r.peerSet {
if _, ok := seen[id]; !ok {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
// Fold in removed peers only when their group is linked (in groupSet).
for groupID, removed := range r.change.RemovedPeersByGroup {
if _, linked := r.groupSet[groupID]; !linked {
continue
}
for _, id := range removed {
if _, ok := seen[id]; !ok {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
log.WithContext(r.ctx).Tracef("affectedpeers expand: removed peer %s from linked group %s -> affected", id, groupID)
}
}
}
log.WithContext(r.ctx).Tracef("affectedpeers expand done: account=%s -> %d affected peers: %v", r.accountID, len(peerIDs), peerIDs)
return peerIDs
}
func (r *resolver) collectFromExplicitPolicies() {
for _, policy := range r.matchedPolicies {
if policy == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitPolicies: changed policy %s (%s) -> folding rule groups %v + direct peers",
policy.ID, policy.Name, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
}
}
func (r *resolver) collectFromExplicitRoutes(routes []*route.Route) {
for _, rt := range routes {
if rt == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
if rt.Peer != "" {
r.peerSet[rt.Peer] = struct{}{}
}
}
}
// collectFromExplicitRouters folds changed routers' peers and marks their networks
// for the bridge. Passing the old router keeps a repointed router's previous peers
// affected without a post-commit read.
func (r *resolver) collectFromExplicitRouters(routers []*routerTypes.NetworkRouter) {
for _, router := range routers {
if router == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitRouters: changed router %s on network %s -> folding peerGroups=%v peer=%q and marking network for source bridge",
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
addAll(r.groupSet, router.PeerGroups)
if router.Peer != "" {
r.peerSet[router.Peer] = struct{}{}
}
if router.NetworkID != "" {
r.networkIDs[router.NetworkID] = struct{}{}
}
}
}
// collectFromExplicitResources marks changed resources' networks for the bridge and
// treats their group IDs as changed, so policies targeting the resource via a
// now-detached (old) group still refresh.
func (r *resolver) collectFromExplicitResources(resources []*resourceTypes.NetworkResource) {
for _, resource := range resources {
if resource == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitResources: changed resource %s on network %s -> marking network for bridge and treating groups %v as changed",
resource.ID, resource.NetworkID, resource.GroupIDs)
addAll(r.changedGroupSet, resource.GroupIDs)
if resource.NetworkID != "" {
r.networkIDs[resource.NetworkID] = struct{}{}
}
}
}
// collectFromExplicitNetworks marks changed networks for the bridge. A network has
// no groups/peers of its own.
func (r *resolver) collectFromExplicitNetworks(networks []*networkTypes.Network) {
for _, network := range networks {
if network == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitNetworks: changed network %s -> marking for bridge", network.ID)
if network.ID != "" {
r.networkIDs[network.ID] = struct{}{}
}
}
}
func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
if len(postureCheckIDs) == 0 {
return
}
ids := toSet(postureCheckIDs)
for _, policy := range r.policies() {
if !policyReferencesPostureChecks(policy, ids) {
continue
}
log.WithContext(r.ctx).Tracef("collectFromPostureChecks: policy %s (%s) references changed posture checks %v -> folding rule groups %v + direct peers",
policy.ID, policy.Name, postureCheckIDs, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
r.matchedPolicies = append(r.matchedPolicies, policy)
}
}
func (r *resolver) collectFromPolicies() {
for _, policy := range r.policies() {
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
r.matchedPolicies = append(r.matchedPolicies, policy)
}
}
func (r *resolver) collectFromRoutes() {
for _, rt := range r.snap.routes {
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
matchedByPeer := rt.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(rt.Peer, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (byGroup=%t byPeer=%t) -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
rt.ID, matchedByGroup, matchedByPeer, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
if rt.Peer != "" {
r.peerSet[rt.Peer] = struct{}{}
}
}
}
func (r *resolver) collectFromNameServers() {
if len(r.changedGroupSet) == 0 {
return
}
for _, ns := range r.snap.nsGroups {
if anyInSet(ns.Groups, r.changedGroupSet) {
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a changed group -> folding its groups %v", ns.ID, ns.Groups)
addAll(r.groupSet, ns.Groups)
}
}
}
func (r *resolver) collectFromDNSSettings() {
if len(r.changedGroupSet) == 0 || r.snap.dnsSettings == nil {
return
}
for _, gID := range r.snap.dnsSettings.DisabledManagementGroups {
if _, ok := r.changedGroupSet[gID]; ok {
log.WithContext(r.ctx).Tracef("collectFromDNSSettings: changed group %s is in DisabledManagementGroups -> folding it", gID)
r.groupSet[gID] = struct{}{}
}
}
}
func (r *resolver) collectFromNetworkRouters() {
for _, router := range r.networkRouters() {
matchedByGroup := anyInSet(router.PeerGroups, r.changedGroupSet)
matchedByPeer := router.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(router.Peer, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding peerGroups=%v peer=%q and marking network for source bridge",
router.ID, router.NetworkID, matchedByGroup, matchedByPeer, router.PeerGroups, router.Peer)
addAll(r.groupSet, router.PeerGroups)
if router.Peer != "" {
r.peerSet[router.Peer] = struct{}{}
}
r.networkIDs[router.NetworkID] = struct{}{}
}
}
func (r *resolver) collectFromProxyServices() {
if len(r.snap.proxyByCluster) == 0 || len(r.snap.services) == 0 {
return
}
services, proxyByCluster := r.snap.services, r.snap.proxyByCluster
expanded := r.expandChangedPeersWithGroups()
for _, svc := range services {
if svc == nil {
continue
}
proxyPeers := proxyByCluster[svc.ProxyCluster]
if len(proxyPeers) == 0 {
continue
}
matchedByPeer := serviceMatchesChangedPeers(svc, proxyPeers, expanded)
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.changedGroupSet)
if !matchedByPeer && !matchedByAccessGroup {
continue
}
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets and access groups %v",
svc.ID, svc.ProxyCluster, matchedByPeer, matchedByAccessGroup, len(proxyPeers), svc.AccessGroups)
for _, pid := range proxyPeers {
r.peerSet[pid] = struct{}{}
}
for _, target := range svc.Targets {
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
r.peerSet[target.TargetId] = struct{}{}
}
}
addAll(r.groupSet, svc.AccessGroups)
}
}
func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
if len(r.changedGroupSet) == 0 {
return r.changedPeerSet
}
ids := r.peerIDsForGroups(r.changedGroupSet)
if len(ids) == 0 {
return r.changedPeerSet
}
merged := make(map[string]struct{}, len(r.changedPeerSet)+len(ids))
for id := range r.changedPeerSet {
merged[id] = struct{}{}
}
for _, id := range ids {
merged[id] = struct{}{}
}
return merged
}
// collectResourceRouterBridge crosses between source peers and routing peers, which
// are reachable only via resource -> network -> router, not through the policy's own
// groups: source -> router (targeted resources' networks), then router -> source.
func (r *resolver) collectResourceRouterBridge() {
r.bridgeSourceToRouters()
r.bridgeRoutersToSources()
}
func (r *resolver) bridgeSourceToRouters() {
resourceIDs := r.policyDestinationResourceIDs(r.matchedPolicies...)
if len(resourceIDs) == 0 {
return
}
networkIDs := r.resourceNetworkIDs(resourceIDs)
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
setToSlice(resourceIDs), setToSlice(networkIDs))
for id := range networkIDs {
r.networkIDs[id] = struct{}{}
}
}
func (r *resolver) bridgeRoutersToSources() {
if len(r.networkIDs) == 0 {
return
}
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: affected networks %v -> folding their routing peers and the source peers of policies targeting their resources",
setToSlice(r.networkIDs))
r.foldRoutersOnNetworks(r.networkIDs)
resourceIDs := make(map[string]struct{})
for _, resource := range r.networkResources() {
if _, ok := r.networkIDs[resource.NetworkID]; ok {
resourceIDs[resource.ID] = struct{}{}
}
}
if len(resourceIDs) == 0 {
return
}
for _, policy := range r.policies() {
if r.policyTargetsResources(policy, resourceIDs) {
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: policy %s (%s) targets an affected-network resource -> folding its source groups/peers", policy.ID, policy.Name)
collectPolicySources(policy, r.groupSet, r.peerSet)
}
}
}
func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
for _, router := range r.networkRouters() {
if _, ok := networkIDs[router.NetworkID]; !ok {
continue
}
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: router %s serves affected network %s -> folding peerGroups=%v peer=%q",
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
addAll(r.groupSet, router.PeerGroups)
if router.Peer != "" {
r.peerSet[router.Peer] = struct{}{}
}
}
}
func (r *resolver) resourceNetworkIDs(resourceIDs map[string]struct{}) map[string]struct{} {
networkIDs := make(map[string]struct{})
for _, resource := range r.networkResources() {
if _, ok := resourceIDs[resource.ID]; ok {
networkIDs[resource.NetworkID] = struct{}{}
}
}
return networkIDs
}
func (r *resolver) policyTargetsResources(policy *types.Policy, resourceIDs map[string]struct{}) bool {
if policy == nil {
return false
}
destGroupSet := make(map[string]struct{})
for _, rule := range policy.Rules {
if rule.DestinationResource.Type != types.ResourceTypePeer && isInSet(rule.DestinationResource.ID, resourceIDs) {
return true
}
for _, gID := range rule.Destinations {
destGroupSet[gID] = struct{}{}
}
}
if len(destGroupSet) == 0 {
return false
}
for gID := range destGroupSet {
group := r.snap.groups[gID]
if group == nil {
continue
}
for _, res := range group.Resources {
if isInSet(res.ID, resourceIDs) {
return true
}
}
}
return false
}
func (r *resolver) policyDestinationResourceIDs(policies ...*types.Policy) map[string]struct{} {
resourceIDs := make(map[string]struct{})
destGroupSet := collectPolicyDestinations(resourceIDs, policies...)
r.addGroupResourceIDs(destGroupSet, resourceIDs)
return resourceIDs
}
// collectPolicyDestinations adds direct destination resource IDs to resourceIDs and
// returns the referenced destination group IDs.
func collectPolicyDestinations(resourceIDs map[string]struct{}, policies ...*types.Policy) map[string]struct{} {
destGroupSet := make(map[string]struct{})
for _, policy := range policies {
if policy == nil {
continue
}
for _, rule := range policy.Rules {
addAll(destGroupSet, rule.Destinations)
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID != "" {
resourceIDs[rule.DestinationResource.ID] = struct{}{}
}
}
}
return destGroupSet
}
// addGroupResourceIDs folds the resource IDs of the given groups into resourceIDs.
func (r *resolver) addGroupResourceIDs(groupIDs map[string]struct{}, resourceIDs map[string]struct{}) {
for gID := range groupIDs {
group := r.snap.groups[gID]
if group == nil {
continue
}
for _, res := range group.Resources {
if res.ID != "" {
resourceIDs[res.ID] = struct{}{}
}
}
}
}
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
}
func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]struct{}) {
for _, rule := range policy.Rules {
addAll(groupSet, rule.Sources)
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
}
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
return true
}
}
return false
}
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
return true
}
}
return false
}
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
for _, id := range policy.SourcePostureChecks {
if _, ok := ids[id]; ok {
return true
}
}
return false
}
func isDirectPeerInSet(res types.Resource, set map[string]struct{}) bool {
if res.Type != types.ResourceTypePeer || res.ID == "" {
return false
}
_, ok := set[res.ID]
return ok
}
func serviceMatchesChangedPeers(svc *rpservice.Service, proxyPeers []string, changedPeers map[string]struct{}) bool {
for _, pid := range proxyPeers {
if _, ok := changedPeers[pid]; ok {
return true
}
}
for _, target := range svc.Targets {
if target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
continue
}
if _, ok := changedPeers[target.TargetId]; ok {
return true
}
}
return false
}
func anyInSet(ids []string, set map[string]struct{}) bool {
for _, id := range ids {
if _, ok := set[id]; ok {
return true
}
}
return false
}
func isInSet(id string, set map[string]struct{}) bool {
_, ok := set[id]
return ok
}
func addAll(set map[string]struct{}, slices ...[]string) {
for _, s := range slices {
for _, id := range s {
set[id] = struct{}{}
}
}
}
func toSet(ids []string) map[string]struct{} {
set := make(map[string]struct{}, len(ids))
for _, id := range ids {
set[id] = struct{}{}
}
return set
}
func setToSlice(set map[string]struct{}) []string {
s := make([]string, 0, len(set))
for id := range set {
s = append(s, id)
}
return s
}

View File

@@ -0,0 +1,140 @@
package affectedpeers
import (
"testing"
"github.com/stretchr/testify/assert"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/types"
)
// policyGroupsAndPeers mirrors the explicit-policy extraction (RuleGroups +
// direct peers) the resolver folds in, for asserting the pure logic.
func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []string) {
peerSet := map[string]struct{}{}
for _, p := range policies {
if p == nil {
continue
}
groups = append(groups, p.RuleGroups()...)
collectPolicyDirectPeers(p, peerSet)
}
for id := range peerSet {
peers = append(peers, id)
}
return groups, peers
}
func TestPolicyGroupsAndPeers_Basic(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
groups, peers := policyGroupsAndPeers(policy)
assert.ElementsMatch(t, []string{"g1", "g2", "g3"}, groups)
assert.Empty(t, peers)
}
func TestPolicyGroupsAndPeers_WithPeerResources(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
Sources: []string{"g1"},
SourceResource: types.Resource{ID: "p1", Type: types.ResourceTypePeer},
Destinations: []string{"g2"},
DestinationResource: types.Resource{ID: "p2", Type: types.ResourceTypePeer},
}}}
groups, peers := policyGroupsAndPeers(policy)
assert.ElementsMatch(t, []string{"g1", "g2"}, groups)
assert.ElementsMatch(t, []string{"p1", "p2"}, peers)
}
func TestPolicyGroupsAndPeers_NilPolicy(t *testing.T) {
groups, peers := policyGroupsAndPeers(nil)
assert.Nil(t, groups)
assert.Nil(t, peers)
}
func TestPolicyGroupsAndPeers_MultiplePolicies(t *testing.T) {
old := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1"}, Destinations: []string{"g2"}}}}
updated := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g3"}, Destinations: []string{"g4"}}}}
groups, _ := policyGroupsAndPeers(updated, old)
assert.ElementsMatch(t, []string{"g1", "g2", "g3", "g4"}, groups)
}
func TestPolicyGroupsAndPeers_NonPeerResource(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
Sources: []string{"g1"},
SourceResource: types.Resource{ID: "domain-1", Type: types.ResourceTypeDomain},
Destinations: []string{"g2"},
}}}
groups, peers := policyGroupsAndPeers(policy)
assert.ElementsMatch(t, []string{"g1", "g2"}, groups)
assert.Empty(t, peers, "domain resource type should not produce direct peer IDs")
}
func TestChangeIsEmpty(t *testing.T) {
assert.True(t, Change{}.isEmpty())
assert.False(t, Change{ChangedGroupIDs: []string{"g"}}.isEmpty())
assert.False(t, Change{ChangedPeerIDs: []string{"p"}}.isEmpty())
assert.False(t, Change{Policies: []*types.Policy{{}}}.isEmpty())
assert.False(t, Change{Resources: []*resourceTypes.NetworkResource{{ID: "r"}}}.isEmpty())
assert.False(t, Change{Networks: []*networkTypes.Network{{ID: "n"}}}.isEmpty())
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
}
func TestPolicyReferencesGroups(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
}
func TestPolicyReferencesDirectPeers(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
}}}
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
}
func TestPolicyReferencesPostureChecks(t *testing.T) {
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}
assert.True(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc1": {}}))
assert.False(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc3": {}}))
}
func TestCollectPolicyDirectPeers(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
DestinationResource: types.Resource{Type: types.ResourceTypePeer, ID: "p2"},
}, {
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
}}}
peerSet := map[string]struct{}{}
collectPolicyDirectPeers(policy, peerSet)
assert.Contains(t, peerSet, "p1")
assert.Contains(t, peerSet, "p2")
assert.NotContains(t, peerSet, "r1")
}
func TestCollectPolicySources(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
Sources: []string{"g1"},
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
Destinations: []string{"g2"},
}}}
groupSet := map[string]struct{}{}
peerSet := map[string]struct{}{}
collectPolicySources(policy, groupSet, peerSet)
assert.Contains(t, groupSet, "g1")
assert.NotContains(t, groupSet, "g2", "destination groups must not be collected as sources")
assert.Contains(t, peerSet, "p1")
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/base62"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
@@ -74,7 +75,10 @@ func (m *manager) ValidateAndParseToken(ctx context.Context, value string) (auth
}
func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth auth.UserAuth, token *jwt.Token) (auth.UserAuth, error) {
if userAuth.IsChild || userAuth.IsPAT {
// Child accounts and PAT-authenticated requests do not use JWT group access checks.
// Embedded-Dex local users also skip them because local password authentication
// does not provide external IdP group claims.
if userAuth.IsChild || userAuth.IsPAT || dex.IsLocalUserID(userAuth.UserId) {
return userAuth, nil
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
@@ -206,6 +207,43 @@ func TestAuthManager_EnsureUserAccessByJWTGroups(t *testing.T) {
_, err = manager.EnsureUserAccessByJWTGroups(context.Background(), userAuth, token)
require.Error(t, err, "ensure user access is not in allowed groups")
})
t.Run("Local embedded-Dex user is exempt from JWT allow-groups", func(t *testing.T) {
account.Settings.JWTGroupsEnabled = true
account.Settings.JWTGroupsClaimName = "idp-groups"
account.Settings.JWTAllowGroups = []string{"not-a-group"}
err := store.SaveAccount(context.Background(), account)
require.NoError(t, err, "save account failed")
// Local Dex users have a "local" connector encoded in their user ID.
localUserAuth := nbauth.UserAuth{
AccountId: account.Id,
Domain: domain,
UserId: dex.EncodeDexUserID("local-owner", "local"),
}
localUserAuth, err = manager.EnsureUserAccessByJWTGroups(context.Background(), localUserAuth, token)
require.NoError(t, err, "local user must not be locked out by JWT allow-groups (issue #5337)")
require.Len(t, localUserAuth.Groups, 0, "JWT groups must not be evaluated for local users")
})
t.Run("Federated embedded-Dex user is still subject to JWT allow-groups", func(t *testing.T) {
account.Settings.JWTGroupsEnabled = true
account.Settings.JWTGroupsClaimName = "idp-groups"
account.Settings.JWTAllowGroups = []string{"not-a-group"}
err := store.SaveAccount(context.Background(), account)
require.NoError(t, err, "save account failed")
// A federated user (non-"local" connector) must remain restricted.
fedUserAuth := nbauth.UserAuth{
AccountId: account.Id,
Domain: domain,
UserId: dex.EncodeDexUserID("entra-user", "entra"),
}
_, err = manager.EnsureUserAccessByJWTGroups(context.Background(), fedUserAuth, token)
require.Error(t, err, "federated user must still be restricted by JWT allow-groups")
})
}
func TestAuthManager_ValidateAndParseToken(t *testing.T) {

View File

@@ -8,6 +8,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
@@ -47,8 +48,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var eventsToStore []func()
var snap *affectedpeers.Snapshot
var change affectedpeers.Change
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
@@ -63,11 +65,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
if err != nil {
return err
}
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...)
@@ -75,6 +72,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return err
}
change = affectedpeers.Change{DistributionGroupIDs: slices.Concat(addedGroups, removedGroups)}
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err
}
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -85,9 +87,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceDNSSettings, Operation: types.UpdateOperationUpdate})
}
am.ExpandAndUpdateAffected(ctx, accountID, snap, change)
return nil
}
@@ -133,20 +133,6 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
return eventsToStore
}
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups)
}
// validateDNSSettings validates the DNS settings.
func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error {
if len(settings.DisabledManagementGroups) == 0 {

View File

@@ -298,11 +298,11 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account
return nil, err
}
savedPeer1, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false)
savedPeer1, _, _, _, err := am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer1, false)
if err != nil {
return nil, err
}
_, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false)
_, _, _, _, err = am.AddPeer(context.Background(), "", "", dnsAdminUserID, peer2, false)
if err != nil {
return nil, err
}

Some files were not shown because too many files have changed in this diff Show More