Compare commits

...

21 Commits

Author SHA1 Message Date
jnfrati
520370a8b0 Merge branch 'main' of github.com:netbirdio/netbird into feat/admin-cli 2026-06-22 17:17:47 +02:00
Bethuel Mmbaga
af3b7e4497 [misc] Add enterprise getting-started and migrate script (#6501) 2026-06-22 16:58:45 +03:00
Zoltan Papp
e84f6527f7 [client] fix WaitStreamConnected test call after ctx signature change (#6503)
watchdog_test.go called WaitStreamConnected() without the context.Context
argument added in #6443, breaking the signal client test build.
2026-06-22 15:53:11 +02:00
Zoltan Papp
ac9529ea8c [client] Fix engine lifecyrcle race (#6443)
* [client] always clean up on Engine.Start failure via defer

The rosenpass init paths (NewManager/Run) returned without calling
e.close(), leaking the WireGuard interface and other partially
initialized state on failure. Per-branch cleanup was easy to miss when
adding new early returns.

Convert Start to a named error return and tear down via a single defer
that calls e.close() whenever err != nil, removing the scattered
per-branch close() calls (including the redundant one in initFirewall).

* [client] make Engine single-use and guard against double Start

Create the run context once in NewEngine instead of in Start. This
keeps e.cancel valid for the engine's whole lifetime, so Stop can
cancel a Start that is blocked waiting on the network while holding
syncMsgMux: Stop now cancels before taking the lock, unblocking that
Start so it can release the mutex.

Reject re-entry into Start: a non-nil wgInterface means a prior Start
already ran (ErrEngineAlreadyStarted), and a cancelled run context
means the engine was stopped (ErrEngineAlreadyStopped). Both checks run
before the cleanup defer so a duplicate call cannot tear down the
running engine's state.

* [client] let engine context unblock WaitStreamConnected

WaitStreamConnected only watched the signal client's own context, which
derives from the parent engineCtx rather than the engine's run context.
A Start blocked here (signal stream not yet up) could therefore not be
released by Engine.Stop, since Stop only cancels the engine's run
context.

Pass a context into WaitStreamConnected and select on it too, and have
the engine pass e.ctx, so Stop cancelling e.ctx unblocks a parked Start.
Update the Client interface, the mock, and callers accordingly.

* [client] fix Start/Stop race by making the run loop own engine shutdown

ConnectClient.Stop stopped the engine directly while the run loop's
backoff cycle could still be starting an engine, so Engine.close raced
Engine.Start (e.g. firewall setup reading wgInterface while close nils
it). embed.Client.Start's rollback only avoided a deadlock by cancelling
before Stop; the race itself remained and was caught by -race.

Make the run loop the sole owner of engine shutdown: derive the run
context in NewConnectClient, and have Stop cancel it and wait for the
loop to exit (skipping the wait when the loop never ran) instead of
calling engine.Stop. The loop now always stops the engine on its way
out, dropping the unsynchronised wgInterface check it used to guard that
call. Self-calls from within the loop use runCancel to avoid waiting on
themselves.

embed keeps a defensive pre-Stop cancel(); the daemon's cleanupConnection
gets a TODO to adopt Stop() rather than stopping the engine in parallel.

* [client] init context state in engine tests

Engine tests built the engine context with context.WithCancel(
context.Background()), omitting CtxInitState. Now that the run context
is created in the constructor, the wgIfaceMonitor goroutine can reach
triggerClientRestart during teardown, which calls CtxGetState and
panics on the missing state. Real entry points (up, embed, service)
always CtxInitState; only the tests skipped it.

* [client] interrupt connect backoff on context cancel

The run loop retried with a raw ExponentialBackOff, so a backoff sleep
ignored context cancellation. Now that ConnectClient.Stop waits for the
run loop to exit, a cancel landing during a sleep would block Stop for
the full interval (up to MaxInterval). Wrap the backoff with the run
context so Retry returns promptly on cancel; the retry budget itself
(MaxElapsedTime) is unchanged.

* [client] bound WaitStreamConnected in signal client tests

The tests waited on WaitStreamConnected with context.Background() and the
client's own context was also Background, so a stream that never connects
would hang until the suite timeout. Pass a 5s timeout context and assert
StreamConnected afterwards so the tests fail fast with a clear reason.

* [client] fix WaitStreamConnected stale-channel race

The StreamConnected check and the wait-channel creation took the mutex
separately, so notifyStreamConnected could set the status and close/clear
connectedCh in between: the waiter then created a fresh channel nobody
would ever close and blocked forever. Also, the status read was unlocked
while notify wrote it under the mutex (a data race). Do the check and the
channel fetch in one locked section; drop the now-unused
getStreamStatusChan helper. Pre-existing bug, not introduced by this branch.

* [client] abort Start if context cancelled while waiting for signal stream

receiveSignalEvents blocks in WaitStreamConnected until the signal stream
connects or the context is cancelled. If Stop cancelled e.ctx while Start
was parked there, Start kept going: it started the remaining subsystems on
a cancelled context and marked a shutting-down engine as started. Return
the context error from receiveSignalEvents and propagate it from Start, so
the deferred cleanup runs and the cancellation reaches the caller.

* [client] clean up all started components on Start failure

Start's failure defer only called close(), which covers the wg interface,
firewall, rosenpass and port forwarding but leaves connMgr, srWatcher,
route/DNS/flow/state managers and the monitor goroutines running. A late
failure (e.g. the context-cancelled check after the signal stream) thus
leaked them.

Extract Stop's locked teardown into stopLocked (caller holds syncMsgMux,
does not wait on shutdownWg) and call it from both Stop and Start's defer.
The defer also cancels the run context first so goroutines started before
the failure unwind. Teardown order is unchanged.
2026-06-22 13:52:57 +02:00
Zoltan Papp
f736ef9647 [client/ios] Add Auth.Stop() to cancel an in-progress interactive login (#6486)
The iOS PKCE login runs in the main-app process, decoupled from the network
extension (the extension's client context is torn down on login-required, which
would otherwise kill the WaitToken goroutine before the OAuth callback arrives).
Because it is decoupled, nothing aborted the flow when the user dismissed the
browser without logging in: WaitToken kept its loopback HTTP server bound to the
redirect port until the flow expired, so the next connect stalled trying to bind
the same port.

Make the Auth context cancellable and add Auth.Stop(), which cancels it. Cancelling
unblocks WaitToken, whose deferred server.Shutdown frees the port immediately. This
mirrors how Android's stopEngine() aborts login via the engine context.

NewAuthWithConfig now also derives a cancellable context; its only iOS caller uses
LoginSync (no interactive server), so behaviour is unchanged there.
2026-06-22 13:27:21 +02:00
Maycon Santos
cf58bf1ba9 [misc] Add TARGETPLATFORM build argument to Docker build commands (#6499) 2026-06-22 12:43:19 +02:00
Viktor Liu
522b8ed969 [client] Surface DNS forwarder upstream failures via Extended DNS Errors (#6441) 2026-06-22 12:41:33 +02:00
dependabot[bot]
c9e99659ea [misc] Bump the actions group across 1 directory with 9 updates (#6451)
Bumps the actions group with 9 updates in the / directory:

| Package | From | To |
| --- | --- | --- |
| [actions/checkout](https://github.com/actions/checkout) | `6.0.2` | `7.0.0` |
| [actions/setup-go](https://github.com/actions/setup-go) | `6.3.0` | `6.4.0` |
| [codecov/codecov-action](https://github.com/codecov/codecov-action) | `6.0.1` | `7.0.0` |
| [vmactions/freebsd-vm](https://github.com/vmactions/freebsd-vm) | `1.4.5` | `1.4.8` |
| [actions/setup-java](https://github.com/actions/setup-java) | `5.2.0` | `5.3.0` |
| [docker/setup-qemu-action](https://github.com/docker/setup-qemu-action) | `4.0.0` | `4.1.0` |
| [docker/setup-buildx-action](https://github.com/docker/setup-buildx-action) | `4.0.0` | `4.1.0` |
| [goreleaser/goreleaser-action](https://github.com/goreleaser/goreleaser-action) | `7.2.0` | `7.2.2` |
| [actions/download-artifact](https://github.com/actions/download-artifact) | `8.0.0` | `8.0.1` |



Updates `actions/checkout` from 6.0.2 to 7.0.0
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](de0fac2e45...9c091bb21b)

Updates `actions/setup-go` from 6.3.0 to 6.4.0
- [Release notes](https://github.com/actions/setup-go/releases)
- [Commits](4b73464bb3...4a3601121d)

Updates `codecov/codecov-action` from 6.0.1 to 7.0.0
- [Release notes](https://github.com/codecov/codecov-action/releases)
- [Changelog](https://github.com/codecov/codecov-action/blob/main/CHANGELOG.md)
- [Commits](e79a6962e0...fb8b3582c8)

Updates `vmactions/freebsd-vm` from 1.4.5 to 1.4.8
- [Release notes](https://github.com/vmactions/freebsd-vm/releases)
- [Commits](d1e6581156...b84ab5559b)

Updates `actions/setup-java` from 5.2.0 to 5.3.0
- [Release notes](https://github.com/actions/setup-java/releases)
- [Commits](be666c2fcd...ad2b38190b)

Updates `docker/setup-qemu-action` from 4.0.0 to 4.1.0
- [Release notes](https://github.com/docker/setup-qemu-action/releases)
- [Commits](ce360397dd...06116385d9)

Updates `docker/setup-buildx-action` from 4.0.0 to 4.1.0
- [Release notes](https://github.com/docker/setup-buildx-action/releases)
- [Commits](4d04d5d948...d7f5e7f509)

Updates `goreleaser/goreleaser-action` from 7.2.0 to 7.2.2
- [Release notes](https://github.com/goreleaser/goreleaser-action/releases)
- [Commits](4c6ab561ad...5daf1e915a)

Updates `actions/download-artifact` from 8.0.0 to 8.0.1
- [Release notes](https://github.com/actions/download-artifact/releases)
- [Commits](70fc10c6e5...3e5f45b2cf)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: 6.0.3
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: actions
- dependency-name: actions/download-artifact
  dependency-version: 8.0.1
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: actions
- dependency-name: actions/setup-go
  dependency-version: 6.4.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: actions
- dependency-name: actions/setup-java
  dependency-version: 5.3.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: actions
- dependency-name: codecov/codecov-action
  dependency-version: 7.0.0
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: actions
- dependency-name: docker/setup-buildx-action
  dependency-version: 4.1.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: actions
- dependency-name: docker/setup-qemu-action
  dependency-version: 4.1.0
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: actions
- dependency-name: goreleaser/goreleaser-action
  dependency-version: 7.2.2
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: actions
- dependency-name: vmactions/freebsd-vm
  dependency-version: 1.4.6
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-06-22 09:43:33 +02:00
Viktor Liu
58c79f5878 [client] Fix DNS custom zone teardown: handler leak and external CNAME resolution (#6445) 2026-06-19 17:33:09 +02:00
Viktor Liu
15a0504fb1 [client] Treat answering upstreams as reachable and widen DNS health grace window (#6453) 2026-06-19 17:32:49 +02:00
Riccardo Manfrin
883a1a8961 [client] Fix profile regressions in up --profile and status (#6479)
* Restores behavior to create profile if not there on Up

* Allows to restore nerbird status showing of the profile name

* [client] Reduce upFunc cognitive complexity

Extract the profile switch/auto-create logic from upFunc into a dedicated
switchOrCreateProfile helper. The inlined NotFound-retry branch pushed
upFunc over SonarCloud's cognitive complexity threshold (S3776).
No behavior change.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* [client] Make up --profile auto-create idempotent under concurrent runs

Don't fail switchOrCreateProfile on a createProfile error: a concurrent
run may create the profile between the NotFound check and our create
call. Retry the switch regardless and only surface the create error if
the switch also fails. Addresses CodeRabbit race-condition feedback.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* Share createProfile with addProfileFunc

* But allow conn reusage

* moves switchOrCreateProfile to where it's used

---------

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-19 16:23:51 +02:00
Maycon Santos
54192a94b7 [misc] handle release candidates when fetching tags in FreeBSD port scripts (#6480)
* [misc] Exclude release candidates when fetching tags in FreeBSD port scripts
2026-06-19 14:10:43 +02:00
Pascal Fischer
8511687270 [management] log peer meta diff (#6468) 2026-06-19 13:30:52 +02:00
Pascal Fischer
35b465fa4a [management] reduce sync and login transaction (#6472) 2026-06-19 11:43:01 +02:00
Brad Ison
fb87f751a5 [management] Fetch complete user data in ValidateTunnelPeer (#6457)
* [management] Fetch complete user data in ValidateTunnelPeer

Previously the `ValidateTunnelPeer` method used by the ProxyService
would fetch user information from the database if the connected peer
was associated with a user ID, but it would not consult the IdP data
for cached info from JWT claims like email.  This caused the value of
the injected `X-Netbird-User` header to always display the peer ID and
never the user email associated with the peer as expected.

This change adds an optional IdP manager to the ProxyService and
fetches the complete user data from it if present.

* [management] Refactor ValidateTunnelPeer principal info gathering

This refactors the gathering of info on proxy tunnel peer principals
into its own method to keep the complexity down and make Sonar happy.
2026-06-19 11:39:21 +02:00
Maycon Santos
679c7182a4 [misc] Remove version prefix v docker tags (#6471) 2026-06-18 22:34:24 +02:00
Pascal Fischer
8c031ea6f0 [management] remove db calls in nested loops (#6470) 2026-06-18 22:12:59 +02:00
Pascal Fischer
60a9544656 [management] pass meta update for browser clients (#6465) 2026-06-18 17:22:42 +02:00
Viktor Liu
d3710d4bb2 [signal] Serialize concurrent sends to a peer signal stream (#6463) 2026-06-18 15:00:19 +02:00
jnfrati
b5a16a1898 chore: move token commands under admin CLI 2026-06-04 12:49:48 +02:00
jnfrati
449b5cbb80 feat: add self-hosted admin CLI 2026-06-04 11:41:57 +02:00
66 changed files with 3542 additions and 593 deletions

View File

@@ -20,7 +20,7 @@ jobs:
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
@@ -59,12 +59,12 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Set up Go - name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: true cache: true

View File

@@ -15,7 +15,7 @@ jobs:
pull-requests: write pull-requests: write
steps: steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1 - uses: git-town/action@3d8b878379abb1ee393fb49865a28b4a6c2cd3b0 # v1.2.1

View File

@@ -16,12 +16,12 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -48,7 +48,7 @@ jobs:
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined) run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -coverprofile=coverage.txt -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1 uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird slug: netbirdio/netbird

View File

@@ -16,7 +16,7 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
@@ -28,7 +28,7 @@ jobs:
id: test id: test
env: env:
GO_VERSION: ${{ steps.goversion.outputs.version }} GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5 uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
with: with:
usesh: true usesh: true
copyback: false copyback: false

View File

@@ -18,7 +18,7 @@ jobs:
management: ${{ steps.filter.outputs.management }} management: ${{ steps.filter.outputs.management }}
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
@@ -30,7 +30,7 @@ jobs:
- 'management/**' - 'management/**'
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -119,12 +119,12 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -162,7 +162,7 @@ jobs:
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64' if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1 uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird slug: netbirdio/netbird
@@ -175,12 +175,12 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -246,12 +246,12 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -290,7 +290,7 @@ jobs:
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64' if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1 uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird slug: netbirdio/netbird
@@ -306,12 +306,12 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -347,7 +347,7 @@ jobs:
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64' if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1 uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird slug: netbirdio/netbird
@@ -363,12 +363,12 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -407,7 +407,7 @@ jobs:
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64' if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1 uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird slug: netbirdio/netbird
@@ -424,12 +424,12 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -484,7 +484,7 @@ jobs:
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64' if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1 uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird slug: netbirdio/netbird
@@ -529,12 +529,12 @@ jobs:
prom/prometheus prom/prometheus
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -623,12 +623,12 @@ jobs:
prom/prometheus prom/prometheus
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -692,12 +692,12 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -734,7 +734,7 @@ jobs:
- name: Upload coverage reports to Codecov - name: Upload coverage reports to Codecov
if: matrix.arch == 'amd64' if: matrix.arch == 'amd64'
uses: codecov/codecov-action@e79a6962e0d4c0c17b229090214935d2e33f8354 #v6.0.1 uses: codecov/codecov-action@fb8b3582c8e4def4969c97caa2f19720cb33a72f #v7.0.0
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
slug: netbirdio/netbird slug: netbirdio/netbird

View File

@@ -18,12 +18,12 @@ jobs:
runs-on: windows-latest runs-on: windows-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
id: go id: go
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"

View File

@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: codespell - name: codespell
@@ -40,7 +40,7 @@ jobs:
timeout-minutes: 15 timeout-minutes: 15
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Check for duplicate constants - name: Check for duplicate constants
@@ -48,7 +48,7 @@ jobs:
run: | run: |
! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep . ! awk '/const \(/,/)/{print $0}' management/server/activity/codes.go | grep -o '= [0-9]*' | sort | uniq -d | grep .
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false

View File

@@ -22,7 +22,7 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false

View File

@@ -16,11 +16,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- name: Setup Android SDK - name: Setup Android SDK
@@ -28,7 +28,7 @@ jobs:
with: with:
cmdline-tools-version: 8512546 cmdline-tools-version: 8512546
- name: Setup Java - name: Setup Java
uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 uses: actions/setup-java@ad2b38190b15e4d6bdf0c97fb4fca8412226d287
with: with:
java-version: "11" java-version: "11"
distribution: "adopt" distribution: "adopt"
@@ -54,11 +54,11 @@ jobs:
runs-on: macos-latest runs-on: macos-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- name: install gomobile - name: install gomobile

View File

@@ -27,7 +27,7 @@ jobs:
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
@@ -64,7 +64,7 @@ jobs:
if: steps.check_diff.outputs.diff_exists == 'true' if: steps.check_diff.outputs.diff_exists == 'true'
env: env:
GO_VERSION: ${{ steps.goversion.outputs.version }} GO_VERSION: ${{ steps.goversion.outputs.version }}
uses: vmactions/freebsd-vm@d1e65811565151536c0c894fff74f06351ed26e6 # v1.4.5 uses: vmactions/freebsd-vm@b84ab5559b5a1bb4b8ee2737d2506a16e1737636 # v1.4.8
with: with:
usesh: true usesh: true
copyback: false copyback: false
@@ -135,7 +135,7 @@ jobs:
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }} ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false persist-credentials: false
@@ -166,7 +166,7 @@ jobs:
fi fi
- name: Set up Go - name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -186,9 +186,9 @@ jobs:
- name: check git status - name: check git status
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Set up QEMU - name: Set up QEMU
uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a #v4.0.0 uses: docker/setup-qemu-action@06116385d9baf250c9f4dcb4858b16962ea869c3 #v4.1.0
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd #v4.0.0 uses: docker/setup-buildx-action@d7f5e7f509e45cec5c76c4d5afdd7de93d0b3df5 #v4.1.0
- name: Login to Docker hub - name: Login to Docker hub
if: github.event_name != 'pull_request' if: github.event_name != 'pull_request'
uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0 uses: docker/login-action@650006c6eb7dba73a995cc03b0b2d7f5ca915bee # v4.2.0
@@ -221,7 +221,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
- name: Run GoReleaser - name: Run GoReleaser
id: goreleaser id: goreleaser
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0 uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --clean ${{ env.flags }} args: release --clean ${{ env.flags }}
@@ -347,7 +347,7 @@ jobs:
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }} release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false persist-credentials: false
@@ -374,7 +374,7 @@ jobs:
fi fi
- name: Set up Go - name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -420,7 +420,7 @@ jobs:
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_arm64.syso
- name: Run GoReleaser - name: Run GoReleaser
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0 uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }} args: release --config .goreleaser_ui.yaml --clean ${{ env.flags }}
@@ -464,12 +464,12 @@ jobs:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
fetch-depth: 0 # It is required for GoReleaser to work properly fetch-depth: 0 # It is required for GoReleaser to work properly
persist-credentials: false persist-credentials: false
- name: Set up Go - name: Set up Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
cache: false cache: false
@@ -488,7 +488,7 @@ jobs:
run: git --no-pager diff --exit-code run: git --no-pager diff --exit-code
- name: Run GoReleaser - name: Run GoReleaser
id: goreleaser id: goreleaser
uses: goreleaser/goreleaser-action@4c6ab561adb47e50c45ef534e2155934e91c40c1 # v7.2.0 uses: goreleaser/goreleaser-action@5daf1e915a5f0af01ddbcd89a43b8061ff4f1a89 # v7.2.2
with: with:
version: ${{ env.GORELEASER_VER }} version: ${{ env.GORELEASER_VER }}
args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }} args: release --config .goreleaser_ui_darwin.yaml --clean ${{ env.flags }}
@@ -522,7 +522,7 @@ jobs:
downloadPath: '${{ github.workspace }}\temp' downloadPath: '${{ github.workspace }}\temp'
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
@@ -534,13 +534,13 @@ jobs:
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Download release artifacts - name: Download release artifacts
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1 uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
with: with:
name: release name: release
path: release path: release
- name: Download UI release artifacts - name: Download UI release artifacts
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.1 uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
with: with:
name: release-ui name: release-ui
path: release-ui path: release-ui

View File

@@ -68,12 +68,12 @@ jobs:
run: sudo apt-get install -y curl run: sudo apt-get install -y curl
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
@@ -207,7 +207,7 @@ jobs:
- name: Build management docker image - name: Build management docker image
working-directory: management working-directory: management
run: | run: |
docker build -t netbirdio/management:latest . docker build -t netbirdio/management:latest --build-arg TARGETPLATFORM=. .
- name: Build signal binary - name: Build signal binary
working-directory: signal working-directory: signal
@@ -216,7 +216,7 @@ jobs:
- name: Build signal docker image - name: Build signal docker image
working-directory: signal working-directory: signal
run: | run: |
docker build -t netbirdio/signal:latest . docker build -t netbirdio/signal:latest --build-arg TARGETPLATFORM=. .
- name: Build relay binary - name: Build relay binary
working-directory: relay working-directory: relay
@@ -225,7 +225,7 @@ jobs:
- name: Build relay docker image - name: Build relay docker image
working-directory: relay working-directory: relay
run: | run: |
docker build -t netbirdio/relay:latest . docker build -t netbirdio/relay:latest --build-arg TARGETPLATFORM=. .
- name: run docker compose up - name: run docker compose up
working-directory: infrastructure_files/artifacts working-directory: infrastructure_files/artifacts
@@ -256,7 +256,7 @@ jobs:
run: sudo apt-get install -y jq run: sudo apt-get install -y jq
- name: Checkout code - name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false

View File

@@ -19,11 +19,11 @@ jobs:
GOARCH: wasm GOARCH: wasm
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- name: Install dependencies - name: Install dependencies
@@ -44,11 +44,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@9c091bb21b7c1c1d1991bb908d89e4e9dddfe3e0 # v7.0.0
with: with:
persist-credentials: false persist-credentials: false
- name: Install Go - name: Install Go
uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6.4.0
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- name: Build Wasm client - name: Build Wasm client

View File

@@ -247,7 +247,7 @@ dockers_v2:
- netbirdio/netbird - netbirdio/netbird
- ghcr.io/netbirdio/netbird - ghcr.io/netbirdio/netbird
tags: tags:
- "v{{ .Version }}" - "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" - "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: client/Dockerfile dockerfile: client/Dockerfile
extra_files: extra_files:
@@ -295,7 +295,7 @@ dockers_v2:
- netbirdio/relay - netbirdio/relay
- ghcr.io/netbirdio/relay - ghcr.io/netbirdio/relay
tags: tags:
- "v{{ .Version }}" - "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" - "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: relay/Dockerfile dockerfile: relay/Dockerfile
platforms: platforms:
@@ -317,7 +317,7 @@ dockers_v2:
- netbirdio/signal - netbirdio/signal
- ghcr.io/netbirdio/signal - ghcr.io/netbirdio/signal
tags: tags:
- "v{{ .Version }}" - "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" - "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: signal/Dockerfile dockerfile: signal/Dockerfile
platforms: platforms:
@@ -339,7 +339,7 @@ dockers_v2:
- netbirdio/management - netbirdio/management
- ghcr.io/netbirdio/management - ghcr.io/netbirdio/management
tags: tags:
- "v{{ .Version }}" - "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" - "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: management/Dockerfile dockerfile: management/Dockerfile
platforms: platforms:
@@ -361,7 +361,7 @@ dockers_v2:
- netbirdio/upload - netbirdio/upload
- ghcr.io/netbirdio/upload - ghcr.io/netbirdio/upload
tags: tags:
- "v{{ .Version }}" - "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" - "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: upload-server/Dockerfile dockerfile: upload-server/Dockerfile
platforms: platforms:
@@ -383,7 +383,7 @@ dockers_v2:
- netbirdio/netbird-server - netbirdio/netbird-server
- ghcr.io/netbirdio/netbird-server - ghcr.io/netbirdio/netbird-server
tags: tags:
- "v{{ .Version }}" - "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" - "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: combined/Dockerfile dockerfile: combined/Dockerfile
platforms: platforms:
@@ -405,7 +405,7 @@ dockers_v2:
- netbirdio/reverse-proxy - netbirdio/reverse-proxy
- ghcr.io/netbirdio/reverse-proxy - ghcr.io/netbirdio/reverse-proxy
tags: tags:
- "v{{ .Version }}" - "{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}" - "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: proxy/Dockerfile dockerfile: proxy/Dockerfile
platforms: platforms:
@@ -462,9 +462,13 @@ checksum:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh - glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh - glob: ./release_files/install.sh
- glob: ./infrastructure_files/getting-started.sh - glob: ./infrastructure_files/getting-started.sh
- glob: ./infrastructure_files/getting-started-enterprise.sh
- glob: ./infrastructure_files/migrate-to-enterprise.sh
release: release:
extra_files: extra_files:
- glob: ./infrastructure_files/getting-started-with-zitadel.sh - glob: ./infrastructure_files/getting-started-with-zitadel.sh
- glob: ./release_files/install.sh - glob: ./release_files/install.sh
- glob: ./infrastructure_files/getting-started.sh - glob: ./infrastructure_files/getting-started.sh
- glob: ./infrastructure_files/getting-started-enterprise.sh
- glob: ./infrastructure_files/migrate-to-enterprise.sh

View File

@@ -227,7 +227,7 @@ func switchProfile(ctx context.Context, handle string, username string) (profile
Username: &username, Username: &username,
}) })
if err != nil { if err != nil {
return "", fmt.Errorf("switch profile failed: %v", err) return "", fmt.Errorf("switch profile failed: %w", err)
} }
return profilemanager.ID(resp.Id), nil return profilemanager.ID(resp.Id), nil

View File

@@ -138,26 +138,23 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
return err return err
} }
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr) conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil { if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err) return fmt.Errorf("connect to service CLI interface: %w", err)
} }
defer conn.Close() defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn) daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0] profileName := args[0]
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{ id, err := addProfileOnDaemon(cmd.Context(), daemonClient, profileName, currUser.Username)
ProfileName: profileName,
Username: currUser.Username,
})
if err != nil { if err != nil {
return fmt.Errorf("add profile request: %w", err) return err
} }
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName) dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
@@ -166,7 +163,6 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.") cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
} }
id := profilemanager.ID(resp.Id)
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName)) cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
return nil return nil
@@ -330,3 +326,19 @@ func wrapAmbiguityError(err error, handle string) error {
} }
return err return err
} }
// addProfileOnDaemon issues the AddProfile RPC on an existing daemon client
// and returns the new profile's ID. It is the single entry point for profile
// creation, shared by `netbird profile add` and the `netbird up --profile
// <name>` auto-create path.
func addProfileOnDaemon(ctx context.Context, client proto.DaemonServiceClient, profileName, username string) (profilemanager.ID, error) {
resp, err := client.AddProfile(ctx, &proto.AddProfileRequest{
ProfileName: profileName,
Username: username,
})
if err != nil {
return "", fmt.Errorf("add profile failed: %w", err)
}
return profilemanager.ID(resp.Id), nil
}

View File

@@ -11,7 +11,6 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@@ -111,11 +110,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
return nil return nil
} }
pm := profilemanager.NewProfileManager() // Resolve the active profile's display name via the daemon, which runs
var profName string // as root and can read the per-user profile files. The local profile
if activeProf, err := pm.GetActiveProfile(); err == nil { // manager only knows the active profile ID, not its display name.
profName = activeProf.Name profName := getActiveProfileName(ctx)
}
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{ var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
Anonymize: anonymizeFlag, Anonymize: anonymizeFlag,
@@ -167,6 +165,25 @@ func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (
return resp, nil return resp, nil
} }
// getActiveProfileName asks the daemon for the active profile's display
// name. The daemon runs as root and can read the per-user profile files to
// resolve the ID to its human-readable name. Returns an empty string on any
// error so status output degrades gracefully.
func getActiveProfileName(ctx context.Context) string {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return ""
}
defer conn.Close()
resp, err := proto.NewDaemonServiceClient(conn).GetActiveProfile(ctx, &proto.GetActiveProfileRequest{})
if err != nil {
return ""
}
return resp.GetProfileName()
}
func parseFilters() error { func parseFilters() error {
switch strings.ToLower(statusFilter) { switch strings.ToLower(statusFilter) {
case "", "idle", "connecting", "connected": case "", "idle", "connecting", "connected":

View File

@@ -128,15 +128,9 @@ func upFunc(cmd *cobra.Command, args []string) error {
var profileSwitched bool var profileSwitched bool
// switch profile if provided // switch profile if provided
if profileName != "" { if profileName != "" {
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username) if err := switchOrCreateProfile(cmd.Context(), pm, profileName, username.Username); err != nil {
if err != nil {
return fmt.Errorf("switch profile: %v", err) return fmt.Errorf("switch profile: %v", err)
} }
if err := pm.SwitchProfile(resolvedID); err != nil {
return fmt.Errorf("switch profile: %v", err)
}
profileSwitched = true profileSwitched = true
} }
@@ -151,6 +145,52 @@ func upFunc(cmd *cobra.Command, args []string) error {
return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched) return runInDaemonMode(ctx, cmd, pm, activeProf, profileSwitched)
} }
// switchOrCreateProfile switches the active profile to the one identified by
// handle, creating it first when it does not exist yet. This restores the
// pre-0.73 behaviour where `netbird up --profile <name>` auto-creates a
// missing profile instead of failing.
func switchOrCreateProfile(ctx context.Context, pm *profilemanager.ProfileManager, handle, username string) error {
resolvedID, err := switchProfile(ctx, handle, username)
if err != nil {
st, ok := gstatus.FromError(err)
if !ok || st.Code() != codes.NotFound {
return err
}
// Don't fail immediately on a create error: a concurrent run may
// have created the profile between the NotFound above and this
// call, in which case the retried switch still succeeds. Only
// surface the create error if the switch also fails.
_, createErr := createProfile(ctx, handle, username)
if resolvedID, err = switchProfile(ctx, handle, username); err != nil {
if createErr != nil {
return fmt.Errorf("create profile: %w", createErr)
}
return err
}
}
if err := pm.SwitchProfile(resolvedID); err != nil {
return err
}
return nil
}
// createProfile dials the daemon and creates a new profile with the given
// display name, returning its generated ID. Use addProfileOnDaemon directly
// when a daemon client is already available to reuse the connection.
func createProfile(ctx context.Context, profileName, username string) (profilemanager.ID, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer conn.Close()
return addProfileOnDaemon(ctx, proto.NewDaemonServiceClient(conn), profileName, username)
}
func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error { func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *profilemanager.Profile) error {
// override the default profile filepath if provided // override the default profile filepath if provided
if configPath != "" { if configPath != "" {

View File

@@ -279,9 +279,11 @@ func (c *Client) Start(startCtx context.Context) error {
select { select {
case <-startCtx.Done(): case <-startCtx.Done():
// Cancel the client context before stopping: Engine.Start blocks on the // ConnectClient.Stop now cancels its own run context and waits for the
// signal stream while holding the engine mutex and only unblocks on // run loop to tear the engine down, so this cancel() is no longer
// cancellation. Stopping first would deadlock on that mutex. // required to break the deadlock and could be removed. It is kept as a
// defensive belt-and-suspenders: cancelling the parent context first
// guarantees the run loop is unblocked even if Stop's contract regresses.
cancel() cancel()
if stopErr := client.Stop(); stopErr != nil { if stopErr := client.Stop(); stopErr != nil {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err()) return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())

View File

@@ -11,6 +11,7 @@ import (
"runtime/debug" "runtime/debug"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
@@ -54,6 +55,10 @@ var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath
type ConnectClient struct { type ConnectClient struct {
ctx context.Context ctx context.Context
runCancel context.CancelFunc
runExited chan struct{}
runOnce sync.Once
runStarted atomic.Bool
config *profilemanager.Config config *profilemanager.Config
statusRecorder *peer.Status statusRecorder *peer.Status
@@ -70,8 +75,14 @@ func NewConnectClient(
config *profilemanager.Config, config *profilemanager.Config,
statusRecorder *peer.Status, statusRecorder *peer.Status,
) *ConnectClient { ) *ConnectClient {
// Derive the run context here so Stop owns the cancel that unblocks the run
// loop. runCancel is set once at construction, so Stop can call it without
// racing the run loop's startup. Callers therefore need not cancel before Stop.
runCtx, runCancel := context.WithCancel(ctx)
return &ConnectClient{ return &ConnectClient{
ctx: ctx, ctx: runCtx,
runCancel: runCancel,
runExited: make(chan struct{}),
config: config, config: config,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
engineMutex: sync.Mutex{}, engineMutex: sync.Mutex{},
@@ -135,6 +146,11 @@ func (c *ConnectClient) RunOniOS(
} }
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error { func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}, logPath string) error {
// Mark the loop as started and signal exit on return so Stop can wait for
// the loop to finish (and skip the wait if the loop never ran).
c.runStarted.Store(true)
defer c.runOnce.Do(func() { close(c.runExited) })
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
rec := c.statusRecorder rec := c.statusRecorder
@@ -290,7 +306,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Debug(err) log.Debug(err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin) state.Set(StatusNeedsLogin)
_ = c.Stop() c.runCancel()
return backoff.Permanent(wrapErr(err)) // unrecoverable error return backoff.Permanent(wrapErr(err)) // unrecoverable error
} }
return wrapErr(err) return wrapErr(err)
@@ -410,14 +426,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.engine = nil c.engine = nil
c.engineMutex.Unlock() c.engineMutex.Unlock()
// todo: consider to remove this condition. Is not thread safe. log.Infof("ensuring wg interface is removed, Netbird engine context cancelled")
// We should always call Stop(), but we need to verify that it is idempotent
if engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil { if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err) log.Errorf("Failed to stop engine: %v", err)
}
} }
c.statusRecorder.ClientTeardown() c.statusRecorder.ClientTeardown()
@@ -433,12 +445,12 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
} }
c.statusRecorder.ClientStart() c.statusRecorder.ClientStart()
err = backoff.Retry(operation, backOff) err = backoff.Retry(operation, backoff.WithContext(backOff, c.ctx))
if err != nil { if err != nil {
log.Debugf("exiting client retry loop due to unrecoverable error: %s", err) log.Debugf("exiting client retry loop due to unrecoverable error: %s", err)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) { if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
state.Set(StatusNeedsLogin) state.Set(StatusNeedsLogin)
_ = c.Stop() c.runCancel()
} }
return err return err
} }
@@ -516,11 +528,9 @@ func (c *ConnectClient) Status() StatusType {
} }
func (c *ConnectClient) Stop() error { func (c *ConnectClient) Stop() error {
engine := c.Engine() c.runCancel()
if engine != nil { if c.runStarted.Load() {
if err := engine.Stop(); err != nil { <-c.runExited
return fmt.Errorf("stop engine: %w", err)
}
} }
return nil return nil
} }

View File

@@ -207,3 +207,35 @@ func FormatAnswers(answers []dns.RR) string {
} }
return "[" + strings.Join(parts, ", ") + "]" return "[" + strings.Join(parts, ", ") + "]"
} }
// StripOPT removes any OPT pseudo-RRs from the message's Extra section. Per
// RFC 6891 a responder must not include an OPT RR toward a client that did not
// advertise EDNS0.
func StripOPT(msg *dns.Msg) {
if len(msg.Extra) == 0 {
return
}
out := msg.Extra[:0]
for _, rr := range msg.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
msg.Extra = out
}
// ExtractEDE returns the first Extended DNS Error (RFC 8914) option carried in
// the message, if present.
func ExtractEDE(msg *dns.Msg) (*dns.EDNS0_EDE, bool) {
opt := msg.IsEdns0()
if opt == nil {
return nil, false
}
for _, o := range opt.Option {
if ede, ok := o.(*dns.EDNS0_EDE); ok {
return ede, true
}
}
return nil, false
}

View File

@@ -120,3 +120,42 @@ func TestLookupIP_DNSErrorNotIsNotFound(t *testing.T) {
assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL") assert.Equal(t, dns.RcodeServerFailure, result.Rcode, "upstream failure should map to SERVFAIL")
} }
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
StripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}
func TestExtractEDE(t *testing.T) {
t.Run("no edns", func(t *testing.T) {
_, ok := ExtractEDE(&dns.Msg{})
assert.False(t, ok, "message without OPT has no EDE")
})
t.Run("edns without ede", func(t *testing.T) {
rm := &dns.Msg{}
rm.SetEdns0(4096, false)
_, ok := ExtractEDE(rm)
assert.False(t, ok, "OPT without EDE option returns false")
})
t.Run("with ede", func(t *testing.T) {
rm := &dns.Msg{}
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: 49152, ExtraText: "upstream timeout"})
rm.Extra = append(rm.Extra, opt)
ede, ok := ExtractEDE(rm)
assert.True(t, ok, "EDE option should be found")
assert.Equal(t, uint16(49152), ede.InfoCode)
assert.Equal(t, "upstream timeout", ede.ExtraText)
})
}

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"net/url" "net/url"
"os"
"slices" "slices"
"strings" "strings"
"sync" "sync"
@@ -38,11 +39,15 @@ const (
// defaultWarningDelayBase is the starting grace window before a // defaultWarningDelayBase is the starting grace window before a
// "Nameserver group unreachable" event fires for a group that's // "Nameserver group unreachable" event fires for a group that's
// never been healthy and only has overlay upstreams with no // never been healthy and only has overlay upstreams with no
// Connected peer. Per-server and overridable; see warningDelayFor. // Connected peer. Per-server and overridable via envWarningDelay;
defaultWarningDelayBase = 30 * time.Second // see warningDelay.
defaultWarningDelayBase = 60 * time.Second
// warningDelayBonusCap caps the route-count bonus added to the // warningDelayBonusCap caps the route-count bonus added to the
// base grace window. See warningDelayFor. // base grace window. See warningDelay.
warningDelayBonusCap = 30 * time.Second warningDelayBonusCap = 30 * time.Second
// envWarningDelay overrides defaultWarningDelayBase with a Go duration
// string (e.g. "90s", "2m"). Invalid or non-positive values are ignored.
envWarningDelay = "NB_DNS_HEALTH_WARNING_DELAY"
) )
// errNoUsableNameservers signals that a merged-domain group has no usable // errNoUsableNameservers signals that a merged-domain group has no usable
@@ -135,7 +140,7 @@ type DefaultServer struct {
disableSys bool disableSys bool
mux sync.Mutex mux sync.Mutex
service service service service
dnsMuxMap registeredHandlerMap dnsMuxHandlers []handlerWrapper
localResolver *local.Resolver localResolver *local.Resolver
wgInterface WGIface wgInterface WGIface
hostManager hostManager hostManager hostManager
@@ -199,8 +204,6 @@ type handlerWrapper struct {
priority int priority int
} }
type registeredHandlerMap map[types.HandlerID]handlerWrapper
// DefaultServerConfig holds configuration parameters for NewDefaultServer // DefaultServerConfig holds configuration parameters for NewDefaultServer
type DefaultServerConfig struct { type DefaultServerConfig struct {
WgInterface WGIface WgInterface WGIface
@@ -289,7 +292,6 @@ func newDefaultServer(
service: dnsService, service: dnsService,
handlerChain: handlerChain, handlerChain: handlerChain,
extraDomains: make(map[domain.Domain]int), extraDomains: make(map[domain.Domain]int),
dnsMuxMap: make(registeredHandlerMap),
localResolver: local.NewResolver(), localResolver: local.NewResolver(),
wgInterface: wgInterface, wgInterface: wgInterface,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
@@ -298,7 +300,7 @@ func newDefaultServer(
hostManager: &noopHostConfigurator{}, hostManager: &noopHostConfigurator{},
mgmtCacheResolver: mgmtCacheResolver, mgmtCacheResolver: mgmtCacheResolver,
currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied currentConfigHash: ^uint64(0), // Initialize to max uint64 to ensure first config is always applied
warningDelayBase: defaultWarningDelayBase, warningDelayBase: warningDelayBaseFromEnv(),
healthRefresh: make(chan struct{}, 1), healthRefresh: make(chan struct{}, 1),
} }
// Wire the local resolver against the peer status recorder so it can // Wire the local resolver against the peer status recorder so it can
@@ -328,7 +330,7 @@ func (s *DefaultServer) SetRouteSources(selected, active func() route.HAMap) {
type routeSettable interface { type routeSettable interface {
setSelectedRoutes(func() route.HAMap) setSelectedRoutes(func() route.HAMap)
} }
for _, entry := range s.dnsMuxMap { for _, entry := range s.dnsMuxHandlers {
if h, ok := entry.handler.(routeSettable); ok { if h, ok := entry.handler.(routeSettable); ok {
h.setSelectedRoutes(selected) h.setSelectedRoutes(selected)
} }
@@ -978,19 +980,23 @@ func (s *DefaultServer) usableNameServers(nameServers []nbdns.NameServer) []neti
func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) { func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
// this will introduce a short period of time when the server is not able to handle DNS requests // this will introduce a short period of time when the server is not able to handle DNS requests
for _, existing := range s.dnsMuxMap { for _, existing := range s.dnsMuxHandlers {
s.deregisterHandler([]string{existing.domain}, existing.priority) s.deregisterHandler([]string{existing.domain}, existing.priority)
existing.handler.Stop() // The local resolver is a persistent singleton shared by every custom
// zone and reused across config updates. Its chain registrations are
// per-config and must be deregistered, but Stop() cancels its lookup
// context (breaking external CNAME-target resolution) and clears its
// records, so it must not be torn down here.
if existing.handler != s.localResolver {
existing.handler.Stop()
}
} }
muxUpdateMap := make(registeredHandlerMap)
for _, update := range muxUpdates { for _, update := range muxUpdates {
s.registerHandler([]string{update.domain}, update.handler, update.priority) s.registerHandler([]string{update.domain}, update.handler, update.priority)
muxUpdateMap[update.handler.ID()] = update
} }
s.dnsMuxMap = muxUpdateMap s.dnsMuxHandlers = muxUpdates
} }
// updateNSGroupStates records the new group set and pokes the refresher. // updateNSGroupStates records the new group set and pokes the refresher.
@@ -1154,6 +1160,26 @@ func (s *DefaultServer) projectUnhealthy(p *nsGroupProj, servers []netip.AddrPor
return false return false
} }
// warningDelayBaseFromEnv returns the base grace window, honoring
// envWarningDelay when it holds a valid positive Go duration. Invalid or
// non-positive values fall back to defaultWarningDelayBase.
func warningDelayBaseFromEnv() time.Duration {
val := os.Getenv(envWarningDelay)
if val == "" {
return defaultWarningDelayBase
}
d, err := time.ParseDuration(val)
if err != nil {
log.Warnf("invalid %s value %q, using default %v: %v", envWarningDelay, val, defaultWarningDelayBase, err)
return defaultWarningDelayBase
}
if d <= 0 {
log.Warnf("%s must be positive, got %v, using default %v", envWarningDelay, d, defaultWarningDelayBase)
return defaultWarningDelayBase
}
return d
}
// warningDelay returns the grace window for the given selected-route // warningDelay returns the grace window for the given selected-route
// count. Scales gently: +1s per 100 routes, capped by // count. Scales gently: +1s per 100 routes, capped by
// warningDelayBonusCap. Parallel handshakes mean handshake time grows // warningDelayBonusCap. Parallel handshakes mean handshake time grows
@@ -1204,7 +1230,7 @@ func (s *DefaultServer) groupHasImmediateUpstream(servers []netip.AddrPort, snap
// in more than one handler. // in more than one handler.
func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth { func (s *DefaultServer) collectUpstreamHealth() map[netip.AddrPort]UpstreamHealth {
merged := make(map[netip.AddrPort]UpstreamHealth) merged := make(map[netip.AddrPort]UpstreamHealth)
for _, entry := range s.dnsMuxMap { for _, entry := range s.dnsMuxHandlers {
reporter, ok := entry.handler.(upstreamHealthReporter) reporter, ok := entry.handler.(upstreamHealthReporter)
if !ok { if !ok {
continue continue

View File

@@ -104,19 +104,6 @@ func init() {
formatter.SetTextFormatter(log.StandardLogger()) formatter.SetTextFormatter(log.StandardLogger())
} }
func generateDummyHandler(d string, servers []nbdns.NameServer) *upstreamResolverBase {
var srvs []netip.AddrPort
for _, srv := range servers {
srvs = append(srvs, srv.AddrPort())
}
u := &upstreamResolverBase{
domain: domain.Domain(d),
cancel: func() {},
}
u.addRace(srvs)
return u
}
func TestUpdateDNSServer(t *testing.T) { func TestUpdateDNSServer(t *testing.T) {
nameServers := []nbdns.NameServer{ nameServers := []nbdns.NameServer{
@@ -132,22 +119,20 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
} }
dummyHandler := local.NewResolver()
testCases := []struct { testCases := []struct {
name string name string
initUpstreamMap registeredHandlerMap initUpstreamMap []handlerWrapper
initLocalZones []nbdns.CustomZone initLocalZones []nbdns.CustomZone
initSerial uint64 initSerial uint64
inputSerial uint64 inputSerial uint64
inputUpdate nbdns.Config inputUpdate nbdns.Config
shouldFail bool shouldFail bool
expectedUpstreamMap registeredHandlerMap expectedUpstreamMap []handlerWrapper
expectedLocalQs []dns.Question expectedLocalQs []dns.Question
}{ }{
{ {
name: "Initial Config Should Succeed", name: "Initial Config Should Succeed",
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: nil,
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
@@ -169,20 +154,17 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
}, },
expectedUpstreamMap: registeredHandlerMap{ expectedUpstreamMap: []handlerWrapper{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ {
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler,
priority: PriorityUpstream, priority: PriorityUpstream,
}, },
dummyHandler.ID(): handlerWrapper{ {
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityLocal, priority: PriorityLocal,
}, },
generateDummyHandler(".", nameServers).ID(): handlerWrapper{ {
domain: nbdns.RootZone, domain: nbdns.RootZone,
handler: dummyHandler,
priority: PriorityDefault, priority: PriorityDefault,
}, },
}, },
@@ -191,10 +173,10 @@ func TestUpdateDNSServer(t *testing.T) {
{ {
name: "New Config Should Succeed", name: "New Config Should Succeed",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}}, initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: 1, Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: []handlerWrapper{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ {
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: &mockHandler{},
priority: PriorityUpstream, priority: PriorityUpstream,
}, },
}, },
@@ -215,15 +197,13 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
}, },
expectedUpstreamMap: registeredHandlerMap{ expectedUpstreamMap: []handlerWrapper{
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ {
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler,
priority: PriorityUpstream, priority: PriorityUpstream,
}, },
"local-resolver": handlerWrapper{ {
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler,
priority: PriorityLocal, priority: PriorityLocal,
}, },
}, },
@@ -232,7 +212,7 @@ func TestUpdateDNSServer(t *testing.T) {
{ {
name: "Smaller Config Serial Should Be Skipped", name: "Smaller Config Serial Should Be Skipped",
initLocalZones: []nbdns.CustomZone{}, initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: nil,
initSerial: 2, initSerial: 2,
inputSerial: 1, inputSerial: 1,
shouldFail: true, shouldFail: true,
@@ -240,7 +220,7 @@ func TestUpdateDNSServer(t *testing.T) {
{ {
name: "Empty NS Group Domain Or Not Primary Element Should Fail", name: "Empty NS Group Domain Or Not Primary Element Should Fail",
initLocalZones: []nbdns.CustomZone{}, initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: nil,
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
@@ -262,7 +242,7 @@ func TestUpdateDNSServer(t *testing.T) {
{ {
name: "Invalid NS Group Nameservers list Should Fail", name: "Invalid NS Group Nameservers list Should Fail",
initLocalZones: []nbdns.CustomZone{}, initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: nil,
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
@@ -284,7 +264,7 @@ func TestUpdateDNSServer(t *testing.T) {
{ {
name: "Invalid Custom Zone Records list Should Skip", name: "Invalid Custom Zone Records list Should Skip",
initLocalZones: []nbdns.CustomZone{}, initLocalZones: []nbdns.CustomZone{},
initUpstreamMap: make(registeredHandlerMap), initUpstreamMap: nil,
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ inputUpdate: nbdns.Config{
@@ -301,42 +281,41 @@ func TestUpdateDNSServer(t *testing.T) {
}, },
}, },
}, },
expectedUpstreamMap: registeredHandlerMap{generateDummyHandler(".", nameServers).ID(): handlerWrapper{ expectedUpstreamMap: []handlerWrapper{{
domain: ".", domain: ".",
handler: dummyHandler,
priority: PriorityDefault, priority: PriorityDefault,
}}, }},
}, },
{ {
name: "Empty Config Should Succeed and Clean Maps", name: "Empty Config Should Succeed and Clean Maps",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}}, initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: []handlerWrapper{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ {
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: &mockHandler{},
priority: PriorityUpstream, priority: PriorityUpstream,
}, },
}, },
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: true}, inputUpdate: nbdns.Config{ServiceEnable: true},
expectedUpstreamMap: make(registeredHandlerMap), expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{}, expectedLocalQs: []dns.Question{},
}, },
{ {
name: "Disabled Service Should clean map", name: "Disabled Service Should clean map",
initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}}, initLocalZones: []nbdns.CustomZone{{Domain: "netbird.cloud", Records: []nbdns.SimpleRecord{{Name: "netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}}}},
initUpstreamMap: registeredHandlerMap{ initUpstreamMap: []handlerWrapper{
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ {
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: &mockHandler{},
priority: PriorityUpstream, priority: PriorityUpstream,
}, },
}, },
initSerial: 0, initSerial: 0,
inputSerial: 1, inputSerial: 1,
inputUpdate: nbdns.Config{ServiceEnable: false}, inputUpdate: nbdns.Config{ServiceEnable: false},
expectedUpstreamMap: make(registeredHandlerMap), expectedUpstreamMap: nil,
expectedLocalQs: []dns.Question{}, expectedLocalQs: []dns.Question{},
}, },
} }
@@ -393,7 +372,7 @@ func TestUpdateDNSServer(t *testing.T) {
} }
}() }()
dnsServer.dnsMuxMap = testCase.initUpstreamMap dnsServer.dnsMuxHandlers = testCase.initUpstreamMap
dnsServer.localResolver.Update(testCase.initLocalZones) dnsServer.localResolver.Update(testCase.initLocalZones)
dnsServer.updateSerial = testCase.initSerial dnsServer.updateSerial = testCase.initSerial
@@ -405,14 +384,20 @@ func TestUpdateDNSServer(t *testing.T) {
t.Fatalf("update dns server should not fail, got error: %v", err) t.Fatalf("update dns server should not fail, got error: %v", err)
} }
if len(dnsServer.dnsMuxMap) != len(testCase.expectedUpstreamMap) { if len(dnsServer.dnsMuxHandlers) != len(testCase.expectedUpstreamMap) {
t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxMap)) t.Fatalf("update upstream failed, map size is different than expected, want %d, got %d", len(testCase.expectedUpstreamMap), len(dnsServer.dnsMuxHandlers))
} }
for key := range testCase.expectedUpstreamMap { for _, expected := range testCase.expectedUpstreamMap {
_, found := dnsServer.dnsMuxMap[key] found := false
for _, got := range dnsServer.dnsMuxHandlers {
if got.domain == expected.domain && got.priority == expected.priority {
found = true
break
}
}
if !found { if !found {
t.Fatalf("update upstream failed, key %s was not found in the dnsMuxMap: %#v", key, dnsServer.dnsMuxMap) t.Fatalf("update upstream failed, handler for domain=%s priority=%d not found in dnsMuxHandlers: %#v", expected.domain, expected.priority, dnsServer.dnsMuxHandlers)
} }
} }
@@ -512,8 +497,8 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
} }
}() }()
dnsServer.dnsMuxMap = registeredHandlerMap{ dnsServer.dnsMuxHandlers = []handlerWrapper{
"id1": handlerWrapper{ {
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: &local.Resolver{}, handler: &local.Resolver{},
priority: PriorityUpstream, priority: PriorityUpstream,
@@ -1029,15 +1014,15 @@ func (m *mockService) RegisterMux(string, dns.Handler) {}
func (m *mockService) DeregisterMux(string) {} func (m *mockService) DeregisterMux(string) {}
func TestDefaultServer_UpdateMux(t *testing.T) { func TestDefaultServer_UpdateMux(t *testing.T) {
baseMatchHandlers := registeredHandlerMap{ baseMatchHandlers := []handlerWrapper{
"upstream-group1": { {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityUpstream, priority: PriorityUpstream,
}, },
"upstream-group2": { {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
@@ -1046,15 +1031,15 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
}, },
} }
baseRootHandlers := registeredHandlerMap{ baseRootHandlers := []handlerWrapper{
"upstream-root1": { {
domain: ".", domain: ".",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-root1", Id: "upstream-root1",
}, },
priority: PriorityDefault, priority: PriorityDefault,
}, },
"upstream-root2": { {
domain: ".", domain: ".",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-root2", Id: "upstream-root2",
@@ -1063,22 +1048,22 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
}, },
} }
baseMixedHandlers := registeredHandlerMap{ baseMixedHandlers := []handlerWrapper{
"upstream-group1": { {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityUpstream, priority: PriorityUpstream,
}, },
"upstream-group2": { {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityUpstream - 1, priority: PriorityUpstream - 1,
}, },
"upstream-other": { {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
@@ -1089,7 +1074,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
initialHandlers registeredHandlerMap initialHandlers []handlerWrapper
updates []handlerWrapper updates []handlerWrapper
expectedHandlers map[string]string // map[HandlerID]domain expectedHandlers map[string]string // map[HandlerID]domain
description string description string
@@ -1373,32 +1358,38 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
server := &DefaultServer{ server := &DefaultServer{
dnsMuxMap: tt.initialHandlers, dnsMuxHandlers: tt.initialHandlers,
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
service: &mockService{}, service: &mockService{},
} }
// Perform the update // Perform the update
server.updateMux(tt.updates) server.updateMux(tt.updates)
// Verify the results // Verify the results
assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxMap), assert.Equal(t, len(tt.expectedHandlers), len(server.dnsMuxHandlers),
"Number of handlers after update doesn't match expected") "Number of handlers after update doesn't match expected")
// Check each expected handler // Check each expected handler
for id, expectedDomain := range tt.expectedHandlers { for id, expectedDomain := range tt.expectedHandlers {
handler, exists := server.dnsMuxMap[types.HandlerID(id)] var found *handlerWrapper
assert.True(t, exists, "Expected handler %s not found", id) for i := range server.dnsMuxHandlers {
if exists { if server.dnsMuxHandlers[i].handler.ID() == types.HandlerID(id) {
assert.Equal(t, expectedDomain, handler.domain, found = &server.dnsMuxHandlers[i]
break
}
}
assert.NotNil(t, found, "Expected handler %s not found", id)
if found != nil {
assert.Equal(t, expectedDomain, found.domain,
"Domain mismatch for handler %s", id) "Domain mismatch for handler %s", id)
} }
} }
// Verify no unexpected handlers exist // Verify no unexpected handlers exist
for HandlerID := range server.dnsMuxMap { for _, entry := range server.dnsMuxHandlers {
_, expected := tt.expectedHandlers[string(HandlerID)] _, expected := tt.expectedHandlers[string(entry.handler.ID())]
assert.True(t, expected, "Unexpected handler found: %s", HandlerID) assert.True(t, expected, "Unexpected handler found: %s", entry.handler.ID())
} }
// Verify the handlerChain state and order // Verify the handlerChain state and order
@@ -1413,7 +1404,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
// Verify handler exists in mux // Verify handler exists in mux
foundInMux := false foundInMux := false
for _, muxEntry := range server.dnsMuxMap { for _, muxEntry := range server.dnsMuxHandlers {
if chainEntry.Handler == muxEntry.handler && if chainEntry.Handler == muxEntry.handler &&
chainEntry.Priority == muxEntry.priority && chainEntry.Priority == muxEntry.priority &&
chainEntry.Pattern == dns.Fqdn(muxEntry.domain) { chainEntry.Pattern == dns.Fqdn(muxEntry.domain) {
@@ -1422,12 +1413,108 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
} }
} }
assert.True(t, foundInMux, assert.True(t, foundInMux,
"Handler in chain not found in dnsMuxMap") "Handler in chain not found in dnsMuxHandlers")
} }
}) })
} }
} }
// chainHasPattern reports whether the handler chain holds an entry registered
// for the given fqdn pattern at the given priority.
func chainHasPattern(s *DefaultServer, pattern string, priority int) bool {
for _, h := range s.handlerChain.handlers {
if h.OrigPattern == pattern && h.Priority == priority {
return true
}
}
return false
}
// TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval verifies that updateMux
// tracks each (handler, domain) registration independently when one handler
// serves multiple zones. Every custom zone is served by the same handler
// instance (the local resolver, whose ID is the constant "local-resolver"), so
// removing one zone must deregister exactly that zone's chain entry and leave
// the others in place. Tracking registrations by handler ID alone collapses all
// zones onto one entry, leaving removed zones in the chain to answer
// authoritatively with no records.
func TestDefaultServer_UpdateMux_SharedHandlerZoneRemoval(t *testing.T) {
// One handler serves every custom zone, mirroring s.localResolver.
shared := &mockHandler{Id: "local-resolver"}
server := &DefaultServer{
handlerChain: NewHandlerChain(),
service: &mockService{},
}
// Two custom zones under the same handler. The surviving zone is registered
// last, mirroring the management emission order.
server.updateMux([]handlerWrapper{
{domain: "userzone.test", handler: shared, priority: PriorityLocal},
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
})
require.True(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
"userzone.test should be registered after the first update")
require.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
"peerzone.test should be registered after the first update")
// Remove one zone, keep the other.
server.updateMux([]handlerWrapper{
{domain: "peerzone.test", handler: shared, priority: PriorityLocal},
})
assert.True(t, chainHasPattern(server, "peerzone.test.", PriorityLocal),
"peerzone.test should remain after removing userzone.test")
assert.False(t, chainHasPattern(server, "userzone.test.", PriorityLocal),
"userzone.test handler must be deregistered, not leaked in the chain")
}
// TestDefaultServer_UpdateMux_PreservesLocalResolver verifies that updateMux
// does not tear down the shared local resolver during reconfiguration. The
// resolver is a process-lifetime singleton reused across config updates;
// Stop() cancels its lookup context (breaking external CNAME-target
// resolution) and clears its records. updateMux must deregister its chain
// entries without stopping it. Records surviving a teardown update is the
// observable proxy: Stop() would have cleared them.
func TestDefaultServer_UpdateMux_PreservesLocalResolver(t *testing.T) {
resolver := local.NewResolver()
require.NoError(t, resolver.RegisterRecord(nbdns.SimpleRecord{
Name: "peer.netbird.cloud.",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "10.0.0.1",
}))
server := &DefaultServer{
handlerChain: NewHandlerChain(),
service: &mockService{},
localResolver: resolver,
}
server.updateMux([]handlerWrapper{
{domain: "netbird.cloud", handler: resolver, priority: PriorityLocal},
})
// Remove the zone. The resolver must survive so its records and lookup
// context stay intact for the next registration.
server.updateMux(nil)
var response *dns.Msg
resolver.ServeDNS(&test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
response = m
return nil
},
}, &dns.Msg{Question: []dns.Question{{Name: "peer.netbird.cloud.", Qtype: dns.TypeA, Qclass: dns.ClassINET}}})
require.NotNil(t, response, "local resolver should answer after teardown")
assert.Equal(t, dns.RcodeSuccess, response.Rcode,
"local resolver records must survive teardown; updateMux must not Stop() the shared resolver")
assert.NotEmpty(t, response.Answer, "answer should contain the surviving record")
}
func TestExtraDomains(t *testing.T) { func TestExtraDomains(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -2049,7 +2136,6 @@ func TestBuildUpstreamHandler_MergesGroupsPerDomain(t *testing.T) {
localResolver: local.NewResolver(), localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{}, hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
} }
groups := []*nbdns.NameServerGroup{ groups := []*nbdns.NameServerGroup{
@@ -2207,7 +2293,7 @@ func TestEvaluateNSGroupHealth(t *testing.T) {
} }
} }
// healthStubHandler is a minimal dnsMuxMap entry that exposes a fixed // healthStubHandler is a minimal dnsMuxHandlers entry that exposes a fixed
// UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates // UpstreamHealth snapshot, letting tests drive recomputeNSGroupStates
// without spinning up real handlers. // without spinning up real handlers.
type healthStubHandler struct { type healthStubHandler struct {
@@ -2283,12 +2369,11 @@ func newProjTestFixture(t *testing.T) *projTestFixture {
ctx: context.Background(), ctx: context.Background(),
wgInterface: &mocWGIface{}, wgInterface: &mocWGIface{},
statusRecorder: recorder, statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return fx.selected }, selectedRoutes: func() route.HAMap { return fx.selected },
activeRoutes: func() route.HAMap { return fx.active }, activeRoutes: func() route.HAMap { return fx.active },
warningDelayBase: defaultWarningDelayBase, warningDelayBase: defaultWarningDelayBase,
} }
fx.server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: fx.stub, priority: PriorityUpstream} fx.server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: fx.stub, priority: PriorityUpstream}}
fx.server.mux.Lock() fx.server.mux.Lock()
fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group}) fx.server.updateNSGroupStates([]*nbdns.NameServerGroup{fx.group})
@@ -2395,7 +2480,6 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
ctx: context.Background(), ctx: context.Background(),
wgInterface: &mocWGIface{}, wgInterface: &mocWGIface{},
statusRecorder: recorder, statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return nil }, selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil }, activeRoutes: func() route.HAMap { return nil },
warningDelayBase: 50 * time.Millisecond, warningDelayBase: 50 * time.Millisecond,
@@ -2407,7 +2491,7 @@ func TestProjection_OverlayAddrNoRouteDelaysWarning(t *testing.T) {
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{ stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{
overlayPeer: {LastFail: time.Now(), LastErr: "timeout"}, overlayPeer: {LastFail: time.Now(), LastErr: "timeout"},
}} }}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.mux.Lock() server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2444,7 +2528,6 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
service: NewServiceViaMemory(wgIface), service: NewServiceViaMemory(wgIface),
hostManager: &noopHostConfigurator{}, hostManager: &noopHostConfigurator{},
extraDomains: map[domain.Domain]int{}, extraDomains: map[domain.Domain]int{},
dnsMuxMap: make(registeredHandlerMap),
statusRecorder: peer.NewRecorder("mgm"), statusRecorder: peer.NewRecorder("mgm"),
selectedRoutes: func() route.HAMap { return nil }, selectedRoutes: func() route.HAMap { return nil },
activeRoutes: func() route.HAMap { return nil }, activeRoutes: func() route.HAMap { return nil },
@@ -2459,7 +2542,7 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}}, NameServers: []nbdns.NameServer{{IP: srv.Addr(), NSType: nbdns.UDPNameServerType, Port: int(srv.Port())}},
} }
stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}} stub := &healthStubHandler{health: map[netip.AddrPort]UpstreamHealth{srv: {LastOk: time.Now()}}}
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.mux.Lock() server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2484,6 +2567,32 @@ func TestProjection_StopClearsHealthState(t *testing.T) {
// rule 3: startup failures while the peer is handshaking, then the peer // rule 3: startup failures while the peer is handshaking, then the peer
// comes up and a query succeeds before the grace window elapses. No // comes up and a query succeeds before the grace window elapses. No
// warning should ever have fired, and no recovery either. // warning should ever have fired, and no recovery either.
func TestWarningDelayBaseFromEnv(t *testing.T) {
tests := []struct {
name string
set bool
val string
want time.Duration
}{
{name: "unset uses default", set: false, want: defaultWarningDelayBase},
{name: "valid override", set: true, val: "90s", want: 90 * time.Second},
{name: "valid minutes", set: true, val: "2m", want: 2 * time.Minute},
{name: "invalid falls back", set: true, val: "notaduration", want: defaultWarningDelayBase},
{name: "zero falls back", set: true, val: "0s", want: defaultWarningDelayBase},
{name: "negative falls back", set: true, val: "-30s", want: defaultWarningDelayBase},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envWarningDelay, tc.val)
if !tc.set {
os.Unsetenv(envWarningDelay)
}
assert.Equal(t, tc.want, warningDelayBaseFromEnv(), "grace window base")
})
}
}
func TestProjection_OverlayRecoversDuringGrace(t *testing.T) { func TestProjection_OverlayRecoversDuringGrace(t *testing.T) {
fx := newProjTestFixture(t) fx := newProjTestFixture(t)
fx.server.warningDelayBase = 200 * time.Millisecond fx.server.warningDelayBase = 200 * time.Millisecond
@@ -2595,7 +2704,6 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
server := &DefaultServer{ server := &DefaultServer{
ctx: context.Background(), ctx: context.Background(),
statusRecorder: recorder, statusRecorder: recorder,
dnsMuxMap: make(registeredHandlerMap),
selectedRoutes: func() route.HAMap { return overlayMap }, selectedRoutes: func() route.HAMap { return overlayMap },
activeRoutes: func() route.HAMap { return nil }, activeRoutes: func() route.HAMap { return nil },
warningDelayBase: time.Hour, warningDelayBase: time.Hour,
@@ -2613,7 +2721,7 @@ func TestProjection_MixedGroupEmitsImmediately(t *testing.T) {
overlay: {LastFail: time.Now(), LastErr: "timeout"}, overlay: {LastFail: time.Now(), LastErr: "timeout"},
}, },
} }
server.dnsMuxMap["example.com"] = handlerWrapper{domain: "example.com", handler: stub, priority: PriorityUpstream} server.dnsMuxHandlers = []handlerWrapper{{domain: "example.com", handler: stub, priority: PriorityUpstream}}
server.mux.Lock() server.mux.Lock()
server.updateNSGroupStates([]*nbdns.NameServerGroup{group}) server.updateNSGroupStates([]*nbdns.NameServerGroup{group})
@@ -2640,7 +2748,6 @@ func TestDNSLoopPrevention(t *testing.T) {
localResolver: local.NewResolver(), localResolver: local.NewResolver(),
handlerChain: NewHandlerChain(), handlerChain: NewHandlerChain(),
hostManager: &noopHostConfigurator{}, hostManager: &noopHostConfigurator{},
dnsMuxMap: make(registeredHandlerMap),
} }
tests := []struct { tests := []struct {

View File

@@ -443,29 +443,32 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, r *dns.M
return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"} return raceResult{}, &upstreamFailure{upstream: upstream, reason: "no response"}
} }
// A valid response means the upstream is reachable, whatever the Rcode.
u.markUpstreamOk(upstream)
proto := "" proto := ""
if upstreamProto != nil { if upstreamProto != nil {
proto = upstreamProto.protocol proto = upstreamProto.protocol
} }
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused { if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
// SERVFAIL and REFUSED are per-question outcomes (DNSSEC-bogus names,
// refused zones, transient recursion errors), not reachability
// problems: fail over for a better answer but keep the upstream healthy.
if code, ok := nonRetryableEDE(rm); ok { if code, ok := nonRetryableEDE(rm); ok {
if !hadEdns { if !hadEdns {
stripOPT(rm) resutil.StripOPT(rm)
} }
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil return raceResult{msg: rm, upstream: upstream, protocol: proto, ede: edeName(code)}, nil
} }
reason := dns.RcodeToString[rm.Rcode] reason := dns.RcodeToString[rm.Rcode]
u.markUpstreamFail(upstream, reason)
return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason} return raceResult{}, &upstreamFailure{upstream: upstream, reason: reason}
} }
if !hadEdns { if !hadEdns {
stripOPT(rm) resutil.StripOPT(rm)
} }
u.markUpstreamOk(upstream)
return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil return raceResult{msg: rm, upstream: upstream, protocol: proto}, nil
} }
@@ -520,22 +523,6 @@ func upstreamUDPSize() uint16 {
return dns.MinMsgSize return dns.MinMsgSize
} }
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
// the response complies with RFC 6891 when the client did not advertise EDNS0.
func stripOPT(rm *dns.Msg) {
if len(rm.Extra) == 0 {
return
}
out := rm.Extra[:0]
for _, rr := range rm.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
rm.Extra = out
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure { func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) { if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
return &upstreamFailure{upstream: upstream, reason: err.Error()} return &upstreamFailure{upstream: upstream, reason: err.Error()}

View File

@@ -517,6 +517,78 @@ func TestUpstreamResolver_HealthTracking(t *testing.T) {
assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers") assert.NotContains(t, health, bad, "sibling upstream should not be queried when primary answers")
} }
// TestUpstreamResolver_HealthTracking_ResponseMeansReachable verifies that an
// upstream which answers with SERVFAIL or REFUSED is recorded as healthy:
// those are per-question outcomes from a reachable server and must not mark
// the upstream unhealthy. Only transport failures (timeouts) do.
func TestUpstreamResolver_HealthTracking_ResponseMeansReachable(t *testing.T) {
a := netip.MustParseAddrPort("192.0.2.10:53")
b := netip.MustParseAddrPort("192.0.2.11:53")
timeoutErr := &net.OpError{Op: "read", Err: fmt.Errorf("i/o timeout")}
tests := []struct {
name string
respA mockUpstreamResponse
respB mockUpstreamResponse
wantHealthy bool
}{
{
name: "both SERVFAIL are reachable",
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeServerFailure, "")},
wantHealthy: true,
},
{
name: "both REFUSED are reachable",
respA: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
respB: mockUpstreamResponse{msg: buildMockResponse(dns.RcodeRefused, "")},
wantHealthy: true,
},
{
name: "timeout marks unhealthy",
respA: mockUpstreamResponse{err: timeoutErr},
respB: mockUpstreamResponse{err: timeoutErr},
wantHealthy: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
mockClient := &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
a.String(): tc.respA,
b.String(): tc.respB,
},
rtt: time.Millisecond,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: mockClient,
upstreamTimeout: UpstreamTimeout,
}
resolver.addRace([]netip.AddrPort{a, b})
responseWriter := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { return nil }}
resolver.ServeDNS(responseWriter, new(dns.Msg).SetQuestion("example.com.", dns.TypeA))
health := resolver.UpstreamHealth()
require.Contains(t, health, a, "primary upstream should have a health record")
if tc.wantHealthy {
assert.False(t, health[a].LastOk.IsZero(), "responding upstream should have LastOk set")
assert.True(t, health[a].LastFail.IsZero(), "responding upstream should not be marked failed")
assert.Empty(t, health[a].LastErr, "responding upstream should have no error")
} else {
assert.False(t, health[a].LastFail.IsZero(), "timed-out upstream should be marked failed")
assert.NotEmpty(t, health[a].LastErr, "timed-out upstream should record an error")
}
})
}
}
func TestFormatFailures(t *testing.T) { func TestFormatFailures(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
@@ -913,19 +985,6 @@ func TestEDEName(t *testing.T) {
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric") assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
} }
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
stripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) { func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
upstream1 := netip.MustParseAddrPort("192.0.2.1:53") upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
upstream2 := netip.MustParseAddrPort("192.0.2.2:53") upstream2 := netip.MustParseAddrPort("192.0.2.2:53")

View File

@@ -26,6 +26,15 @@ import (
const errResolveFailed = "failed to resolve query for domain=%s: %v" const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second const upstreamTimeout = 15 * time.Second
// EDE info codes the forwarder emits on upstream failures so the querying
// client can see the reason without inspecting this peer's logs. They live in
// the RFC 8914 Private Use range (49152-65535); the Go resolver never exposes a
// real upstream EDE here, so these cannot collide with a genuine code.
const (
edeNetbirdUpstreamTimeout uint16 = 49152
edeNetbirdUpstreamFailure uint16 = 49153
)
type resolver interface { type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
} }
@@ -220,7 +229,7 @@ func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, q
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype) result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
if result.Err != nil { if result.Err != nil {
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime) f.handleDNSError(ctx, logger, w, question, resp, qname, result, query.IsEdns0() != nil, startTime)
return return
} }
@@ -333,6 +342,7 @@ func (f *DNSForwarder) handleDNSError(
resp *dns.Msg, resp *dns.Msg,
domain string, domain string,
result resutil.LookupResult, result resutil.LookupResult,
reqHasEdns bool,
startTime time.Time, startTime time.Time,
) { ) {
qType := question.Qtype qType := question.Qtype
@@ -374,6 +384,10 @@ func (f *DNSForwarder) handleDNSError(
logger.Warnf(errResolveFailed, domain, result.Err) logger.Warnf(errResolveFailed, domain, result.Err)
} }
if reqHasEdns {
attachEDE(resp, edeCodeFor(dnsErr), edeText(dnsErr))
}
f.writeResponse(logger, w, resp, domain, startTime) f.writeResponse(logger, w, resp, domain, startTime)
} }
@@ -414,3 +428,33 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
return selectedResId, matches return selectedResId, matches
} }
// edeCodeFor maps an upstream lookup error to the NetBird EDE info code.
func edeCodeFor(dnsErr *net.DNSError) uint16 {
if dnsErr != nil && dnsErr.IsTimeout {
return edeNetbirdUpstreamTimeout
}
return edeNetbirdUpstreamFailure
}
// edeText builds the EDE extra-text describing the class of upstream failure.
// It deliberately omits the upstream server address, which may be an internal
// resolver and is exposed to any client permitted to use the route; the full
// detail stays in the forwarder's local log.
func edeText(dnsErr *net.DNSError) string {
if dnsErr != nil && dnsErr.IsTimeout {
return "netbird forwarder: upstream timeout"
}
return "netbird forwarder: upstream failure"
}
// attachEDE adds an Extended DNS Error (RFC 8914) option to the response,
// creating the OPT pseudo-record if the response does not already carry one.
func attachEDE(resp *dns.Msg, code uint16, text string) {
opt := resp.IsEdns0()
if opt == nil {
resp.SetEdns0(dns.DefaultMsgSize, false)
opt = resp.IsEdns0()
}
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: code, ExtraText: text})
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -617,6 +618,85 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
} }
} }
func TestDNSForwarder_UpstreamFailureEDE(t *testing.T) {
tests := []struct {
name string
lookupErr error
reqEdns bool
wantEDE bool
wantCode uint16
wantTextHas string
}{
{
name: "timeout with edns0",
lookupErr: &net.DNSError{Err: "i/o timeout", Server: "10.0.0.53:53", IsTimeout: true},
reqEdns: true,
wantEDE: true,
wantCode: edeNetbirdUpstreamTimeout,
wantTextHas: "netbird forwarder: upstream timeout",
},
{
name: "server failure with edns0",
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
reqEdns: true,
wantEDE: true,
wantCode: edeNetbirdUpstreamFailure,
wantTextHas: "netbird forwarder: upstream failure",
},
{
name: "no edns0 in request omits ede",
lookupErr: &net.DNSError{Err: "server misbehaving", Server: "10.0.0.53:53"},
reqEdns: false,
wantEDE: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
forwarder.resolver = mockResolver
d, err := domain.FromString("example.com")
require.NoError(t, err)
forwarder.UpdateDomains([]*ForwarderEntry{{Domain: d, ResID: "test-res"}})
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return([]netip.Addr(nil), tt.lookupErr).Once()
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
if tt.reqEdns {
query.SetEdns0(dns.DefaultMsgSize, false)
}
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
mockResolver.AssertExpectations(t)
require.NotNil(t, writtenResp, "expected a response")
assert.Equal(t, dns.RcodeServerFailure, writtenResp.Rcode, "upstream failure must be SERVFAIL")
ede, ok := resutil.ExtractEDE(writtenResp)
if !tt.wantEDE {
assert.False(t, ok, "response must not carry EDE")
return
}
require.True(t, ok, "response must carry EDE")
assert.Equal(t, tt.wantCode, ede.InfoCode, "EDE info code")
assert.Contains(t, ede.ExtraText, tt.wantTextHas, "EDE extra-text")
assert.NotContains(t, ede.ExtraText, "10.0.0.53", "must not leak upstream server address")
})
}
}
func TestDNSForwarder_TCPTruncation(t *testing.T) { func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set // Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{} mockResolver := &MockResolver{}

View File

@@ -86,6 +86,8 @@ const (
var ErrResetConnection = fmt.Errorf("reset connection") var ErrResetConnection = fmt.Errorf("reset connection")
var ErrEngineAlreadyStarted = errors.New("engine already started")
type EngineConfig struct { type EngineConfig struct {
WgPort int WgPort int
WgIfaceName string WgIfaceName string
@@ -199,6 +201,8 @@ type Engine struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
started bool
wgInterface WGIface wgInterface WGIface
udpMux *udpmux.UniversalUDPMuxDefault udpMux *udpmux.UniversalUDPMuxDefault
@@ -279,9 +283,15 @@ func NewEngine(
services EngineServices, services EngineServices,
mobileDep MobileDependency, mobileDep MobileDependency,
) *Engine { ) *Engine {
// The engine is single-use: a fresh instance is built per connection
// cycle (see Client.run), so the run context is created once here rather
// than in Start.
ctx, cancel := context.WithCancel(clientCtx)
engine := &Engine{ engine := &Engine{
clientCtx: clientCtx, clientCtx: clientCtx,
clientCancel: clientCancel, clientCancel: clientCancel,
ctx: ctx,
cancel: cancel,
signal: services.SignalClient, signal: services.SignalClient,
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey), signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
mgmClient: services.MgmClient, mgmClient: services.MgmClient,
@@ -314,8 +324,34 @@ func (e *Engine) Stop() error {
log.Debugf("tried stopping engine that is nil") log.Debugf("tried stopping engine that is nil")
return nil return nil
} }
e.cancel()
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
e.stopLocked()
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
}
// stopLocked tears down everything Start may have brought up, in the order
// teardown requires (DNS before the interface goes down, flow manager after).
// The caller must hold syncMsgMux. It is shared by Stop and by Start's failure
// path, so a partially-initialized engine is cleaned up the same way; every
// step is nil-guarded. It does not wait on shutdownWg — the caller does that
// after releasing the lock, since the goroutines also take syncMsgMux.
func (e *Engine) stopLocked() {
if e.connMgr != nil { if e.connMgr != nil {
e.connMgr.Close() e.connMgr.Close()
} }
@@ -366,10 +402,6 @@ func (e *Engine) Stop() error {
// so dbus and friends don't complain because of a missing interface // so dbus and friends don't complain because of a missing interface
e.stopDNSServer() e.stopDNSServer()
if e.cancel != nil {
e.cancel()
}
e.jobExecutorWG.Wait() // block until job goroutines finish e.jobExecutorWG.Wait() // block until job goroutines finish
e.close() e.close()
@@ -388,21 +420,6 @@ func (e *Engine) Stop() error {
if err := e.stateManager.PersistState(context.Background()); err != nil { if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err) log.Errorf("failed to persist state: %v", err)
} }
e.syncMsgMux.Unlock()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
} }
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s. // calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
@@ -440,18 +457,38 @@ func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here. // Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service // However, they will be established once an event with a list of peers to connect to will be received from Management Service
func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) error { func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) (err error) {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
if err := iface.ValidateMTU(e.config.MTU); err != nil { // The engine is single-use. Reject a duplicate start and a start on an
// already-stopped engine (run context cancelled).
if e.started {
return ErrEngineAlreadyStarted
}
if ctxErr := e.ctx.Err(); ctxErr != nil {
return fmt.Errorf("engine already stopped: %w", ctxErr)
}
e.started = true
// Tear down any partially-initialized state on a failed start. Cancel the
// run context first so goroutines started before the failure (connMgr,
// srWatcher, monitors) unwind, then stopLocked mirrors Stop's teardown (we
// already hold syncMsgMux), cleaning up route/DNS/flow/state managers too,
// not just what close() covers.
defer func() {
if err != nil {
e.cancel()
e.stopLocked()
}
}()
if err = iface.ValidateMTU(e.config.MTU); err != nil {
return fmt.Errorf("invalid MTU configuration: %w", err) return fmt.Errorf("invalid MTU configuration: %w", err)
} }
if e.cancel != nil {
e.cancel()
}
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient) e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
wgIface, err := e.newWgIface() wgIface, err := e.newWgIface()
@@ -485,13 +522,11 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings() initialRoutes, dnsConfig, dnsFeatureFlag, err := e.readInitialSettings()
if err != nil { if err != nil {
e.close()
return fmt.Errorf("read initial settings: %w", err) return fmt.Errorf("read initial settings: %w", err)
} }
dnsServer, err := e.newDnsServer(dnsConfig) dnsServer, err := e.newDnsServer(dnsConfig)
if err != nil { if err != nil {
e.close()
return fmt.Errorf("create dns server: %w", err) return fmt.Errorf("create dns server: %w", err)
} }
e.dnsServer = dnsServer e.dnsServer = dnsServer
@@ -526,7 +561,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
if err = e.wgInterfaceCreate(); err != nil { if err = e.wgInterfaceCreate(); err != nil {
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error()) log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
e.close()
return fmt.Errorf("create wg interface: %w", err) return fmt.Errorf("create wg interface: %w", err)
} }
@@ -535,7 +569,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
} }
if err := e.createFirewall(); err != nil { if err := e.createFirewall(); err != nil {
e.close()
return err return err
} }
@@ -547,7 +580,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.udpMux, err = e.wgInterface.Up() e.udpMux, err = e.wgInterface.Up()
if err != nil { if err != nil {
log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error()) log.Errorf("failed to pull up wgInterface [%s]: %s", e.wgInterface.Name(), err.Error())
e.close()
return fmt.Errorf("up wg interface: %w", err) return fmt.Errorf("up wg interface: %w", err)
} }
@@ -572,9 +604,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.acl = acl.NewDefaultManager(e.firewall) e.acl = acl.NewDefaultManager(e.firewall)
} }
err = e.dnsServer.Initialize() if err := e.dnsServer.Initialize(); err != nil {
if err != nil {
e.close()
return fmt.Errorf("initialize dns server: %w", err) return fmt.Errorf("initialize dns server: %w", err)
} }
@@ -586,7 +616,9 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start(peer.IsForceRelayed()) e.srWatcher.Start(peer.IsForceRelayed())
e.receiveSignalEvents() if err = e.receiveSignalEvents(); err != nil {
return err
}
e.receiveManagementEvents() e.receiveManagementEvents()
e.receiveJobEvents() e.receiveJobEvents()
@@ -638,7 +670,6 @@ func (e *Engine) createFirewall() error {
func (e *Engine) initFirewall() error { func (e *Engine) initFirewall() error {
if err := e.routeManager.SetFirewall(e.firewall); err != nil { if err := e.routeManager.SetFirewall(e.firewall); err != nil {
e.close()
return fmt.Errorf("set firewall: %w", err) return fmt.Errorf("set firewall: %w", err)
} }
@@ -1698,7 +1729,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
} }
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers // receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
func (e *Engine) receiveSignalEvents() { func (e *Engine) receiveSignalEvents() error {
e.shutdownWg.Add(1) e.shutdownWg.Add(1)
go func() { go func() {
defer e.shutdownWg.Done() defer e.shutdownWg.Done()
@@ -1769,7 +1800,12 @@ func (e *Engine) receiveSignalEvents() {
} }
}() }()
e.signal.WaitStreamConnected() // todo: consider to remove this blocker. I do not see benefit to block the Start operations
e.signal.WaitStreamConnected(e.ctx)
if err := e.ctx.Err(); err != nil {
return fmt.Errorf("wait for signal stream: %w", err)
}
return nil
} }
func (e *Engine) parseNATExternalIPMappings() []string { func (e *Engine) parseNATExternalIPMappings() []string {

View File

@@ -247,7 +247,7 @@ func TestEngine_SSH(t *testing.T) {
return return
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel() defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
@@ -426,7 +426,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
return return
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel() defer cancel()
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU) relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
@@ -638,7 +638,7 @@ func TestEngine_Sync(t *testing.T) {
return return
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel() defer cancel()
// feed updates to Engine via mocked Management client // feed updates to Engine via mocked Management client
@@ -817,7 +817,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
return return
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel() defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgIfaceName := fmt.Sprintf("utun%d", 104+n)
@@ -1024,7 +1024,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
return return
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(CtxInitState(context.Background()))
defer cancel() defer cancel()
wgIfaceName := fmt.Sprintf("utun%d", 104+n) wgIfaceName := fmt.Sprintf("utun%d", 104+n)

View File

@@ -251,6 +251,14 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
// Advertise EDNS0 to the forwarder so it may return an Extended DNS Error
// describing why a lookup failed. The OPT is stripped from the reply when
// the original client did not request EDNS0.
hadEdns := r.IsEdns0() != nil
if !hadEdns {
r.SetEdns0(dns.DefaultMsgSize, false)
}
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10)) upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel() defer cancel()
@@ -260,6 +268,13 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
if ede, ok := resutil.ExtractEDE(reply); ok {
resutil.SetMeta(w, "ede", fmt.Sprintf("%d %s", ede.InfoCode, ede.ExtraText))
}
if !hadEdns {
resutil.StripOPT(reply)
}
resutil.SetMeta(w, "peer", peerKey) resutil.SetMeta(w, "peer", peerKey)
reply.Id = r.Id reply.Id = r.Id

View File

@@ -36,6 +36,7 @@ type URLOpener interface {
// Auth can register or login new client // Auth can register or login new client
type Auth struct { type Auth struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc
config *profilemanager.Config config *profilemanager.Config
cfgPath string cfgPath string
} }
@@ -51,8 +52,19 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
return nil, err return nil, err
} }
// Use a cancellable context so Stop() can abort an in-progress interactive
// login. The PKCE flow's WaitToken blocks (and keeps its loopback HTTP server
// bound to a port) until the OAuth callback arrives or the flow expires;
// cancelling the context unblocks WaitToken, which then shuts that server down
// and frees the port for the next login attempt. iOS runs login in the main-app
// process (decoupled from the network extension), so without this the server
// lingers after the user dismisses the browser and the next connect stalls
// trying to bind the same port.
ctx, cancel := context.WithCancel(context.Background())
return &Auth{ return &Auth{
ctx: context.Background(), ctx: ctx,
cancel: cancel,
config: cfg, config: cfg,
cfgPath: cfgPath, cfgPath: cfgPath,
}, nil }, nil
@@ -60,12 +72,24 @@ func NewAuth(cfgPath string, mgmURL string) (*Auth, error) {
// NewAuthWithConfig instantiate Auth based on existing config // NewAuthWithConfig instantiate Auth based on existing config
func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth { func NewAuthWithConfig(ctx context.Context, config *profilemanager.Config) *Auth {
ctx, cancel := context.WithCancel(ctx)
return &Auth{ return &Auth{
ctx: ctx, ctx: ctx,
cancel: cancel,
config: config, config: config,
} }
} }
// Stop aborts an in-progress interactive login started via Login/LoginWithDeviceName.
// It cancels the auth context, which unblocks the PKCE WaitToken and shuts down its
// loopback HTTP server, freeing the redirect port. Safe to call multiple times and
// safe to call when no login is running.
func (a *Auth) Stop() {
if a.cancel != nil {
a.cancel()
}
}
// SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info. // SaveConfigIfSSOSupported test the connectivity with the management server by retrieving the server device flow info.
// If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO // If it returns a flow info than save the configuration and return true. If it gets a codes.NotFound, it means that SSO
// is not supported and returns false without saving the configuration. For other errors return false. // is not supported and returns false without saving the configuration. For other errors return false.

View File

@@ -993,6 +993,10 @@ func (s *Server) cleanupConnection() error {
return nil return nil
} }
// TODO: consider calling s.connectClient.Stop() instead of engine.Stop().
// actCancel() lets the run loop stop the engine too, so both stop it
// concurrently; ConnectClient.Stop cancels and waits for the run loop,
// making the run loop the sole owner of engine shutdown.
if engine != nil { if engine != nil {
if err := engine.Stop(); err != nil { if err := engine.Stop(); err != nil {
return err return err

91
combined/cmd/admin.go Normal file
View File

@@ -0,0 +1,91 @@
package cmd
import (
"context"
"fmt"
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
// newAdminCommands creates the admin command tree with combined-specific resource openers.
func newAdminCommands() *cobra.Command {
cmd := admincmd.NewCommands(withAdminResources)
cmd.AddCommand(tokencmd.NewCommands(withAdminTokenStore))
return cmd
}
// withAdminResources loads the combined YAML config, initializes stores, and calls fn.
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, cfg *CombinedConfig) error {
mgmtConfig, err := cfg.ToManagementConfig()
if err != nil {
return fmt.Errorf("create management config: %w", err)
}
idpStorage, err := admincmd.OpenEmbeddedIDPStorage(mgmtConfig.EmbeddedIdP)
if err != nil {
return err
}
defer func() {
if err := idpStorage.Close(); err != nil {
log.Debugf("close embedded IdP storage: %v", err)
}
}()
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage})
})
}
// withAdminTokenStore opens only the management store for admin token commands.
func withAdminTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, _ *CombinedConfig) error {
return fn(ctx, managementStore)
})
}
func withAdminStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store, cfg *CombinedConfig) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
if dsn := cfg.Server.Store.DSN; dsn != "" {
switch strings.ToLower(cfg.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := cfg.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
managementStore, err := store.NewStore(ctx, types.Engine(cfg.Management.Store.Engine), cfg.Management.DataDir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := managementStore.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, managementStore, cfg)
}

View File

@@ -64,7 +64,7 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)") rootCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "path to YAML configuration file (required)")
_ = rootCmd.MarkPersistentFlagRequired("config") _ = rootCmd.MarkPersistentFlagRequired("config")
rootCmd.AddCommand(newTokenCommands()) rootCmd.AddCommand(newAdminCommands())
} }
func RootCmd() *cobra.Command { func RootCmd() *cobra.Command {

View File

@@ -1,63 +0,0 @@
package cmd
import (
"context"
"fmt"
"os"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util"
)
// newTokenCommands creates the token command tree with combined-specific store opener.
func newTokenCommands() *cobra.Command {
return tokencmd.NewCommands(withTokenStore)
}
// withTokenStore loads the combined YAML config, initializes the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
cfg, err := LoadConfig(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
if dsn := cfg.Server.Store.DSN; dsn != "" {
switch strings.ToLower(cfg.Server.Store.Engine) {
case "postgres":
os.Setenv("NB_STORE_ENGINE_POSTGRES_DSN", dsn)
case "mysql":
os.Setenv("NB_STORE_ENGINE_MYSQL_DSN", dsn)
}
}
if file := cfg.Server.Store.File; file != "" {
os.Setenv("NB_STORE_ENGINE_SQLITE_FILE", file)
}
datadir := cfg.Management.DataDir
engine := types.Engine(cfg.Management.Store.Engine)
s, err := store.NewStore(ctx, engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}

View File

@@ -0,0 +1,616 @@
#!/bin/bash
set -e
set -o pipefail
# NetBird Enterprise — Getting Started
# Single-node bootstrap for a self-hosted NetBird Enterprise stack with the
# embedded identity provider. Owner is created via first-login flow.
SED_STRIP_PADDING='s/=//g'
check_docker_compose() {
if command -v docker-compose &> /dev/null; then
echo "docker-compose"
return
fi
if docker compose --help &> /dev/null; then
echo "docker compose"
return
fi
echo "docker-compose is not installed or not in PATH. See https://docs.docker.com/engine/install/" > /dev/stderr
exit 1
}
check_openssl() {
if ! command -v openssl &> /dev/null; then
echo "openssl is not installed or not in PATH." > /dev/stderr
exit 1
fi
}
rand_secret() {
openssl rand -base64 32 | sed "$SED_STRIP_PADDING"
}
rand_b64_key() {
openssl rand -base64 32
}
check_nb_domain() {
local domain="$1"
if [[ -z "$domain" ]]; then
echo "The domain cannot be empty." > /dev/stderr
return 1
fi
if [[ "$domain" == "netbird.example.com" ]]; then
echo "The domain cannot be netbird.example.com" > /dev/stderr
return 1
fi
if [[ "$domain" =~ ^[0-9.]+$ ]]; then
echo "An IP address is not allowed. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
return 1
fi
if [[ ! "$domain" =~ ^[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?(\.[A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)+$ ]]; then
echo "The value '$domain' is not a valid FQDN. A real DNS-resolvable domain is required for TLS and the embedded IdP issuer." > /dev/stderr
return 1
fi
return 0
}
check_domain_resolves() {
local domain="$1"
if command -v getent &> /dev/null && getent hosts "$domain" &> /dev/null; then return 0; fi
if command -v host &> /dev/null && host "$domain" &> /dev/null; then return 0; fi
if command -v dig &> /dev/null && [[ -n "$(dig +short "$domain" 2>/dev/null)" ]]; then return 0; fi
if command -v nslookup &> /dev/null && nslookup "$domain" &> /dev/null; then return 0; fi
return 1
}
read_nb_domain() {
local value=""
echo -n "Enter the FQDN for NetBird (must resolve via DNS, e.g. netbird.my-domain.com): " > /dev/stderr
read -r value < /dev/tty
if ! check_nb_domain "$value"; then
read_nb_domain
return
fi
if ! check_domain_resolves "$value"; then
echo "" > /dev/stderr
echo "Warning: '$value' does not resolve via DNS from this host." > /dev/stderr
echo "Caddy will not be able to issue TLS certificates until it does." > /dev/stderr
local confirm=""
echo -n "Continue anyway? [y/N]: " > /dev/stderr
read -r confirm < /dev/tty
if [[ ! "$confirm" =~ ^[Yy]$ ]]; then
read_nb_domain
return
fi
fi
echo "$value"
}
read_required() {
local prompt="$1"
local value=""
while [[ -z "$value" ]]; do
echo -n "$prompt: " > /dev/stderr
read -r value < /dev/tty
if [[ -z "$value" ]]; then
echo "Value cannot be empty." > /dev/stderr
fi
done
echo "$value"
}
read_secret() {
local prompt="$1"
local value=""
while [[ -z "$value" ]]; do
echo -n "$prompt: " > /dev/stderr
read -rs value < /dev/tty
echo "" > /dev/stderr
if [[ -z "$value" ]]; then
echo "Value cannot be empty." > /dev/stderr
fi
done
echo "$value"
}
# read_yes_no "<prompt>" [<default y|n>]
read_yes_no() {
local prompt="$1"
local default="${2:-n}"
local hint
if [[ "$default" == "y" ]]; then
hint="[Y/n]"
else
hint="[y/N]"
fi
echo -n "${prompt} ${hint}: " > /dev/stderr
local ans=""
read -r ans < /dev/tty
if [[ -z "$ans" ]]; then
ans="$default"
fi
case "$ans" in
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
*) echo "no" ;;
esac
}
wait_postgres() {
set +e
echo -n "Waiting for postgres to become ready"
local counter=1
while true; do
if $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U "$POSTGRES_USER" -d "$POSTGRES_DB" &> /dev/null; then
break
fi
if [[ $counter -eq 60 ]]; then
echo ""
echo "Postgres is taking too long. Recent logs:"
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
exit 1
fi
echo -n " ."
sleep 2
counter=$((counter + 1))
done
echo " done"
set -e
}
init_environment() {
check_openssl
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
if [[ -f .env ]] || [[ -f docker-compose.yml ]] || [[ -f config.yaml ]] || [[ -f Caddyfile ]]; then
echo "Generated files already exist in $(pwd)."
echo "If you want to reinitialize the environment, please remove them first:"
echo " $DOCKER_COMPOSE_COMMAND down --volumes # removes all containers and volumes"
echo " rm -f .env docker-compose.yml Caddyfile config.yaml"
echo "Be aware this will remove all data from the database."
exit 1
fi
echo "NetBird Enterprise bootstrap"
echo ""
echo "Traffic flow:"
echo " Enables traffic events logging on the management server."
echo " When enabled, the NetBird stack also runs NATS along with two"
echo " additional containers: netbird-receiver (the traffic log receiver"
echo " service) and netbird-enricher (the traffic log enricher service)."
echo " It still has to be turned on from the dashboard settings afterwards."
echo " See https://docs.netbird.io/manage/activity/traffic-events-logging"
NETBIRD_TRAFFIC_FLOW=$(read_yes_no "Enable traffic flow" "n")
echo ""
NETBIRD_DOMAIN=$(read_nb_domain)
echo ""
NETBIRD_LICENSE_KEY=$(read_secret "Enter license key (input hidden)")
GHCR_USERNAME="netbirdExtAccess1"
GHCR_TOKEN=$(read_secret "Enter GHCR token (input hidden)")
POSTGRES_USER="netbird"
POSTGRES_DB="netbird"
POSTGRES_PASSWORD=$(rand_secret)
NETBIRD_ENCRYPTION_KEY=$(rand_b64_key)
NETBIRD_RELAY_AUTH_SECRET=$(rand_secret)
POSTGRES_DSN="host=postgres user=${POSTGRES_USER} password=${POSTGRES_PASSWORD} dbname=${POSTGRES_DB} port=5432 sslmode=disable TimeZone=UTC"
NETBIRD_RELAY_ENDPOINT="rels://${NETBIRD_DOMAIN}:443"
echo ""
echo "Selected:"
echo " Traffic flow: ${NETBIRD_TRAFFIC_FLOW}"
echo " Domain: ${NETBIRD_DOMAIN}"
echo ""
echo "Rendering files into $(pwd) ..."
install -m 600 /dev/null .env
render_env >> .env
render_docker_compose > docker-compose.yml
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' docker-compose.yml && rm -f docker-compose.yml.bak
fi
render_caddyfile > Caddyfile
install -m 600 /dev/null config.yaml
render_config_yaml >> config.yaml
echo "Logging in to ghcr.io ..."
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
unset GHCR_TOKEN
echo ""
echo "Pulling images ..."
$DOCKER_COMPOSE_COMMAND pull
echo ""
echo "Starting postgres ..."
$DOCKER_COMPOSE_COMMAND up -d postgres
sleep 2
wait_postgres
echo ""
echo "Starting remaining services ..."
$DOCKER_COMPOSE_COMMAND up -d
echo ""
echo "Done."
echo ""
echo "Dashboard: https://${NETBIRD_DOMAIN}"
echo ""
echo "Open the dashboard in a browser to complete the first-login owner setup."
echo "All configuration and secrets are stored (mode 600) in $(pwd)/.env"
echo ""
echo "Tail logs:"
echo " cd $(pwd) && $DOCKER_COMPOSE_COMMAND logs -f netbird-server caddy"
}
# ------------------------------------------------------------------
# Renderers
# ------------------------------------------------------------------
render_env() {
cat <<EOF
# Generated by getting-started-enterprise.sh
# Holds all configuration and secrets for the stack. Mode 600.
# Features (set by the script; don't edit without re-running)
NETBIRD_TRAFFIC_FLOW_ENABLED=${NETBIRD_TRAFFIC_FLOW}
# Domain
NETBIRD_DOMAIN=${NETBIRD_DOMAIN}
# Image tags. Default to "latest"
NETBIRD_DASHBOARD_TAG=${NETBIRD_DASHBOARD_TAG:-latest}
NETBIRD_SERVER_TAG=${NETBIRD_SERVER_TAG:-latest}
EOF
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
cat <<EOF
NETBIRD_ENRICHER_TAG=${NETBIRD_ENRICHER_TAG:-latest}
NETBIRD_RECEIVER_TAG=${NETBIRD_RECEIVER_TAG:-latest}
EOF
fi
cat <<EOF
# License keys
EOF
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
cat <<EOF
NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
EOF
fi
cat <<EOF
NETBIRD_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
EOF
cat <<EOF
# Postgres
POSTGRES_USER=${POSTGRES_USER}
POSTGRES_DB=${POSTGRES_DB}
POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
NETBIRD_STORE_ENGINE_POSTGRES_DSN=${POSTGRES_DSN}
# Relay
NETBIRD_RELAY_ENDPOINT=${NETBIRD_RELAY_ENDPOINT}
NETBIRD_RELAY_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
# Datastore encryption
NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}
# Dashboard OIDC scopes
NETBIRD_AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES:-openid profile email groups}
EOF
}
render_docker_compose() {
render_compose_header
render_compose_common
render_compose_server
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
render_compose_flow
fi
render_compose_postgres
render_compose_footer
}
render_compose_header() {
cat <<'EOF'
x-default: &default
restart: unless-stopped
logging:
driver: json-file
options:
max-size: '500m'
max-file: '2'
services:
EOF
}
render_compose_common() {
cat <<'EOF'
caddy:
<<: *default
image: caddy:2
container_name: netbird-caddy
networks: [netbird]
environment:
- CADDY_SECURE_DOMAIN=${NETBIRD_DOMAIN}
ports:
- '443:443'
- '443:443/udp'
- '80:80'
volumes:
- netbird_caddy_data:/data
- ./Caddyfile:/etc/caddy/Caddyfile
dashboard:
<<: *default
image: ghcr.io/netbirdio/dashboard-cloud:${NETBIRD_DASHBOARD_TAG}
container_name: netbird-dashboard
networks: [netbird]
environment:
- NETBIRD_MGMT_API_ENDPOINT=https://${NETBIRD_DOMAIN}
- NETBIRD_MGMT_GRPC_API_ENDPOINT=https://${NETBIRD_DOMAIN}
- AUTH_AUDIENCE=netbird-dashboard
- AUTH_CLIENT_ID=netbird-dashboard
- AUTH_CLIENT_SECRET=
- AUTH_AUTHORITY=https://${NETBIRD_DOMAIN}/oauth2
- USE_AUTH0=false
- AUTH_SUPPORTED_SCOPES=${NETBIRD_AUTH_SUPPORTED_SCOPES}
- AUTH_REDIRECT_URI=/nb-auth
- AUTH_SILENT_REDIRECT_URI=/nb-silent-auth
- NETBIRD_TOKEN_SOURCE=accessToken
- NGINX_SSL_PORT=443
- LETSENCRYPT_DOMAIN=
- LETSENCRYPT_EMAIL=
EOF
}
render_compose_server() {
cat <<'EOF'
netbird-server:
<<: *default
image: ghcr.io/netbirdio/netbird-server-cloud:${NETBIRD_SERVER_TAG}
container_name: netbird-server
networks: [netbird]
depends_on:
dashboard:
condition: service_started
postgres:
condition: service_healthy
ports:
- '3478:3478/udp'
volumes:
- netbird_data:/var/lib/netbird
- ./config.yaml:/etc/netbird/config.yaml
command: ["--config", "/etc/netbird/config.yaml"]
environment:
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
EOF
}
render_compose_flow() {
cat <<'EOF'
nats:
<<: *default
image: nats:2
container_name: netbird-nats
networks: [netbird]
volumes:
- netbird_nats_data:/data
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
enricher:
<<: *default
image: ghcr.io/netbirdio/flow-enricher-cloud:${NETBIRD_ENRICHER_TAG}
container_name: netbird-enricher
networks: [netbird]
depends_on:
postgres:
condition: service_healthy
nats:
condition: service_started
volumes:
- netbird_enricher:/var/lib/netbird
environment:
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
- NB_DATADIR=/var/lib/netbird
- NB_MANAGEMENT_STORE_ENGINE=postgres
- NB_MANAGEMENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
- NETBIRD_STORE_ENGINE_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
- NB_TRAFFIC_EVENT_POSTGRES_DSN=${NETBIRD_STORE_ENGINE_POSTGRES_DSN}
- NB_TRAFFIC_EVENT_STORE_ENGINE=postgres
- NB_MANAGEMENT_STORE_KEY=${NETBIRD_ENCRYPTION_KEY}
- NB_FLOW_ADAPTER_TYPE=nats
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
- NB_FLOW_NATS_STREAM=traffic-events
- NB_METRICS_PORT=9091
- NB_PERSISTENCE_RETENTION_PERIOD=168h
receiver:
<<: *default
image: ghcr.io/netbirdio/flow-receiver-cloud:${NETBIRD_RECEIVER_TAG}
container_name: netbird-receiver
networks: [netbird]
depends_on:
nats:
condition: service_started
environment:
- NB_LICENSE_KEY=${NETBIRD_LICENSE_KEY}
- NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}
- NB_FLOW_LISTEN_PORT=80
- NB_FLOW_ADAPTER_TYPE=nats
- NB_FLOW_NATS_ENDPOINTS=nats://nats:4222
- NB_FLOW_NATS_STREAM=traffic-events
- NB_FLOW_AUTH_SECRET=${NETBIRD_RELAY_AUTH_SECRET}
EOF
}
render_compose_postgres() {
cat <<'EOF'
postgres:
<<: *default
image: postgres:17
container_name: netbird-postgres
networks: [netbird]
environment:
- POSTGRES_USER=${POSTGRES_USER}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
- POSTGRES_DB=${POSTGRES_DB}
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"]
interval: 10s
timeout: 5s
retries: 10
volumes:
- netbird_postgres:/var/lib/postgresql/data
EOF
}
render_compose_footer() {
cat <<'EOF'
volumes:
netbird_data:
EOF
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
cat <<'EOF'
netbird_nats_data:
netbird_enricher:
EOF
fi
cat <<'EOF'
netbird_postgres:
netbird_caddy_data:
networks:
netbird:
EOF
}
render_caddyfile() {
cat <<'EOF'
{
servers :80,:443 {
protocols h1 h2c h2 h3
}
}
(security_headers) {
header * {
Strict-Transport-Security "max-age=3600; includeSubDomains; preload"
X-Content-Type-Options "nosniff"
X-Frame-Options "SAMEORIGIN"
X-XSS-Protection "1; mode=block"
-Server
Referrer-Policy strict-origin-when-cross-origin
}
}
:80 {
redir https://{$CADDY_SECURE_DOMAIN}{uri} permanent
}
{$CADDY_SECURE_DOMAIN}:443 {
import security_headers
# Signal (gRPC over h2c)
reverse_proxy /signalexchange.SignalExchange/* h2c://netbird-server:80
# Management (gRPC over h2c + HTTP)
reverse_proxy /management.ManagementService/* h2c://netbird-server:80
reverse_proxy /api/* netbird-server:80
reverse_proxy /ws-proxy/* netbird-server:80
# Embedded IdP (OAuth2 endpoints served by netbird server)
reverse_proxy /oauth2/* netbird-server:80
# Relay (WebSocket multiplexed on the same port)
reverse_proxy /relay* netbird-server:80
EOF
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
cat <<'EOF'
# Flow receiver (gRPC over h2c)
reverse_proxy /flow.FlowService/* h2c://receiver:80
EOF
fi
cat <<'EOF'
# Dashboard
reverse_proxy /* dashboard:80
}
EOF
}
render_config_yaml() {
cat <<EOF
# NetBird Enterprise server configuration.
# Generated by getting-started-enterprise.sh. Mode 600.
server:
listenAddress: ":80"
exposedAddress: "https://${NETBIRD_DOMAIN}:443"
metricsPort: 9090
healthcheckAddress: ":9000"
logLevel: "info"
logFile: "console"
# TLS is terminated by Caddy in front; leave this block empty.
tls:
certFile: ""
keyFile: ""
letsencrypt:
enabled: false
authSecret: "${NETBIRD_RELAY_AUTH_SECRET}"
dataDir: "/var/lib/netbird/"
disableAnonymousMetrics: false
disableGeoliteUpdate: false
auth:
issuer: "https://${NETBIRD_DOMAIN}/oauth2"
localAuthDisabled: false
signKeyRefreshEnabled: false
dashboardRedirectURIs:
- "https://${NETBIRD_DOMAIN}/nb-auth"
- "https://${NETBIRD_DOMAIN}/nb-silent-auth"
cliRedirectURIs:
- "http://localhost:53000/"
store:
engine: "postgres"
dsn: "${POSTGRES_DSN}"
encryptionKey: "${NETBIRD_ENCRYPTION_KEY}"
activityStore:
engine: "postgres"
dsn: "${POSTGRES_DSN}"
EOF
if [[ "$NETBIRD_TRAFFIC_FLOW" == "yes" ]]; then
cat <<EOF
trafficFlow:
enabled: true
address: "https://${NETBIRD_DOMAIN}:443"
interval: "60s"
EOF
fi
}
init_environment

View File

@@ -0,0 +1,638 @@
#!/bin/bash
set -e
set -o pipefail
# NetBird — community combined → Enterprise combined migration
#
# Non-destructive migration: produces docker-compose.override.yml (auto-loaded
# by docker compose) and config.yaml.enterprise alongside the operator's
# existing files. Original docker-compose.yml and config.yaml are never
# modified.
#
# Steps (all optional, asked interactively):
# 1. Image swap — replace community images with enterprise cloud images.
# 2. Postgres migration — add Postgres, migrate SQLite data via migrate-store.
# 3. Traffic flow — add NATS + flow-enricher + flow-receiver.
#
# To revert:
# docker compose down
# rm -f docker-compose.override.yml config.yaml.enterprise
# # If Postgres migration was done, also restore the SQLite backup printed
# # at the end of this script's run.
# docker compose up -d
OVERRIDE_FILE="docker-compose.override.yml"
ENTERPRISE_CONFIG_FILE="config.yaml.enterprise"
check_docker_compose() {
if command -v docker-compose &> /dev/null; then
echo "docker-compose"
return
fi
if docker compose --help &> /dev/null; then
echo "docker compose"
return
fi
echo "docker-compose is not installed or not in PATH." > /dev/stderr
exit 1
}
check_yq() {
if ! command -v yq &> /dev/null; then
cat > /dev/stderr <<'EOF'
yq is required to parse and update YAML safely.
macOS: brew install yq
Linux: https://github.com/mikefarah/yq/releases (download binary into PATH)
Debian: apt-get install yq (Note: must be the mikefarah Go yq, not the Python wrapper.)
EOF
exit 1
fi
if ! yq --version 2>&1 | grep -q "mikefarah"; then
echo "yq is present but appears to be the wrong implementation. The mikefarah Go-based yq is required (https://github.com/mikefarah/yq)." > /dev/stderr
exit 1
fi
}
check_openssl() {
if ! command -v openssl &> /dev/null; then
echo "openssl is not installed or not in PATH." > /dev/stderr
exit 1
fi
}
rand_password() {
openssl rand -hex 32
}
read_required() {
local prompt="$1"
local value=""
while [[ -z "$value" ]]; do
echo -n "$prompt: " > /dev/stderr
read -r value < /dev/tty
if [[ -z "$value" ]]; then
echo "Value cannot be empty." > /dev/stderr
fi
done
echo "$value"
}
read_secret() {
local prompt="$1"
local value=""
while [[ -z "$value" ]]; do
echo -n "$prompt: " > /dev/stderr
read -rs value < /dev/tty
echo "" > /dev/stderr
if [[ -z "$value" ]]; then
echo "Value cannot be empty." > /dev/stderr
fi
done
echo "$value"
}
read_yes_no() {
local prompt="$1"
local default="${2:-n}"
local hint
if [[ "$default" == "y" ]]; then
hint="[Y/n]"
else
hint="[y/N]"
fi
echo -n "${prompt} ${hint}: " > /dev/stderr
local ans=""
read -r ans < /dev/tty
if [[ -z "$ans" ]]; then
ans="$default"
fi
case "$ans" in
[Yy] | [Yy][Ee][Ss]) echo "yes" ;;
*) echo "no" ;;
esac
}
# ---------------------------------------------------------------------------
# Detection — read the operator's existing compose to find service names and
# paths we need to override. Bail loudly if shape isn't recognised.
# ---------------------------------------------------------------------------
detect_combined_service() {
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/netbird-server"))) | .[0].key // ""' "$COMPOSE_FILE"
}
detect_dashboard_service() {
yq eval '.services | to_entries | map(select(.value.image | test("^netbirdio/dashboard"))) | .[0].key // ""' "$COMPOSE_FILE"
}
detect_config_yaml_host_path() {
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/etc/netbird/config.yaml\")) | sub(\":/etc/netbird/config.yaml.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
}
detect_data_volume() {
yq eval ".services[\"$COMBINED_SERVICE\"].volumes[] | select(. | test(\":/var/lib/netbird\")) | sub(\":/var/lib/netbird.*\"; \"\") // \"\"" "$COMPOSE_FILE" | head -1
}
detect_exposed_address() {
yq eval '.server.exposedAddress // ""' "$CONFIG_YAML_HOST"
}
detect_compose_network() {
local tag
tag=$(yq eval ".services[\"$COMBINED_SERVICE\"].networks | tag" "$COMPOSE_FILE" 2>/dev/null)
case "$tag" in
"!!seq")
yq eval ".services[\"$COMBINED_SERVICE\"].networks[0]" "$COMPOSE_FILE"
;;
"!!map")
yq eval ".services[\"$COMBINED_SERVICE\"].networks | keys | .[0]" "$COMPOSE_FILE"
;;
*)
echo "default"
;;
esac
}
# ---------------------------------------------------------------------------
# Renderers
# ---------------------------------------------------------------------------
# Build docker-compose.override.yml from the steps the operator selected.
# Service names match what we detected on the operator's side.
render_override() {
cat <<EOF
# Generated by migrate-to-enterprise.sh. Mode 644.
# Merged with docker-compose.yml automatically by Docker Compose.
# Remove this file (and config.yaml.enterprise if present) to revert.
services:
${DASHBOARD_SERVICE}:
image: \${NETBIRD_DASHBOARD_IMAGE:-ghcr.io/netbirdio/dashboard-cloud:latest}
${COMBINED_SERVICE}:
image: \${NETBIRD_SERVER_IMAGE:-ghcr.io/netbirdio/netbird-server-cloud:latest}
environment:
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
EOF
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
cat <<EOF
depends_on:
postgres:
condition: service_healthy
volumes:
- ./${ENTERPRISE_CONFIG_FILE}:/etc/netbird/config.yaml.enterprise:ro
command: ["--config", "/etc/netbird/config.yaml.enterprise"]
postgres:
image: postgres:17
container_name: netbird-postgres
restart: unless-stopped
networks: [${COMPOSE_NETWORK}]
environment:
POSTGRES_USER: netbird
POSTGRES_PASSWORD: \${POSTGRES_PASSWORD}
POSTGRES_DB: netbird
volumes:
- netbird_postgres:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U netbird -d netbird"]
interval: 5s
timeout: 5s
retries: 20
EOF
fi
if [[ "$ENABLE_FLOW" == "yes" ]]; then
cat <<EOF
nats:
image: nats:2
container_name: netbird-nats
restart: unless-stopped
networks: [${COMPOSE_NETWORK}]
command: ["-m", "8222", "--jetstream", "--store_dir", "/data"]
volumes:
- netbird_nats_data:/data
flow-enricher:
image: ghcr.io/netbirdio/flow-enricher-cloud:latest
container_name: netbird-flow-enricher
restart: unless-stopped
networks: [${COMPOSE_NETWORK}]
depends_on:
postgres:
condition: service_healthy
nats:
condition: service_started
environment:
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
NB_DATADIR: /var/lib/netbird
NB_MANAGEMENT_STORE_ENGINE: postgres
NB_MANAGEMENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
NB_STORE_ENGINE_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
NB_TRAFFIC_EVENT_STORE_ENGINE: postgres
NB_TRAFFIC_EVENT_POSTGRES_DSN: "host=postgres user=netbird password=\${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
NB_MANAGEMENT_STORE_KEY: \${NETBIRD_ENCRYPTION_KEY}
NB_FLOW_ADAPTER_TYPE: nats
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
NB_FLOW_NATS_STREAM: traffic-events
NB_METRICS_PORT: 9091
NB_PERSISTENCE_RETENTION_PERIOD: 168h
flow-receiver:
image: ghcr.io/netbirdio/flow-receiver-cloud:latest
container_name: netbird-flow-receiver
restart: unless-stopped
networks: [${COMPOSE_NETWORK}]
depends_on:
nats:
condition: service_started
environment:
NB_LICENSE_KEY: \${NB_LICENSE_KEY}
NETBIRD_LICENSE_SERVER_BASE_URL: \${NETBIRD_LICENSE_SERVER_BASE_URL}
NB_FLOW_LISTEN_PORT: 80
NB_FLOW_ADAPTER_TYPE: nats
NB_FLOW_NATS_ENDPOINTS: nats://nats:4222
NB_FLOW_NATS_STREAM: traffic-events
NB_FLOW_AUTH_SECRET: \${NB_FLOW_AUTH_SECRET}
labels:
- traefik.enable=true
- traefik.http.routers.netbird-flow.rule=Host(\`${NETBIRD_HOSTNAME}\`) && PathPrefix(\`/flow.FlowService/\`)
- traefik.http.routers.netbird-flow.entrypoints=websecure
- traefik.http.routers.netbird-flow.tls=true
- traefik.http.routers.netbird-flow.tls.certresolver=letsencrypt
- traefik.http.routers.netbird-flow.service=netbird-flow-h2c
- traefik.http.routers.netbird-flow.priority=100
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.port=80
- traefik.http.services.netbird-flow-h2c.loadbalancer.server.scheme=h2c
EOF
fi
# Volume declarations for anything new the override introduced
local has_volumes="no"
if [[ "$MIGRATE_POSTGRES" == "yes" ]] || [[ "$ENABLE_FLOW" == "yes" ]]; then
has_volumes="yes"
fi
if [[ "$has_volumes" == "yes" ]]; then
cat <<EOF
volumes:
EOF
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
echo " netbird_postgres:"
fi
if [[ "$ENABLE_FLOW" == "yes" ]]; then
echo " netbird_nats_data:"
fi
fi
}
# Build config.yaml.enterprise by yq-editing the operator's existing
# config.yaml. We don't touch the original file.
render_enterprise_config() {
local pg_dsn="host=postgres user=netbird password=${POSTGRES_PASSWORD} dbname=netbird port=5432 sslmode=disable"
yq eval "
.server.store.engine = \"postgres\" |
.server.store.dsn = \"$pg_dsn\" |
.server.activityStore.engine = \"postgres\" |
.server.activityStore.dsn = \"$pg_dsn\" |
.server.authStore.engine = \"postgres\" |
.server.authStore.dsn = \"$pg_dsn\"
" "$CONFIG_YAML_HOST" > "$ENTERPRISE_CONFIG_FILE"
if [[ "$ENABLE_FLOW" == "yes" ]]; then
local flow_addr="${NETBIRD_DOMAIN}"
yq eval -i "
.server.trafficFlow.enabled = true |
.server.trafficFlow.address = \"$flow_addr\" |
.server.trafficFlow.interval = \"60s\"
" "$ENTERPRISE_CONFIG_FILE"
fi
}
# ---------------------------------------------------------------------------
# Execution steps
# ---------------------------------------------------------------------------
resolve_data_volume() {
local short="$1"
local actual
# Resolve project-prefixed volume name from Docker Compose config first.
actual=$($DOCKER_COMPOSE_COMMAND config 2>/dev/null | yq eval ".volumes.\"$short\".name" - 2>/dev/null)
if [[ -n "$actual" && "$actual" != "null" ]]; then
echo "$actual"
return
fi
# Relative bind mount: docker-compose resolves it against the compose
# file's directory, but `docker run -v` resolves it against the current
# working directory. Normalize to an absolute path so both interpretations
# agree (and the printed revert command works from any CWD).
if [[ "$short" == ./* || "$short" == ../* ]]; then
local compose_dir
compose_dir="$(cd "$(dirname "$COMPOSE_FILE")" && pwd)"
(
cd "$compose_dir"
cd "$(dirname "$short")"
printf '%s/%s\n' "$(pwd)" "$(basename "$short")"
)
return
fi
# Not a named volume (e.g. an absolute bind-mount path) — use it as-is.
echo "$short"
}
backup_sqlite() {
BACKUP_DIR="$(pwd)/backups/sqlite-pre-enterprise-$(date +%Y%m%d-%H%M%S)"
mkdir -p "$BACKUP_DIR"
local data_volume_actual
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
echo "Backing up SQLite store from volume '$data_volume_actual' to $BACKUP_DIR ..."
docker run --rm \
-v "${data_volume_actual}:/var/lib/netbird:ro" \
-v "${BACKUP_DIR}:/backup" \
busybox \
sh -c 'cp -a /var/lib/netbird/. /backup/ 2>/dev/null || true'
local copied
copied=$(find "$BACKUP_DIR" -mindepth 1 | head -1)
if [[ -z "$copied" ]]; then
echo " ⚠ Backup directory is empty — the volume '$data_volume_actual' didn't contain data. Aborting." > /dev/stderr
exit 1
fi
echo " done"
}
run_migrate_store() {
echo "Running migrate-store (SQLite → Postgres) ..."
$DOCKER_COMPOSE_COMMAND run --rm "$COMBINED_SERVICE" migrate-store --config /etc/netbird/config.yaml.enterprise --verify
echo " done"
}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
init_migration() {
DOCKER_COMPOSE_COMMAND=$(check_docker_compose)
check_yq
check_openssl
COMPOSE_FILE="${COMPOSE_FILE:-docker-compose.yml}"
if [[ ! -f "$COMPOSE_FILE" ]]; then
echo "$COMPOSE_FILE not found in $(pwd)." > /dev/stderr
exit 1
fi
if [[ -f "$OVERRIDE_FILE" ]] || [[ -f "$ENTERPRISE_CONFIG_FILE" ]]; then
echo "Migration artifacts already exist in $(pwd):"
[[ -f "$OVERRIDE_FILE" ]] && echo " $OVERRIDE_FILE"
[[ -f "$ENTERPRISE_CONFIG_FILE" ]] && echo " $ENTERPRISE_CONFIG_FILE"
echo ""
echo "Either you've already migrated, or a previous run was interrupted."
echo "To re-run cleanly: rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
exit 1
fi
COMBINED_SERVICE=$(detect_combined_service)
DASHBOARD_SERVICE=$(detect_dashboard_service)
CONFIG_YAML_HOST=$(detect_config_yaml_host_path)
DATA_VOLUME=$(detect_data_volume)
COMPOSE_NETWORK=$(detect_compose_network)
if [[ -z "$COMBINED_SERVICE" ]]; then
echo "Could not find a service running netbirdio/netbird-server* in $COMPOSE_FILE." > /dev/stderr
echo "This script targets the community combined-server deployment." > /dev/stderr
exit 1
fi
if [[ -z "$DASHBOARD_SERVICE" ]]; then
echo "Could not find a service running netbirdio/dashboard* in $COMPOSE_FILE." > /dev/stderr
exit 1
fi
if [[ -z "$CONFIG_YAML_HOST" ]]; then
echo "Could not find a config.yaml mount on $COMBINED_SERVICE (expected to bind-mount to /etc/netbird/config.yaml)." > /dev/stderr
exit 1
fi
if [[ ! -f "$CONFIG_YAML_HOST" ]]; then
echo "config.yaml host file not found at $CONFIG_YAML_HOST." > /dev/stderr
exit 1
fi
if [[ -z "$DATA_VOLUME" ]]; then
echo "Could not find a volume mounted at /var/lib/netbird on $COMBINED_SERVICE." > /dev/stderr
exit 1
fi
echo "Detected existing deployment:"
echo " Combined service: $COMBINED_SERVICE"
echo " Dashboard: $DASHBOARD_SERVICE"
echo " config.yaml: $CONFIG_YAML_HOST"
echo " Data volume: $DATA_VOLUME"
echo " Network: $COMPOSE_NETWORK"
echo ""
local proceed
proceed=$(read_yes_no "Proceed with migration?" "y")
if [[ "$proceed" != "yes" ]]; then
echo "Aborted."
exit 0
fi
# Step 1 — always (this is the point of the script)
MIGRATE_IMAGES="yes"
echo ""
echo "Step 1: Image swap (community → Enterprise). License key required."
NB_LICENSE_KEY=$(read_secret " License key")
GHCR_USERNAME="netbirdExtAccess1"
GHCR_TOKEN=$(read_secret " GHCR token (input hidden)")
# Step 2 — optional
echo ""
MIGRATE_POSTGRES=$(read_yes_no "Step 2: Migrate storage from SQLite to Postgres? (recommended)" "n")
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
echo ""
echo " ⚠ Data will be migrated from SQLite to Postgres. The SQLite store"
echo " will be backed up automatically. To fully revert later, restore"
echo " that backup and delete docker-compose.override.yml +"
echo " config.yaml.enterprise."
local confirm
confirm=$(read_yes_no " Continue?" "y")
if [[ "$confirm" != "yes" ]]; then
MIGRATE_POSTGRES="no"
echo " Skipping Postgres migration."
else
POSTGRES_PASSWORD=$(rand_password)
fi
fi
# Step 3 — optional, only if Postgres is on (flow requires Postgres)
echo ""
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
ENABLE_FLOW=$(read_yes_no "Step 3: Enable traffic flow? (requires Postgres)" "n")
if [[ "$ENABLE_FLOW" == "yes" ]]; then
# Auth secret MUST match server.authSecret from config.yaml
NB_FLOW_AUTH_SECRET=$(yq eval '.server.authSecret // ""' "$CONFIG_YAML_HOST")
if [[ -z "$NB_FLOW_AUTH_SECRET" ]] || [[ "$NB_FLOW_AUTH_SECRET" == "null" ]]; then
echo "Could not read server.authSecret from $CONFIG_YAML_HOST." > /dev/stderr
echo "Flow receiver auth must match the combined server's authSecret." > /dev/stderr
exit 1
fi
NETBIRD_DOMAIN=$(detect_exposed_address)
if [[ -z "$NETBIRD_DOMAIN" ]] || [[ "$NETBIRD_DOMAIN" == "null" ]]; then
NETBIRD_DOMAIN=$(read_required " Public NetBird URL (e.g. https://netbird.example.com)")
fi
# Strip protocol + port to leave just the hostname for the Traefik Host() rule.
NETBIRD_HOSTNAME=$(echo "$NETBIRD_DOMAIN" | sed -E 's,^https?://,,' | sed 's,:.*,,' | sed 's,/.*,,')
# We need the encryption key from the existing config.yaml for the enricher
NETBIRD_ENCRYPTION_KEY=$(yq eval '.server.store.encryptionKey // ""' "$CONFIG_YAML_HOST")
if [[ -z "$NETBIRD_ENCRYPTION_KEY" ]] || [[ "$NETBIRD_ENCRYPTION_KEY" == "null" ]]; then
echo "Could not read server.store.encryptionKey from $CONFIG_YAML_HOST." > /dev/stderr
exit 1
fi
fi
else
ENABLE_FLOW="no"
echo "Step 3 (traffic flow) skipped — requires Postgres."
fi
}
apply_changes() {
echo ""
echo "Writing $OVERRIDE_FILE ..."
install -m 644 /dev/null "$OVERRIDE_FILE"
render_override > "$OVERRIDE_FILE"
if [[ -z "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
sed -i.bak '/NETBIRD_LICENSE_SERVER_BASE_URL/d' "$OVERRIDE_FILE" && rm -f "$OVERRIDE_FILE.bak"
fi
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
echo "Writing $ENTERPRISE_CONFIG_FILE ..."
install -m 600 /dev/null "$ENTERPRISE_CONFIG_FILE"
render_enterprise_config
fi
# Persist secrets that the override file references via env interpolation.
# We write them to a .env file in the current directory; docker compose
# picks it up automatically.
echo "Writing .env additions (mode 600) ..."
local ENV_FILE=".env"
touch "$ENV_FILE"
chmod 600 "$ENV_FILE"
{
echo ""
echo "# Added by migrate-to-enterprise.sh on $(date -u +%Y-%m-%dT%H:%M:%SZ)"
echo "NB_LICENSE_KEY=${NB_LICENSE_KEY}"
if [[ -n "${NETBIRD_LICENSE_SERVER_BASE_URL:-}" ]]; then
echo "NETBIRD_LICENSE_SERVER_BASE_URL=${NETBIRD_LICENSE_SERVER_BASE_URL}"
fi
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
echo "POSTGRES_PASSWORD=${POSTGRES_PASSWORD}"
fi
if [[ "$ENABLE_FLOW" == "yes" ]]; then
echo "NB_FLOW_AUTH_SECRET=${NB_FLOW_AUTH_SECRET}"
echo "NETBIRD_ENCRYPTION_KEY=${NETBIRD_ENCRYPTION_KEY}"
fi
} >> "$ENV_FILE"
echo ""
echo "Logging in to ghcr.io ..."
printf '%s' "$GHCR_TOKEN" | docker login ghcr.io -u "$GHCR_USERNAME" --password-stdin
unset GHCR_TOKEN
echo ""
echo "Pulling enterprise images ..."
$DOCKER_COMPOSE_COMMAND pull
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
echo ""
echo "Stopping existing services (volumes preserved) ..."
$DOCKER_COMPOSE_COMMAND down
backup_sqlite
echo ""
echo "Starting Postgres ..."
$DOCKER_COMPOSE_COMMAND up -d postgres
# Wait for healthy
local counter=0
echo -n "Waiting for Postgres to become ready"
while ! $DOCKER_COMPOSE_COMMAND exec -T postgres pg_isready -U netbird -d netbird &> /dev/null; do
echo -n " ."
sleep 2
counter=$((counter + 1))
if [[ $counter -ge 60 ]]; then
echo ""
echo "Postgres did not become ready in 120s. Recent logs:"
$DOCKER_COMPOSE_COMMAND logs --tail=20 postgres
exit 1
fi
done
echo " done"
run_migrate_store
fi
echo ""
echo "Bringing up all services ..."
$DOCKER_COMPOSE_COMMAND up -d
echo ""
echo "Migration complete."
}
print_summary() {
echo ""
echo "──────────────────────────────────────────────────────────────────────"
echo " Summary"
echo "──────────────────────────────────────────────────────────────────────"
echo " Images: swapped to enterprise"
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " Storage: Postgres (data migrated from SQLite)"
[[ "$MIGRATE_POSTGRES" != "yes" ]] && echo " Storage: SQLite (unchanged)"
[[ "$ENABLE_FLOW" == "yes" ]] && echo " Traffic flow: enabled"
[[ "$ENABLE_FLOW" != "yes" ]] && echo " Traffic flow: disabled"
echo ""
echo " Generated files (next to your docker-compose.yml):"
echo " $OVERRIDE_FILE"
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " $ENTERPRISE_CONFIG_FILE"
echo " .env (license key + secrets, mode 600)"
[[ "$MIGRATE_POSTGRES" == "yes" ]] && echo " backups/sqlite-pre-enterprise-*/ (SQLite backup)"
echo ""
echo " Tail logs:"
echo " $DOCKER_COMPOSE_COMMAND logs -f $COMBINED_SERVICE"
echo ""
echo "──────────────────────────────────────────────────────────────────────"
echo " To revert"
echo "──────────────────────────────────────────────────────────────────────"
echo " $DOCKER_COMPOSE_COMMAND down"
if [[ "$MIGRATE_POSTGRES" == "yes" ]]; then
# Resolve project-prefixed volume names now (before override is removed).
local pg_volume data_volume_actual
pg_volume=$(resolve_data_volume "netbird_postgres")
data_volume_actual=$(resolve_data_volume "$DATA_VOLUME")
echo " # Remove the Postgres volume FIRST, before deleting the override file:"
echo " docker volume rm $pg_volume"
echo " # Restore SQLite from the backup created during this run:"
echo " docker run --rm -v ${data_volume_actual}:/var/lib/netbird -v ${BACKUP_DIR}:/backup busybox sh -c 'cp -a /backup/. /var/lib/netbird/'"
fi
echo " rm -f $OVERRIDE_FILE $ENTERPRISE_CONFIG_FILE"
echo " # Remove migrate-to-enterprise.sh additions from .env (search for the timestamp marker)"
echo " $DOCKER_COMPOSE_COMMAND up -d"
echo "──────────────────────────────────────────────────────────────────────"
}
# ---------------------------------------------------------------------------
# Run
# ---------------------------------------------------------------------------
init_migration
apply_changes
print_summary

89
management/cmd/admin.go Normal file
View File

@@ -0,0 +1,89 @@
package cmd
import (
"context"
"fmt"
"path/filepath"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
admincmd "github.com/netbirdio/netbird/management/cmd/admin"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
var adminDatadir string
// newAdminCommands creates the admin command tree with management-specific resource openers.
func newAdminCommands() *cobra.Command {
cmd := admincmd.NewCommands(withAdminResources)
cmd.PersistentFlags().StringVar(&adminDatadir, "datadir", "", "Override the data directory from config (used for store.db and the default idp.db)")
cmd.AddCommand(tokencmd.NewCommands(withAdminTokenStore))
return cmd
}
// withAdminResources initializes logging, loads config, opens the management store
// and embedded IdP storage, and calls fn.
func withAdminResources(cmd *cobra.Command, fn func(ctx context.Context, resources admincmd.Resources) error) error {
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, config *nbconfig.Config) error {
idpStorage, err := admincmd.OpenEmbeddedIDPStorage(config.EmbeddedIdP)
if err != nil {
return err
}
defer func() {
if err := idpStorage.Close(); err != nil {
log.Debugf("close embedded IdP storage: %v", err)
}
}()
return fn(ctx, admincmd.Resources{Store: managementStore, IDPStorage: idpStorage})
})
}
// withAdminTokenStore opens only the management store for admin token commands.
func withAdminTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
return withAdminStore(cmd, func(ctx context.Context, managementStore store.Store, _ *nbconfig.Config) error {
return fn(ctx, managementStore)
})
}
func withAdminStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store, config *nbconfig.Config) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
datadir := config.Datadir
if adminDatadir != "" {
oldDatadir := datadir
datadir = adminDatadir
if config.EmbeddedIdP != nil && config.EmbeddedIdP.Storage.Type == "sqlite3" {
defaultIDPFile := filepath.Join(oldDatadir, "idp.db")
if config.EmbeddedIdP.Storage.Config.File == "" || config.EmbeddedIdP.Storage.Config.File == defaultIDPFile {
config.EmbeddedIdP.Storage.Config.File = filepath.Join(datadir, "idp.db")
}
}
}
managementStore, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := managementStore.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, managementStore, config)
}

View File

@@ -0,0 +1,441 @@
// Package admincmd provides reusable cobra commands for self-hosted administrator helpers.
// Both the management and combined binaries use these commands, each providing
// their own opener to handle config loading and storage initialization.
package admincmd
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"os"
"strings"
"github.com/dexidp/dex/storage"
"github.com/spf13/cobra"
"golang.org/x/crypto/bcrypt"
nbdex "github.com/netbirdio/netbird/idp/dex"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
const (
localConnectorID = "local"
dashboardClientID = "netbird-dashboard"
cliClientID = "netbird-cli"
defaultTOTPAuthenticatorID = "default-totp"
)
// Resources contains the storages required by the admin commands.
type Resources struct {
Store store.Store
IDPStorage storage.Storage
}
// Opener initializes command resources from the command context and calls fn.
type Opener func(cmd *cobra.Command, fn func(ctx context.Context, resources Resources) error) error
type userSelector struct {
email string
userID string
}
func (s userSelector) normalized() userSelector {
return userSelector{
email: strings.TrimSpace(s.email),
userID: strings.TrimSpace(s.userID),
}
}
func (s userSelector) validate() error {
s = s.normalized()
if (s.email == "") == (s.userID == "") {
return fmt.Errorf("provide exactly one of --email or --user-id")
}
return nil
}
// NewCommands creates the admin command tree with the given resource opener.
func NewCommands(opener Opener) *cobra.Command {
adminCmd := &cobra.Command{
Use: "admin",
Short: "Self-hosted administrator helpers",
Long: "Administrative helpers for self-hosted deployments using the embedded identity provider.",
}
userCmd := &cobra.Command{
Use: "user",
Short: "Manage local embedded IdP users",
}
var passwordSelector userSelector
var password string
var passwordFile string
passwordCmd := &cobra.Command{
Use: "change-password (--email email | --user-id id) (--password password | --password-file path)",
Aliases: []string{"set-password"},
Short: "Change a local user's password",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
newPassword, err := resolvePasswordInput(cmd, password, passwordFile)
if err != nil {
return err
}
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runChangePassword(ctx, resources.IDPStorage, cmd.OutOrStdout(), passwordSelector, newPassword)
})
},
}
addUserSelectorFlags(passwordCmd, &passwordSelector)
passwordCmd.Flags().StringVar(&password, "password", "", "New password for the user")
passwordCmd.Flags().StringVar(&passwordFile, "password-file", "", "Read new password from file ('-' for stdin)")
var resetSelector userSelector
resetMFACmd := &cobra.Command{
Use: "reset-mfa (--email email | --user-id id)",
Short: "Reset a local user's MFA enrollment",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runResetMFA(ctx, resources.IDPStorage, cmd.OutOrStdout(), resetSelector)
})
},
}
addUserSelectorFlags(resetMFACmd, &resetSelector)
userCmd.AddCommand(passwordCmd, resetMFACmd)
mfaCmd := &cobra.Command{
Use: "mfa",
Short: "Manage local MFA for embedded IdP users",
}
enableCmd := &cobra.Command{
Use: "enable",
Short: "Enable MFA for local embedded IdP users",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), true)
})
},
}
disableCmd := &cobra.Command{
Use: "disable",
Short: "Disable MFA for local embedded IdP users",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runSetMFAEnabled(ctx, resources, cmd.OutOrStdout(), false)
})
},
}
statusCmd := &cobra.Command{
Use: "status",
Short: "Show local MFA status",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
return opener(cmd, func(ctx context.Context, resources Resources) error {
return runMFAStatus(ctx, resources, cmd.OutOrStdout())
})
},
}
mfaCmd.AddCommand(enableCmd, disableCmd, statusCmd)
adminCmd.AddCommand(userCmd, mfaCmd)
return adminCmd
}
// OpenEmbeddedIDPStorage opens the Dex storage configured for the embedded IdP.
func OpenEmbeddedIDPStorage(cfg *idp.EmbeddedIdPConfig) (storage.Storage, error) {
if cfg == nil || !cfg.Enabled {
return nil, fmt.Errorf("admin commands require the embedded IdP to be enabled")
}
yamlConfig, err := cfg.ToYAMLConfig()
if err != nil {
return nil, fmt.Errorf("build embedded IdP config: %w", err)
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
st, err := yamlConfig.Storage.OpenStorage(logger)
if err != nil {
return nil, fmt.Errorf("open embedded IdP storage: %w", err)
}
return st, nil
}
func addUserSelectorFlags(cmd *cobra.Command, selector *userSelector) {
cmd.Flags().StringVar(&selector.email, "email", "", "User email")
cmd.Flags().StringVar(&selector.userID, "user-id", "", "User ID")
}
func resolvePasswordInput(cmd *cobra.Command, password, passwordFile string) (string, error) {
if password != "" && passwordFile != "" {
return "", fmt.Errorf("provide only one of --password or --password-file")
}
if passwordFile == "" {
return password, nil
}
var data []byte
var err error
if passwordFile == "-" {
data, err = io.ReadAll(cmd.InOrStdin())
} else {
data, err = os.ReadFile(passwordFile)
}
if err != nil {
return "", fmt.Errorf("read password: %w", err)
}
return strings.TrimRight(string(data), "\r\n"), nil
}
func runChangePassword(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector, password string) error {
if idpStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
selector = selector.normalized()
if err := selector.validate(); err != nil {
return err
}
if password == "" {
return fmt.Errorf("password is required")
}
if err := server.ValidatePassword(password); err != nil {
return fmt.Errorf("invalid password: %w", err)
}
user, err := findLocalUser(ctx, idpStorage, selector)
if err != nil {
return err
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("hash password: %w", err)
}
if err := idpStorage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
old.Hash = hash
return old, nil
}); err != nil {
return fmt.Errorf("update password for %s: %w", user.Email, err)
}
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
_, _ = fmt.Fprintf(w, "Password updated for %s.\n", user.Email)
return nil
}
func runResetMFA(ctx context.Context, idpStorage storage.Storage, w io.Writer, selector userSelector) error {
if idpStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
selector = selector.normalized()
if err := selector.validate(); err != nil {
return err
}
user, err := findLocalUser(ctx, idpStorage, selector)
if err != nil {
return err
}
reset := false
err = idpStorage.UpdateUserIdentity(ctx, user.UserID, localConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
reset = reset || len(old.MFASecrets) > 0 || len(old.WebAuthnCredentials) > 0
old.MFASecrets = map[string]*storage.MFASecret{}
old.WebAuthnCredentials = map[string][]storage.WebAuthnCredential{}
return old, nil
})
if errors.Is(err, storage.ErrNotFound) {
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
return nil
}
if err != nil {
return fmt.Errorf("reset MFA for %s: %w", user.Email, err)
}
if err := deleteLocalAuthSession(ctx, idpStorage, user.UserID); err != nil {
return err
}
if reset {
_, _ = fmt.Fprintf(w, "MFA reset for %s. The user will re-enroll at next login.\n", user.Email)
} else {
_, _ = fmt.Fprintf(w, "No MFA enrollment found for %s.\n", user.Email)
}
return nil
}
func runSetMFAEnabled(ctx context.Context, resources Resources, w io.Writer, enabled bool) error {
if resources.Store == nil {
return fmt.Errorf("management store is required")
}
if resources.IDPStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
accounts := resources.Store.GetAllAccounts(ctx)
if len(accounts) != 1 {
return fmt.Errorf("expected exactly one account, got %d; local MFA is supported only in single-account embedded IdP deployments", len(accounts))
}
settings := &types.Settings{}
if accounts[0].Settings != nil {
settings = accounts[0].Settings.Copy()
}
settings.LocalMfaEnabled = enabled
if err := resources.Store.SaveAccountSettings(ctx, accounts[0].Id, settings); err != nil {
return fmt.Errorf("save local MFA account setting: %w", err)
}
if err := setIDPClientsMFA(ctx, resources.IDPStorage, enabled); err != nil {
return err
}
state := "disabled"
if enabled {
state = "enabled"
}
_, _ = fmt.Fprintf(w, "Local MFA %s.\n", state)
return nil
}
func runMFAStatus(ctx context.Context, resources Resources, w io.Writer) error {
if resources.Store == nil {
return fmt.Errorf("management store is required")
}
if resources.IDPStorage == nil {
return fmt.Errorf("embedded IdP storage is required")
}
accounts := resources.Store.GetAllAccounts(ctx)
accountStatus := "unknown"
if len(accounts) == 1 && accounts[0].Settings != nil {
accountStatus = "disabled"
if accounts[0].Settings.LocalMfaEnabled {
accountStatus = "enabled"
}
}
clientStatus, err := idpClientsMFAStatus(ctx, resources.IDPStorage)
if err != nil {
return err
}
_, _ = fmt.Fprintf(w, "Account setting: %s\n", accountStatus)
_, _ = fmt.Fprintf(w, "Embedded IdP clients: %s\n", clientStatus)
return nil
}
func findLocalUser(ctx context.Context, idpStorage storage.Storage, selector userSelector) (storage.Password, error) {
selector = selector.normalized()
if err := selector.validate(); err != nil {
return storage.Password{}, err
}
if selector.email != "" {
user, err := idpStorage.GetPassword(ctx, selector.email)
if errors.Is(err, storage.ErrNotFound) {
return storage.Password{}, fmt.Errorf("local user with email %q not found", selector.email)
}
if err != nil {
return storage.Password{}, fmt.Errorf("get local user by email %q: %w", selector.email, err)
}
return user, nil
}
rawUserID := selector.userID
if decodedUserID, _, err := nbdex.DecodeDexUserID(selector.userID); err == nil && decodedUserID != "" {
rawUserID = decodedUserID
}
users, err := idpStorage.ListPasswords(ctx)
if err != nil {
return storage.Password{}, fmt.Errorf("list local users: %w", err)
}
for _, user := range users {
if user.UserID == rawUserID || user.UserID == selector.userID {
return user, nil
}
}
return storage.Password{}, fmt.Errorf("local user with ID %q not found", selector.userID)
}
func deleteLocalAuthSession(ctx context.Context, idpStorage storage.Storage, userID string) error {
err := idpStorage.DeleteAuthSession(ctx, userID, localConnectorID)
if err == nil || errors.Is(err, storage.ErrNotFound) {
return nil
}
return fmt.Errorf("delete local auth session for user %s: %w", userID, err)
}
func setIDPClientsMFA(ctx context.Context, idpStorage storage.Storage, enabled bool) error {
var mfaChain []string
if enabled {
mfaChain = []string{defaultTOTPAuthenticatorID}
}
for _, clientID := range []string{cliClientID, dashboardClientID} {
if err := idpStorage.UpdateClient(ctx, clientID, func(old storage.Client) (storage.Client, error) {
old.MFAChain = mfaChain
return old, nil
}); err != nil {
if errors.Is(err, storage.ErrNotFound) {
return fmt.Errorf("embedded IdP client %q not found; start the management server once before toggling MFA", clientID)
}
return fmt.Errorf("update MFA chain on embedded IdP client %q: %w", clientID, err)
}
}
return nil
}
func idpClientsMFAStatus(ctx context.Context, idpStorage storage.Storage) (string, error) {
clientIDs := []string{cliClientID, dashboardClientID}
enabledCount := 0
for _, clientID := range clientIDs {
client, err := idpStorage.GetClient(ctx, clientID)
if errors.Is(err, storage.ErrNotFound) {
return "unknown", fmt.Errorf("embedded IdP client %q not found", clientID)
}
if err != nil {
return "unknown", fmt.Errorf("get embedded IdP client %q: %w", clientID, err)
}
if hasAuthenticator(client.MFAChain, defaultTOTPAuthenticatorID) {
enabledCount++
}
}
switch enabledCount {
case 0:
return "disabled", nil
case len(clientIDs):
return "enabled", nil
default:
return "partially enabled", nil
}
}
func hasAuthenticator(chain []string, authenticatorID string) bool {
for _, id := range chain {
if id == authenticatorID {
return true
}
}
return false
}

View File

@@ -0,0 +1,160 @@
package admincmd
import (
"bytes"
"context"
"io"
"log/slog"
"strings"
"testing"
"time"
"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
nbdex "github.com/netbirdio/netbird/idp/dex"
)
func newTestIDPStorage(t *testing.T) storage.Storage {
t.Helper()
st := memory.New(slog.New(slog.NewTextHandler(io.Discard, nil)))
hash, err := bcrypt.GenerateFromPassword([]byte("OldPass1!"), bcrypt.DefaultCost)
require.NoError(t, err)
require.NoError(t, st.CreatePassword(context.Background(), storage.Password{
Email: "user@example.com",
Username: "User",
UserID: "user-1",
Hash: hash,
}))
require.NoError(t, st.CreateUserIdentity(context.Background(), storage.UserIdentity{
UserID: "user-1",
ConnectorID: localConnectorID,
MFASecrets: map[string]*storage.MFASecret{
defaultTOTPAuthenticatorID: {
AuthenticatorID: defaultTOTPAuthenticatorID,
Type: "TOTP",
Secret: "otpauth://totp/NetBird:user@example.com?secret=ABC",
Confirmed: true,
CreatedAt: time.Now(),
},
},
WebAuthnCredentials: map[string][]storage.WebAuthnCredential{
"webauthn": {{CredentialID: []byte("credential")}},
},
}))
require.NoError(t, st.CreateAuthSession(context.Background(), storage.AuthSession{
UserID: "user-1",
ConnectorID: localConnectorID,
Nonce: "nonce",
}))
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: cliClientID, Name: "CLI"}))
require.NoError(t, st.CreateClient(context.Background(), storage.Client{ID: dashboardClientID, Name: "Dashboard"}))
return st
}
func TestRunChangePassword(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
var out bytes.Buffer
err := runChangePassword(ctx, st, &out, userSelector{email: "user@example.com"}, "NewPass1!")
require.NoError(t, err)
require.Contains(t, out.String(), "Password updated")
user, err := st.GetPassword(ctx, "user@example.com")
require.NoError(t, err)
require.NoError(t, bcrypt.CompareHashAndPassword(user.Hash, []byte("NewPass1!")))
_, err = st.GetAuthSession(ctx, "user-1", localConnectorID)
require.ErrorIs(t, err, storage.ErrNotFound)
}
func TestRunChangePasswordValidatesPassword(t *testing.T) {
st := newTestIDPStorage(t)
err := runChangePassword(context.Background(), st, io.Discard, userSelector{email: "user@example.com"}, "short")
require.Error(t, err)
require.Contains(t, err.Error(), "invalid password")
}
func TestRunResetMFA(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
var out bytes.Buffer
encodedUserID := nbdex.EncodeDexUserID("user-1", localConnectorID)
err := runResetMFA(ctx, st, &out, userSelector{userID: encodedUserID})
require.NoError(t, err)
require.Contains(t, out.String(), "MFA reset")
identity, err := st.GetUserIdentity(ctx, "user-1", localConnectorID)
require.NoError(t, err)
require.Empty(t, identity.MFASecrets)
require.Empty(t, identity.WebAuthnCredentials)
_, err = st.GetAuthSession(ctx, "user-1", localConnectorID)
require.ErrorIs(t, err, storage.ErrNotFound)
}
func TestRunResetMFAWithoutEnrollment(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
require.NoError(t, st.UpdateUserIdentity(ctx, "user-1", localConnectorID, func(old storage.UserIdentity) (storage.UserIdentity, error) {
old.MFASecrets = nil
old.WebAuthnCredentials = nil
return old, nil
}))
var out bytes.Buffer
err := runResetMFA(ctx, st, &out, userSelector{email: "user@example.com"})
require.NoError(t, err)
require.Contains(t, out.String(), "No MFA enrollment found")
}
func TestSetIDPClientsMFA(t *testing.T) {
ctx := context.Background()
st := newTestIDPStorage(t)
require.NoError(t, setIDPClientsMFA(ctx, st, true))
status, err := idpClientsMFAStatus(ctx, st)
require.NoError(t, err)
require.Equal(t, "enabled", status)
require.NoError(t, setIDPClientsMFA(ctx, st, false))
status, err = idpClientsMFAStatus(ctx, st)
require.NoError(t, err)
require.Equal(t, "disabled", status)
}
func TestUserSelectorValidate(t *testing.T) {
require.NoError(t, userSelector{email: " user@example.com "}.validate())
require.NoError(t, userSelector{userID: "user-1"}.validate())
require.Error(t, userSelector{}.validate())
require.Error(t, userSelector{email: "user@example.com", userID: "user-1"}.validate())
}
func TestFindLocalUserNotFound(t *testing.T) {
st := newTestIDPStorage(t)
_, err := findLocalUser(context.Background(), st, userSelector{email: "missing@example.com"})
require.Error(t, err)
require.True(t, strings.Contains(err.Error(), "not found"))
}
func TestResolvePasswordInputFromStdin(t *testing.T) {
cmd := &cobra.Command{}
cmd.SetIn(strings.NewReader("NewPass1!\n"))
password, err := resolvePasswordInput(cmd, "", "-")
require.NoError(t, err)
require.Equal(t, "NewPass1!", password)
}
func TestResolvePasswordInputRejectsMultipleSources(t *testing.T) {
_, err := resolvePasswordInput(&cobra.Command{}, "NewPass1!", "-")
require.Error(t, err)
}

View File

@@ -83,7 +83,7 @@ func init() {
rootCmd.AddCommand(migrationCmd) rootCmd.AddCommand(migrationCmd)
tc := newTokenCommands() ac := newAdminCommands()
tc.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location") ac.PersistentFlags().StringVar(&nbconfig.MgmtConfigPath, "config", defaultMgmtConfig, "Netbird config file location")
rootCmd.AddCommand(tc) rootCmd.AddCommand(ac)
} }

View File

@@ -1,55 +0,0 @@
package cmd
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/netbirdio/netbird/formatter/hook"
tokencmd "github.com/netbirdio/netbird/management/cmd/token"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/util"
)
var tokenDatadir string
// newTokenCommands creates the token command tree with management-specific store opener.
func newTokenCommands() *cobra.Command {
cmd := tokencmd.NewCommands(withTokenStore)
cmd.PersistentFlags().StringVar(&tokenDatadir, "datadir", "", "Override the data directory from config (where store.db is located)")
return cmd
}
// withTokenStore initializes logging, loads config, opens the store, and calls fn.
func withTokenStore(cmd *cobra.Command, fn func(ctx context.Context, s store.Store) error) error {
if err := util.InitLog("error", "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := context.WithValue(cmd.Context(), hook.ExecutionContextKey, hook.SystemSource) //nolint:staticcheck
config, err := LoadMgmtConfig(ctx, nbconfig.MgmtConfigPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
datadir := config.Datadir
if tokenDatadir != "" {
datadir = tokenDatadir
}
s, err := store.NewStore(ctx, config.StoreConfig.Engine, datadir, nil, true)
if err != nil {
return fmt.Errorf("create store: %w", err)
}
defer func() {
if err := s.Close(ctx); err != nil {
log.Debugf("close store: %v", err)
}
}()
return fn(ctx, s)
}

View File

@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
t.Helper() t.Helper()
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t)) tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
return srv return srv
} }
@@ -723,7 +723,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err) require.NoError(t, err)
@@ -1147,7 +1147,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil) proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
require.NoError(t, err) require.NoError(t, err)

View File

@@ -219,7 +219,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer { func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
return Create(s, func() *nbgrpc.ProxyServiceServer { return Create(s, func() *nbgrpc.ProxyServiceServer {
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store()) proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.IdpManager(), s.ProxyManager(), s.Store())
s.AfterInit(func(s *BaseServer) { s.AfterInit(func(s *BaseServer) {
proxyService.SetServiceManager(s.ServiceManager()) proxyService.SetServiceManager(s.ServiceManager())
proxyService.SetProxyController(s.ServiceProxyController()) proxyService.SetProxyController(s.ServiceProxyController())

View File

@@ -33,6 +33,8 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/management/server/users"
proxyauth "github.com/netbirdio/netbird/proxy/auth" proxyauth "github.com/netbirdio/netbird/proxy/auth"
@@ -82,6 +84,9 @@ type ProxyServiceServer struct {
// Manager for users // Manager for users
usersManager users.Manager usersManager users.Manager
// Manager for IdP-enriched user data (may be nil when no IdP is configured)
idpManager idp.Manager
// Store for one-time authentication tokens // Store for one-time authentication tokens
tokenStore *OneTimeTokenStore tokenStore *OneTimeTokenStore
@@ -157,7 +162,7 @@ func enforceAccountScope(ctx context.Context, requestAccountID string) error {
} }
// NewProxyServiceServer creates a new proxy service server. // NewProxyServiceServer creates a new proxy service server.
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer { func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, idpManager idp.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{ s := &ProxyServiceServer{
accessLogManager: accessLogMgr, accessLogManager: accessLogMgr,
@@ -166,6 +171,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
pkceVerifierStore: pkceStore, pkceVerifierStore: pkceStore,
peersManager: peersManager, peersManager: peersManager,
usersManager: usersManager, usersManager: usersManager,
idpManager: idpManager,
proxyManager: proxyMgr, proxyManager: proxyMgr,
tokenChecker: tokenChecker, tokenChecker: tokenChecker,
snapshotBatchSize: snapshotBatchSizeFromEnv(), snapshotBatchSize: snapshotBatchSizeFromEnv(),
@@ -1702,22 +1708,7 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
} }
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups) groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
principalID, displayIdentity := s.getTunnelPeerInfo(ctx, domain, service, peer)
// Resolve the principal: when the peer is linked to a user, the human
// is the principal so multiple peers owned by the same user share a
// single identity. Unlinked peers (machine agents) are their own
// principal keyed on peer.ID. displayIdentity is what upstream gateways
// tag spend with — user.Email when linked, peer.Name when not.
principalID := peer.ID
displayIdentity := peer.Name
if peer.UserID != "" {
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
}
if err := checkPeerGroupAccess(service, groupIDs); err != nil { if err := checkPeerGroupAccess(service, groupIDs); err != nil {
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied") log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
@@ -1754,6 +1745,45 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
}, nil }, nil
} }
// getTunnelPeerInfo returns the principal ID and display name for a peer, e.g. a
// user or peer ID, and peer name or user email.
func (s *ProxyServiceServer) getTunnelPeerInfo(ctx context.Context, domain string, service *rpservice.Service, peer *peer.Peer) (string, string) {
// Resolve the principal: when the peer is linked to a user, the human is the
// principal so multiple peers owned by the same user share a single
// identity. Unlinked peers (machine agents) are their own principal keyed on
// peer.ID. displayIdentity is what upstream gateways tag spend with —
// user.Email when linked, peer.Name when not.
// If the peer isn't associated with a user, return the peer info directly.
if peer.UserID == "" {
return peer.ID, peer.Name
}
// Otherwise, if the peer is linked to a user, the user is the principal and
// if an IdP is available, we gather details on the user from it.
principalID := peer.UserID
displayIdentity := peer.Name
// Stored column first (cheap, but often empty for OIDC-provisioned users).
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
principalID = user.Id
if user.Email != "" {
displayIdentity = user.Email
}
}
// IdP enrichment wins when available — the stored email column is a
// best-effort cache and is frequently empty for OIDC users. Enrichment
// failures must never fail the RPC; we simply keep the stored/peer identity.
if s.idpManager != nil {
if ud, uerr := s.idpManager.GetUserDataByID(ctx, peer.UserID, idp.AppMetadata{WTAccountID: service.AccountID}); uerr == nil && ud != nil && ud.Email != "" {
displayIdentity = ud.Email
} else if uerr != nil {
log.WithFields(log.Fields{"domain": domain, "user_id": peer.UserID, "error": uerr.Error()}).Debug("ValidateTunnelPeer: IdP user enrichment failed; using stored/peer identity")
}
}
return principalID, displayIdentity
}
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required // checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
// groups. Private services authorise against AccessGroups (empty list fails // groups. Private services authorise against AccessGroups (empty list fails
// closed — Validate() rejects that at save time but the RPC is the security // closed — Validate() rejects that at save time but the RPC is the security

View File

@@ -3,14 +3,19 @@ package grpc
import ( import (
"context" "context"
"errors" "errors"
"net"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
) )
type mockReverseProxyManager struct { type mockReverseProxyManager struct {
@@ -137,6 +142,52 @@ func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string)
return user, nil, nil return user, nil, nil
} }
// mockTunnelPeersManager implements only the two peers.Manager methods that
// ValidateTunnelPeer calls; the embedded interface satisfies the rest (and
// panics if any unexpected method is invoked).
type mockTunnelPeersManager struct {
peers.Manager
peer *peer.Peer
peerErr error
groups []*types.Group
groupsErr error
}
func (m *mockTunnelPeersManager) GetPeerByTunnelIP(_ context.Context, _ string, _ net.IP) (*peer.Peer, error) {
return m.peer, m.peerErr
}
func (m *mockTunnelPeersManager) GetPeerWithGroups(_ context.Context, _, _ string) (*peer.Peer, []*types.Group, error) {
return m.peer, m.groups, m.groupsErr
}
// mockTunnelIdpManager implements only GetUserDataByID; the embedded interface
// satisfies the rest of idp.Manager. hasData==false returns (nil, nil) to model
// an IdP that knows nothing about the user.
type mockTunnelIdpManager struct {
idp.Manager
email string
hasData bool
err error
gotCalls int
gotMeta []idp.AppMetadata
}
func (m *mockTunnelIdpManager) GetUserDataByID(_ context.Context, userID string, meta idp.AppMetadata) (*idp.UserData, error) {
m.gotCalls++
m.gotMeta = append(m.gotMeta, meta)
if m.err != nil {
return nil, m.err
}
if !m.hasData {
// This might not be a thing any of the actual IDP implementations do,
// i.e. return a nil value with no error, but it seems valuable to test
// that behavior here.
return nil, nil //nolint:nilnil
}
return &idp.UserData{ID: userID, Email: m.email}, nil
}
func TestValidateUserGroupAccess(t *testing.T) { func TestValidateUserGroupAccess(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -354,6 +405,163 @@ func TestValidateUserGroupAccess(t *testing.T) {
} }
} }
// TestValidateTunnelPeerUserEmailEnrichment verifies the UserEmail/UserId
// resolution in ValidateTunnelPeer, including the IdP-enrichment fallback order
// (IdP email -> stored User.Email -> peer.Name).
func TestValidateTunnelPeerUserEmailEnrichment(t *testing.T) {
const (
domain = "app.example.com"
accountID = "account1"
peerID = "peer1"
peerName = "peer-display-name"
userID = "user1"
)
storedUser := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: "stored@example.com"}}
storedUserNoEmail := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: ""}}
tests := []struct {
name string
peerUserID string
storedUsers map[string]*types.User
storedErr error
noIdP bool
idpEmail string
idpHasData bool
idpErr error
expectEmail string
expectUserID string
expectIdPHit bool
}{
{
name: "idp email wins over stored email",
peerUserID: userID,
storedUsers: storedUser,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp returns empty email",
peerUserID: userID,
storedUsers: storedUser,
idpEmail: "",
idpHasData: true,
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp has no data",
peerUserID: userID,
storedUsers: storedUser,
idpHasData: false,
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when idp errors",
peerUserID: userID,
storedUsers: storedUser,
idpErr: errors.New("idp unreachable"),
expectEmail: "stored@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "stored email when no idp manager",
peerUserID: userID,
storedUsers: storedUser,
noIdP: true,
expectEmail: "stored@example.com",
expectUserID: userID,
},
{
name: "idp email when stored email is empty",
peerUserID: userID,
storedUsers: storedUserNoEmail,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "idp email when stored user missing keeps peer.UserID as principal",
peerUserID: userID,
storedUsers: map[string]*types.User{},
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: "idp@example.com",
expectUserID: userID,
expectIdPHit: true,
},
{
name: "unlinked peer uses peer name and never consults idp",
peerUserID: "",
storedUsers: storedUser,
idpEmail: "idp@example.com",
idpHasData: true,
expectEmail: peerName,
expectUserID: peerID,
expectIdPHit: false,
},
{
name: "linked peer with empty stored email and no idp falls back to peer name",
peerUserID: userID,
storedUsers: storedUserNoEmail,
noIdP: true,
expectEmail: peerName,
expectUserID: userID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
svc := &service.Service{Domain: domain, AccountID: accountID}
server := &ProxyServiceServer{
serviceManager: &mockReverseProxyManager{
proxiesByAccount: map[string][]*service.Service{accountID: {svc}},
},
peersManager: &mockTunnelPeersManager{
peer: &peer.Peer{ID: peerID, Name: peerName, UserID: tt.peerUserID},
},
usersManager: &mockUsersManager{users: tt.storedUsers, err: tt.storedErr},
}
var idpMock *mockTunnelIdpManager
if !tt.noIdP {
idpMock = &mockTunnelIdpManager{email: tt.idpEmail, hasData: tt.idpHasData, err: tt.idpErr}
server.idpManager = idpMock
}
resp, err := server.ValidateTunnelPeer(context.Background(), &proto.ValidateTunnelPeerRequest{
Domain: domain,
TunnelIp: "100.64.0.1",
})
require.NoError(t, err)
require.NotNil(t, resp)
assert.True(t, resp.GetValid(), "expected access granted")
assert.Equal(t, tt.expectEmail, resp.GetUserEmail())
assert.Equal(t, tt.expectUserID, resp.GetUserId())
if idpMock != nil {
if tt.expectIdPHit {
assert.Equal(t, 1, idpMock.gotCalls, "expected IdP to be consulted")
require.Len(t, idpMock.gotMeta, 1)
assert.Equal(t, accountID, idpMock.gotMeta[0].WTAccountID)
} else {
assert.Equal(t, 0, idpMock.gotCalls, "expected IdP to not be consulted")
}
}
})
}
}
func TestGetAccountProxyByDomain(t *testing.T) { func TestGetAccountProxyByDomain(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil) proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, nil, proxyManager, nil)
proxyService.SetServiceManager(serviceManager) proxyService.SetServiceManager(serviceManager)
createTestProxies(t, ctx, testStore) createTestProxies(t, ctx, testStore)

View File

@@ -3215,7 +3215,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err return nil, nil, err
} }
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil) proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, nil, proxyManager, nil)
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{}) proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err

View File

@@ -217,6 +217,7 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
usersManager, usersManager,
nil, nil,
nil, nil,
nil,
) )
proxyService.SetServiceManager(&testServiceManager{store: testStore}) proxyService.SetServiceManager(&testServiceManager{store: testStore})

View File

@@ -110,7 +110,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
if err != nil { if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err) t.Fatalf("Failed to create proxy manager: %v", err)
} }
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil) proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil { if err != nil {
@@ -240,7 +240,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
if err != nil { if err != nil {
t.Fatalf("Failed to create proxy manager: %v", err) t.Fatalf("Failed to create proxy manager: %v", err)
} }
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil) proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am) domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter) serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
if err != nil { if err != nil {

View File

@@ -982,8 +982,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
var peer *nbpeer.Peer var peer *nbpeer.Peer
var updated, versionChanged, ipv6CapabilityChanged bool var updated, versionChanged, ipv6CapabilityChanged bool
var err error var err error
var postureChecks []*posture.Checks
var peerGroupIDs []string
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
@@ -1011,13 +1009,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return status.NewPeerLoginExpiredError() return status.NewPeerLoginExpiredError()
} }
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay) oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta) updated, versionChanged = peer.UpdateMetaIfNew(ctx, sync.Meta)
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay) ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
if updated { if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate() am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
@@ -1025,11 +1018,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
if err = transaction.SavePeer(ctx, accountID, peer); err != nil { if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
return err return err
} }
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
} }
return nil return nil
}) })
@@ -1037,6 +1025,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err return nil, nil, nil, 0, err
} }
peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
if err != nil {
return nil, nil, nil, 0, err
}
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra) peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil { if err != nil {
return nil, nil, nil, 0, err return nil, nil, nil, 0, err
@@ -1047,9 +1040,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err return nil, nil, nil, 0, err
} }
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) { if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(resPostureChecks) > 0 || versionChanged)) {
changedPeerIDs := []string{peer.ID} changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(postureChecks) > 0) affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(resPostureChecks) > 0)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil { if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err) return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
} }
@@ -1124,7 +1117,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
} }
var peer *nbpeer.Peer var peer *nbpeer.Peer
var shouldStorePeer bool var shouldStorePeer, shouldUpdatePeers bool
var peerGroupIDs []string var peerGroupIDs []string
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
@@ -1151,14 +1144,10 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
if changed { if changed {
shouldStorePeer = true shouldStorePeer = true
shouldUpdatePeers = true
} }
} }
peerGroupIDs, err = getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
if peer.SSHKey != login.SSHKey { if peer.SSHKey != login.SSHKey {
peer.SSHKey = login.SSHKey peer.SSHKey = login.SSHKey
shouldStorePeer = true shouldStorePeer = true
@@ -1180,7 +1169,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, false, err return nil, nil, nil, false, err
} }
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra) // This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
peer.UpdateMetaIfNew(ctx, login.Meta)
peerGroupIDs, err = getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
if err != nil {
return nil, nil, nil, false, err
}
isRequiresApproval, _, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil { if err != nil {
return nil, nil, nil, false, err return nil, nil, nil, false, err
} }
@@ -1190,7 +1187,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, false, err return nil, nil, nil, false, err
} }
if isStatusChanged || shouldStorePeer { if shouldUpdatePeers {
changedPeerIDs := []string{peer.ID} changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs) affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil { if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
@@ -1286,12 +1283,22 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
return network, nil, false, nil return network, nil, false, nil
} }
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, nil, false, err return nil, nil, false, err
} }
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer) peerGroupIDs, err := transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return nil, nil, false, err
}
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peerGroupIDs, policies)
if err != nil {
return nil, nil, false, err
}
enableSSH, err := isPeerSSHEnabled(ctx, peer, policies, peerGroupIDs)
if err != nil { if err != nil {
return nil, nil, false, err return nil, nil, false, err
} }
@@ -1299,32 +1306,16 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
return network, postureChecks, enableSSH, nil return network, postureChecks, enableSSH, nil
} }
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) { func isPeerSSHEnabled(ctx context.Context, peer *nbpeer.Peer, policies []*types.Policy, peerGroupIDs []string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) groupIDsMap := make(map[string]struct{}, len(peerGroupIDs))
if err != nil { for _, peerID := range peerGroupIDs {
return false, err groupIDsMap[peerID] = struct{}{}
} }
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, groupIDsMap, peer.SSHEnabled), nil
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return false, err
}
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
for _, g := range peerGroups {
peerGroupIDs[g.ID] = struct{}{}
}
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
} }
// getPeerPostureChecks returns the posture checks for the peer. // getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID string, peerGroupIDs []string, policies []*types.Policy) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
if len(policies) == 0 { if len(policies) == 0 {
return nil, nil return nil, nil
} }
@@ -1336,11 +1327,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
continue continue
} }
postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID) postureChecksIDs := processPeerPostureChecks(policy, peerGroupIDs)
if err != nil {
return nil, err
}
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...) peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
} }
@@ -1353,29 +1340,19 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
} }
// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks. // processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks.
func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) { func processPeerPostureChecks(policy *types.Policy, peerGroupIDs []string) []string {
for _, rule := range policy.Rules { for _, rule := range policy.Rules {
if !rule.Enabled { if !rule.Enabled {
continue continue
} }
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
if err != nil {
return nil, err
}
for _, sourceGroup := range rule.Sources { for _, sourceGroup := range rule.Sources {
group, ok := sourceGroups[sourceGroup] if slices.Contains(peerGroupIDs, sourceGroup) {
if !ok { return policy.SourcePostureChecks
return nil, fmt.Errorf("failed to check peer in policy source group")
}
if slices.Contains(group.Peers, peerID) {
return policy.SourcePostureChecks, nil
} }
} }
} }
return nil, nil return nil
} }
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO // checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO

View File

@@ -1,12 +1,16 @@
package peer package peer
import ( import (
"context"
"fmt"
"net" "net"
"net/netip" "net/netip"
"slices" "slices"
"sort" "strings"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/api"
) )
@@ -162,49 +166,7 @@ type PeerSystemMeta struct { //nolint:revive
} }
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool { func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
sort.Slice(p.NetworkAddresses, func(i, j int) bool { return len(metaDiff(p, other)) == 0
return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac
})
sort.Slice(other.NetworkAddresses, func(i, j int) bool {
return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac
})
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
})
if !equalNetworkAddresses {
return false
}
sort.Slice(p.Files, func(i, j int) bool {
return p.Files[i].Path < p.Files[j].Path
})
sort.Slice(other.Files, func(i, j int) bool {
return other.Files[i].Path < other.Files[j].Path
})
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
})
if !equalFiles {
return false
}
return p.Hostname == other.Hostname &&
p.GoOS == other.GoOS &&
p.Kernel == other.Kernel &&
p.KernelVersion == other.KernelVersion &&
p.Core == other.Core &&
p.Platform == other.Platform &&
p.OS == other.OS &&
p.OSVersion == other.OSVersion &&
p.WtVersion == other.WtVersion &&
p.UIVersion == other.UIVersion &&
p.SystemSerialNumber == other.SystemSerialNumber &&
p.SystemProductName == other.SystemProductName &&
p.SystemManufacturer == other.SystemManufacturer &&
p.Environment.Cloud == other.Environment.Cloud &&
p.Environment.Platform == other.Environment.Platform &&
p.Flags.isEqual(other.Flags) &&
capabilitiesEqual(p.Capabilities, other.Capabilities)
} }
func (p PeerSystemMeta) isEmpty() bool { func (p PeerSystemMeta) isEmpty() bool {
@@ -296,7 +258,7 @@ func (p *Peer) Copy() *Peer {
// UpdateMetaIfNew updates peer's system metadata if new information is provided // UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise // returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged bool) { func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (updated, versionChanged bool) {
if meta.isEmpty() { if meta.isEmpty() {
return updated, versionChanged return updated, versionChanged
} }
@@ -308,14 +270,121 @@ func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged boo
meta.UIVersion = p.Meta.UIVersion meta.UIVersion = p.Meta.UIVersion
} }
if p.Meta.isEqual(meta) { oldVersion := p.Meta.WtVersion
return updated, versionChanged
diff := metaDiff(p.Meta, meta)
if len(diff) != 0 {
p.Meta = meta
updated = true
} }
p.Meta = meta
updated = true versionInfo := ""
if versionChanged {
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
}
if len(diff) > 0 || versionChanged {
log.WithContext(ctx).
Debugf("peer meta updated, %s%d field(s) changed: %s", versionInfo, len(diff), strings.Join(diff, ", "))
}
return updated, versionChanged return updated, versionChanged
} }
// metaDiff returns a human-readable list of the fields that differ between the
// old and new meta, each formatted as `field: <old> -> <new>`. It is the single
// source of truth for meta comparison: isEqual reports equality as an empty
// diff, so the log line can never disagree with the change decision. Slices are
// cloned before sorting, so callers' meta is not mutated.
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
var diff []string
add := func(field string, oldVal, newVal any) {
diff = append(diff, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
}
if oldMeta.Hostname != newMeta.Hostname {
add("hostname", oldMeta.Hostname, newMeta.Hostname)
}
if oldMeta.GoOS != newMeta.GoOS {
add("goos", oldMeta.GoOS, newMeta.GoOS)
}
if oldMeta.Kernel != newMeta.Kernel {
add("kernel", oldMeta.Kernel, newMeta.Kernel)
}
if oldMeta.KernelVersion != newMeta.KernelVersion {
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
}
if oldMeta.Core != newMeta.Core {
add("core", oldMeta.Core, newMeta.Core)
}
if oldMeta.Platform != newMeta.Platform {
add("platform", oldMeta.Platform, newMeta.Platform)
}
if oldMeta.OS != newMeta.OS {
add("os", oldMeta.OS, newMeta.OS)
}
if oldMeta.OSVersion != newMeta.OSVersion {
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
}
if oldMeta.WtVersion != newMeta.WtVersion {
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
}
if oldMeta.UIVersion != newMeta.UIVersion {
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
}
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
}
if oldMeta.SystemProductName != newMeta.SystemProductName {
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
}
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
}
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
}
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
}
if !oldMeta.Flags.isEqual(newMeta.Flags) {
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
}
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
}
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
}
if !sameMultiset(oldMeta.Files, newMeta.Files) {
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
}
return diff
}
// sameMultiset reports whether two slices contain the same elements with the
// same multiplicity, ignoring order. The element type is the comparison key, so
// every field participates in equality.
func sameMultiset[T comparable](a, b []T) bool {
if len(a) != len(b) {
return false
}
counts := make(map[T]int, len(a))
for _, v := range a {
counts[v]++
}
for _, v := range b {
counts[v]--
if counts[v] == 0 {
delete(counts, v)
}
}
return len(counts) == 0
}
// GetLastLogin returns the last login time of the peer. // GetLastLogin returns the last login time of the peer.
func (p *Peer) GetLastLogin() time.Time { func (p *Peer) GetLastLogin() time.Time {
if p.LastLogin != nil { if p.LastLogin != nil {

View File

@@ -0,0 +1,113 @@
package peer
import (
"net/netip"
"reflect"
"testing"
"github.com/stretchr/testify/require"
)
// metaDiffExtraEntries accounts for PeerSystemMeta fields that metaDiff does not
// map 1:1 to a single diff entry. Today the only such field is Environment, which
// is exploded into two checks (Cloud, Platform) and therefore yields one extra
// entry beyond its single struct field. If you teach metaDiff to explode another
// field into N entries, bump this by N-1; if you collapse a field, lower it.
const metaDiffExtraEntries = 1
// TestMetaDiff_CoversAllFields fully populates a PeerSystemMeta with non-zero
// values and diffs it against the zero value, then asserts metaDiff emits exactly
// one entry per exported field (plus metaDiffExtraEntries for fields it explodes).
//
// The expected count is derived from the struct via reflection, so adding a field
// to PeerSystemMeta raises the expectation automatically — but the actual diff
// only grows if metaDiff was taught to compare the new field. A mismatch means
// someone changed the struct without updating metaDiff (or this test's
// extra-entry accounting), which is exactly what we want to catch.
func TestMetaDiff_CoversAllFields(t *testing.T) {
var full PeerSystemMeta
exported := populateAll(t, reflect.ValueOf(&full).Elem())
require.NotZero(t, exported, "expected PeerSystemMeta to expose fields")
diff := metaDiff(PeerSystemMeta{}, full)
require.Len(t, diff, exported+metaDiffExtraEntries,
"metaDiff entry count no longer matches PeerSystemMeta's fields: a field was "+
"likely added or removed without updating metaDiff (or metaDiffExtraEntries). "+
"diff was: %v", diff)
require.False(t, full.isEqual(PeerSystemMeta{}),
"isEqual must report a fully-populated meta as different from the zero value")
}
// TestFlags_isEqualChecksEveryField guards the one field that the count-based
// TestMetaDiff_CoversAllFields cannot: metaDiff collapses all of Flags into a
// single "flags" diff entry, so a new Flags field that Flags.isEqual forgets to
// compare would not change the diff count. This flips each Flags field on its own
// and asserts Flags.isEqual notices, so adding a Flags field without comparing it
// fails here.
func TestFlags_isEqualChecksEveryField(t *testing.T) {
typ := reflect.TypeOf(Flags{})
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
require.Equal(t, reflect.Bool, f.Type.Kind(),
"Flags.%s is not a bool; extend this test to set it non-zero", f.Name)
var a, b Flags
reflect.ValueOf(&b).Elem().Field(i).SetBool(true)
require.False(t, a.isEqual(b), "Flags.isEqual ignores field %s", f.Name)
}
}
// populateAll sets every exported field of the struct to a deterministic non-zero
// value, recursing into nested structs and the element type of struct slices so
// that each leaf differs from zero. It returns the number of exported fields on
// the top-level struct. netip.Prefix is treated as an opaque leaf (it has no
// settable exported fields and is comparable with ==).
func populateAll(t *testing.T, v reflect.Value) int {
t.Helper()
typ := v.Type()
exported := 0
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
if f.PkgPath != "" { // unexported
continue
}
exported++
setNonZero(t, v.Field(i))
}
return exported
}
// setNonZero assigns a deterministic non-zero value to a field based on its kind,
// recursing into nested structs and populating one element of slice fields.
func setNonZero(t *testing.T, field reflect.Value) {
t.Helper()
if field.Type() == reflect.TypeOf(netip.Prefix{}) {
field.Set(reflect.ValueOf(netip.MustParsePrefix("10.0.0.0/24")))
return
}
switch field.Kind() {
case reflect.String:
field.SetString("non-zero")
case reflect.Bool:
field.SetBool(true)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.SetInt(7)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.SetUint(7)
case reflect.Float32, reflect.Float64:
field.SetFloat(7)
case reflect.Struct:
populateAll(t, field)
case reflect.Slice:
s := reflect.MakeSlice(field.Type(), 1, 1)
setNonZero(t, s.Index(0))
field.Set(s)
default:
t.Fatalf("unhandled field kind %s; extend setNonZero", field.Kind())
}
}

View File

@@ -1847,12 +1847,17 @@ func (am *DefaultAccountManager) DeleteUserInvite(ctx context.Context, accountID
const minPasswordLength = 8 const minPasswordLength = 8
// validatePassword checks password strength requirements: // validatePassword checks password strength requirements.
func validatePassword(password string) error {
return ValidatePassword(password)
}
// ValidatePassword checks password strength requirements:
// - Minimum 8 characters // - Minimum 8 characters
// - At least 1 digit // - At least 1 digit
// - At least 1 uppercase letter // - At least 1 uppercase letter
// - At least 1 special character // - At least 1 special character
func validatePassword(password string) error { func ValidatePassword(password string) error {
if len(password) < minPasswordLength { if len(password) < minPasswordLength {
return errors.New("password must be at least 8 characters long") return errors.New("password must be at least 8 characters long")
} }

View File

@@ -125,6 +125,7 @@ func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
oidcConfig, oidcConfig,
nil, nil,
usersManager, usersManager,
nil,
realProxyManager, realProxyManager,
nil, nil,
) )

View File

@@ -140,6 +140,7 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
oidcConfig, oidcConfig,
nil, nil,
usersManager, usersManager,
nil,
proxyManager, proxyManager,
nil, nil,
) )

View File

@@ -21,7 +21,8 @@ AWK_FIRST_FIELD='{print $1}'
fetch_all_tags() { fetch_all_tags() {
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \ curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \ grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+([^"]+)?' | \
grep -iv 'rc' | \
sed 's/.*\/v//' | \ sed 's/.*\/v//' | \
sort -u -V sort -u -V
return 0 return 0

View File

@@ -32,7 +32,8 @@ fetch_current_ports_version() {
fetch_all_tags() { fetch_all_tags() {
# Fetch tags from GitHub tags page (no rate limiting, no auth needed) # Fetch tags from GitHub tags page (no rate limiting, no auth needed)
curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \ curl -sL "https://github.com/${GITHUB_REPO}/tags" 2>/dev/null | \
grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+' | \ grep -oE '/releases/tag/v[0-9]+\.[0-9]+\.[0-9]+([^"]+)?' | \
grep -iv 'rc' | \
sed 's/.*\/v//' | \ sed 's/.*\/v//' | \
sort -u -V sort -u -V
return 0 return 0

View File

@@ -33,7 +33,7 @@ type Client interface {
Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error Receive(ctx context.Context, msgHandler func(msg *proto.Message) error) error
Ready() bool Ready() bool
IsHealthy() bool IsHealthy() bool
WaitStreamConnected() WaitStreamConnected(context.Context)
SendToStream(msg *proto.EncryptedMessage) error SendToStream(msg *proto.EncryptedMessage) error
Send(msg *proto.Message) error Send(msg *proto.Message) error
SetOnReconnectedListener(func()) SetOnReconnectedListener(func())

View File

@@ -65,7 +65,10 @@ var _ = Describe("GrpcClient", func() {
return return
} }
}() }()
clientA.WaitStreamConnected() ctxA, cancelA := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelA()
clientA.WaitStreamConnected(ctxA)
Expect(clientA.StreamConnected()).To(BeTrue())
// connect PeerB to Signal // connect PeerB to Signal
keyB, _ := wgtypes.GenerateKey() keyB, _ := wgtypes.GenerateKey()
@@ -91,7 +94,10 @@ var _ = Describe("GrpcClient", func() {
} }
}() }()
clientB.WaitStreamConnected() ctxB, cancelB := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelB()
clientB.WaitStreamConnected(ctxB)
Expect(clientB.StreamConnected()).To(BeTrue())
// PeerA initiates ping-pong // PeerA initiates ping-pong
err := clientA.Send(&sigProto.Message{ err := clientA.Send(&sigProto.Message{
@@ -129,8 +135,10 @@ var _ = Describe("GrpcClient", func() {
return return
} }
}() }()
client.WaitStreamConnected() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
Expect(client).NotTo(BeNil()) defer cancel()
client.WaitStreamConnected(ctx)
Expect(client.StreamConnected()).To(BeTrue())
}) })
}) })

View File

@@ -246,15 +246,6 @@ func (c *GrpcClient) notifyStreamConnected() {
} }
} }
func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
c.mux.Lock()
defer c.mux.Unlock()
if c.connectedCh == nil {
c.connectedCh = make(chan struct{})
}
return c.connectedCh
}
func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) { func (c *GrpcClient) connect(ctx context.Context, key string) (proto.SignalExchange_ConnectStreamClient, error) {
c.stream = nil c.stream = nil
@@ -310,14 +301,24 @@ func (c *GrpcClient) IsHealthy() bool {
} }
// WaitStreamConnected waits until the client is connected to the Signal stream // WaitStreamConnected waits until the client is connected to the Signal stream
func (c *GrpcClient) WaitStreamConnected() { func (c *GrpcClient) WaitStreamConnected(ctx context.Context) {
// Check the status and obtain the wait channel atomically: otherwise
// notifyStreamConnected could flip the status and close/clear the channel
// between the check and the channel creation, leaving us waiting forever on
// a stale channel.
c.mux.Lock()
if c.status == StreamConnected { if c.status == StreamConnected {
c.mux.Unlock()
return return
} }
if c.connectedCh == nil {
c.connectedCh = make(chan struct{})
}
ch := c.connectedCh
c.mux.Unlock()
ch := c.getStreamStatusChan()
select { select {
case <-ctx.Done():
case <-c.ctx.Done(): case <-c.ctx.Done():
case <-ch: case <-ch:
} }

View File

@@ -55,7 +55,7 @@ func (sm *MockClient) Ready() bool {
return sm.ReadyFunc() return sm.ReadyFunc()
} }
func (sm *MockClient) WaitStreamConnected() { func (sm *MockClient) WaitStreamConnected(context.Context) {
if sm.WaitStreamConnectedFunc == nil { if sm.WaitStreamConnectedFunc == nil {
return return
} }

View File

@@ -65,7 +65,7 @@ func TestReceiveProbeRoundTrips(t *testing.T) {
streamReady := make(chan struct{}) streamReady := make(chan struct{})
go func() { go func() {
client.WaitStreamConnected() client.WaitStreamConnected(ctx)
close(streamReady) close(streamReady)
}() }()
select { select {

View File

@@ -26,6 +26,10 @@ type Peer struct {
// a gRpc connection stream to the Peer // a gRpc connection stream to the Peer
Stream proto.SignalExchange_ConnectStreamServer Stream proto.SignalExchange_ConnectStreamServer
// sendMu serializes writes to Stream. gRPC forbids concurrent SendMsg on
// the same ServerStream, and a peer can be the target of many senders at
// once.
sendMu sync.Mutex
// registration time // registration time
RegisteredAt time.Time RegisteredAt time.Time
@@ -33,6 +37,13 @@ type Peer struct {
Cancel context.CancelFunc Cancel context.CancelFunc
} }
// Send writes a message to the peer's stream, serializing concurrent senders.
func (p *Peer) Send(msg *proto.EncryptedMessage) error {
p.sendMu.Lock()
defer p.sendMu.Unlock()
return p.Stream.Send(msg)
}
// NewPeer creates a new instance of a connected Peer // NewPeer creates a new instance of a connected Peer
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer { func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer {
return &Peer{ return &Peer{

View File

@@ -0,0 +1,67 @@
package server
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/peer"
)
// concurrencyCheckStream records the maximum number of Send calls in flight at
// once. gRPC forbids concurrent SendMsg on the same ServerStream, so a correct
// server must never have more than one in flight per peer.
type concurrencyCheckStream struct {
proto.SignalExchange_ConnectStreamServer
ctx context.Context
inflight atomic.Int32
maxSeen atomic.Int32
}
func (s *concurrencyCheckStream) Send(*proto.EncryptedMessage) error {
n := s.inflight.Add(1)
for {
old := s.maxSeen.Load()
if n <= old || s.maxSeen.CompareAndSwap(old, n) {
break
}
}
// Widen the window so overlapping callers are reliably observed.
time.Sleep(time.Millisecond)
s.inflight.Add(-1)
return nil
}
func (s *concurrencyCheckStream) Context() context.Context { return s.ctx }
// TestForwardMessageToPeerSerializesSend verifies that concurrent forwards to the
// same peer never call Stream.Send concurrently, which would violate the gRPC
// ServerStream contract.
func TestForwardMessageToPeerSerializesSend(t *testing.T) {
s, err := NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
const peerID = "peerX"
stream := &concurrencyCheckStream{ctx: context.Background()}
_, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
require.NoError(t, s.registry.Register(peer.NewPeer(peerID, stream, cancel)))
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
s.forwardMessageToPeer(context.Background(), &proto.EncryptedMessage{Key: "sender", RemoteKey: peerID})
}()
}
wg.Wait()
require.Equal(t, int32(1), stream.maxSeen.Load(), "Stream.Send must never run concurrently on the same peer stream")
}

View File

@@ -179,7 +179,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
sendResultChan := make(chan error, 1) sendResultChan := make(chan error, 1)
go func() { go func() {
select { select {
case sendResultChan <- dstPeer.Stream.Send(msg): case sendResultChan <- dstPeer.Send(msg):
return return
case <-dstPeer.Stream.Context().Done(): case <-dstPeer.Stream.Context().Done():
return return